chore: initial public snapshot for github upload
This commit is contained in:
2
llm-gateway-competitors/litellm-wheel-src/litellm/proxy/.gitignore
vendored
Normal file
2
llm-gateway-competitors/litellm-wheel-src/litellm/proxy/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
.env
|
||||
secrets.toml
|
||||
@@ -0,0 +1,44 @@
|
||||
# litellm-proxy
|
||||
|
||||
A local, fast, and lightweight **OpenAI-compatible server** to call 100+ LLM APIs.
|
||||
|
||||
## usage
|
||||
|
||||
```shell
|
||||
$ pip install litellm
|
||||
```
|
||||
```shell
|
||||
$ litellm --model ollama/codellama
|
||||
|
||||
#INFO: Ollama running on http://0.0.0.0:8000
|
||||
```
|
||||
|
||||
## replace openai base
|
||||
```python
|
||||
import openai # openai v1.0.0+
|
||||
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
[**See how to call Huggingface,Bedrock,TogetherAI,Anthropic, etc.**](https://docs.litellm.ai/docs/simple_proxy)
|
||||
|
||||
|
||||
---
|
||||
|
||||
### Folder Structure
|
||||
|
||||
**Routes**
|
||||
- `proxy_server.py` - all openai-compatible routes - `/v1/chat/completion`, `/v1/embedding` + model info routes - `/v1/models`, `/v1/model/info`, `/v1/model_group_info` routes.
|
||||
- `health_endpoints/` - `/health`, `/health/liveliness`, `/health/readiness`
|
||||
- `management_endpoints/key_management_endpoints.py` - all `/key/*` routes
|
||||
- `management_endpoints/team_endpoints.py` - all `/team/*` routes
|
||||
- `management_endpoints/internal_user_endpoints.py` - all `/user/*` routes
|
||||
- `management_endpoints/ui_sso.py` - all `/sso/*` routes
|
||||
@@ -0,0 +1 @@
|
||||
from . import *
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class MCPAuthenticatedUser(AuthenticatedUser):
|
||||
"""
|
||||
Wrapper class to make LiteLLM's authentication and configuration compatible with MCP's AuthenticatedUser.
|
||||
|
||||
This class handles:
|
||||
1. User API key authentication information
|
||||
2. MCP authentication header (deprecated)
|
||||
3. MCP server configuration (can include access groups)
|
||||
4. Server-specific authentication headers
|
||||
5. OAuth2 headers
|
||||
6. Raw headers - allows forwarding specific headers to the MCP server, specified by the admin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
mcp_auth_header: Optional[str] = None,
|
||||
mcp_servers: Optional[List[str]] = None,
|
||||
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
oauth2_headers: Optional[Dict[str, str]] = None,
|
||||
mcp_protocol_version: Optional[str] = None,
|
||||
raw_headers: Optional[Dict[str, str]] = None,
|
||||
client_ip: Optional[str] = None,
|
||||
):
|
||||
self.user_api_key_auth = user_api_key_auth
|
||||
self.mcp_auth_header = mcp_auth_header
|
||||
self.mcp_servers = mcp_servers
|
||||
self.mcp_server_auth_headers = mcp_server_auth_headers or {}
|
||||
self.mcp_protocol_version = mcp_protocol_version
|
||||
self.oauth2_headers = oauth2_headers
|
||||
self.raw_headers = raw_headers
|
||||
self.client_ip = client_ip
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,789 @@
|
||||
"""
|
||||
BYOK (Bring Your Own Key) OAuth 2.1 Authorization Server endpoints for MCP servers.
|
||||
|
||||
When an MCP client connects to a BYOK-enabled server and no stored credential exists,
|
||||
LiteLLM runs a minimal OAuth 2.1 authorization code flow. The "authorization page" is
|
||||
just a form that asks the user for their API key — not a full identity-provider OAuth.
|
||||
|
||||
Endpoints implemented here:
|
||||
GET /.well-known/oauth-authorization-server — OAuth authorization server metadata
|
||||
GET /.well-known/oauth-protected-resource — OAuth protected resource metadata
|
||||
GET /v1/mcp/oauth/authorize — Shows HTML form to collect the API key
|
||||
POST /v1/mcp/oauth/authorize — Stores temp auth code and redirects
|
||||
POST /v1/mcp/oauth/token — Exchanges code for a bearer JWT token
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import html as _html_module
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Optional, cast
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._experimental.mcp_server.db import store_user_credential
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
get_request_base_url,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory store for pending authorization codes.
|
||||
# Each entry: {code: {api_key, server_id, code_challenge, redirect_uri, user_id, expires_at}}
|
||||
# ---------------------------------------------------------------------------
|
||||
_byok_auth_codes: Dict[str, dict] = {}
|
||||
|
||||
# Authorization codes expire after 5 minutes.
|
||||
_AUTH_CODE_TTL_SECONDS = 300
|
||||
# Hard cap to prevent memory exhaustion from incomplete OAuth flows.
|
||||
_AUTH_CODES_MAX_SIZE = 1000
|
||||
|
||||
router = APIRouter(tags=["mcp"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PKCE helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
|
||||
"""Return True iff SHA-256(code_verifier) == code_challenge (base64url, no padding)."""
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return computed == code_challenge
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cleanup of expired auth codes (called lazily on each request)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _purge_expired_codes() -> None:
|
||||
now = time.time()
|
||||
expired = [k for k, v in _byok_auth_codes.items() if v["expires_at"] < now]
|
||||
for k in expired:
|
||||
del _byok_auth_codes[k]
|
||||
|
||||
|
||||
def _build_authorize_html(
|
||||
server_name: str,
|
||||
server_initial: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_challenge: str,
|
||||
code_challenge_method: str,
|
||||
state: str,
|
||||
server_id: str,
|
||||
access_items: list,
|
||||
help_url: str,
|
||||
) -> str:
|
||||
"""Build the 2-step BYOK OAuth authorization page HTML."""
|
||||
|
||||
# Escape all user-supplied / externally-derived values before interpolation
|
||||
e = _html_module.escape
|
||||
server_name = e(server_name)
|
||||
server_initial = e(server_initial)
|
||||
client_id = e(client_id)
|
||||
redirect_uri = e(redirect_uri)
|
||||
code_challenge = e(code_challenge)
|
||||
code_challenge_method = e(code_challenge_method)
|
||||
state = e(state)
|
||||
server_id = e(server_id)
|
||||
|
||||
# Build access checklist rows
|
||||
access_rows = "".join(
|
||||
f'<div class="access-item"><span class="check">✓</span>{e(item)}</div>'
|
||||
for item in access_items
|
||||
)
|
||||
access_section = ""
|
||||
if access_rows:
|
||||
access_section = f"""
|
||||
<div class="access-box">
|
||||
<div class="access-header">
|
||||
<span class="shield">▮</span>
|
||||
<span>Requested Access</span>
|
||||
</div>
|
||||
{access_rows}
|
||||
</div>"""
|
||||
|
||||
# Help link for step 2
|
||||
help_link_html = ""
|
||||
if help_url:
|
||||
help_link_html = f'<a class="help-link" href="{e(help_url)}" target="_blank">Where do I find my API key? ↗</a>'
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connect {server_name} — LiteLLM</title>
|
||||
<style>
|
||||
*, *::before, *::after {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #0f172a;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 24px;
|
||||
}}
|
||||
.modal {{
|
||||
background: #ffffff;
|
||||
border-radius: 20px;
|
||||
padding: 36px 32px 32px;
|
||||
width: 440px;
|
||||
max-width: 100%;
|
||||
position: relative;
|
||||
box-shadow: 0 25px 60px rgba(0,0,0,0.35);
|
||||
}}
|
||||
/* Progress dots */
|
||||
.dots {{
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
gap: 7px;
|
||||
margin-bottom: 28px;
|
||||
}}
|
||||
.dot {{
|
||||
width: 8px; height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #e2e8f0;
|
||||
}}
|
||||
.dot.active {{ background: #38bdf8; }}
|
||||
/* Close button */
|
||||
.close-btn {{
|
||||
position: absolute;
|
||||
top: 16px; right: 16px;
|
||||
background: none; border: none;
|
||||
font-size: 16px; color: #94a3b8;
|
||||
cursor: pointer; line-height: 1;
|
||||
width: 28px; height: 28px;
|
||||
border-radius: 6px;
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
}}
|
||||
.close-btn:hover {{ background: #f1f5f9; color: #475569; }}
|
||||
/* Logo pair */
|
||||
.logos {{
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
gap: 12px; margin-bottom: 20px;
|
||||
}}
|
||||
.logo {{
|
||||
width: 52px; height: 52px;
|
||||
border-radius: 14px;
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
font-size: 22px; font-weight: 800; color: white;
|
||||
}}
|
||||
.logo-img {{
|
||||
width: 52px; height: 52px;
|
||||
border-radius: 14px;
|
||||
object-fit: cover;
|
||||
border: 1.5px solid #e2e8f0;
|
||||
}}
|
||||
.logo-s {{ background: linear-gradient(135deg, #818cf8 0%, #4f46e5 100%); }}
|
||||
.logo-arrow {{ color: #cbd5e1; font-size: 20px; font-weight: 300; }}
|
||||
/* Headings */
|
||||
.step-title {{
|
||||
text-align: center;
|
||||
font-size: 21px; font-weight: 700;
|
||||
color: #0f172a; margin-bottom: 8px;
|
||||
}}
|
||||
.step-subtitle {{
|
||||
text-align: center;
|
||||
font-size: 14px; color: #64748b;
|
||||
line-height: 1.55; margin-bottom: 22px;
|
||||
}}
|
||||
/* Info box */
|
||||
.info-box {{
|
||||
background: #f8fafc;
|
||||
border-radius: 12px;
|
||||
padding: 14px 16px;
|
||||
display: flex; gap: 12px;
|
||||
margin-bottom: 14px;
|
||||
}}
|
||||
.info-icon {{ font-size: 17px; flex-shrink: 0; margin-top: 1px; color: #38bdf8; }}
|
||||
.info-box h4 {{ font-size: 13px; font-weight: 600; color: #1e293b; margin-bottom: 4px; }}
|
||||
.info-box p {{ font-size: 13px; color: #64748b; line-height: 1.5; }}
|
||||
/* Access checklist */
|
||||
.access-box {{
|
||||
background: #f8fafc;
|
||||
border-radius: 12px;
|
||||
padding: 14px 16px;
|
||||
margin-bottom: 22px;
|
||||
}}
|
||||
.access-header {{
|
||||
display: flex; align-items: center; gap: 8px;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.shield {{ color: #22c55e; font-size: 15px; }}
|
||||
.access-header > span:last-child {{
|
||||
font-size: 11px; font-weight: 700;
|
||||
letter-spacing: 0.07em;
|
||||
text-transform: uppercase;
|
||||
color: #475569;
|
||||
}}
|
||||
.access-item {{
|
||||
display: flex; align-items: center; gap: 9px;
|
||||
font-size: 13.5px; color: #374151;
|
||||
padding: 3px 0;
|
||||
}}
|
||||
.check {{ color: #22c55e; font-weight: 700; font-size: 13px; }}
|
||||
/* Primary CTA */
|
||||
.btn-primary {{
|
||||
width: 100%; padding: 15px;
|
||||
background: #0f172a; color: white;
|
||||
border: none; border-radius: 12px;
|
||||
font-size: 15px; font-weight: 600;
|
||||
cursor: pointer; margin-bottom: 10px;
|
||||
}}
|
||||
.btn-primary:hover {{ background: #1e293b; }}
|
||||
.btn-cancel {{
|
||||
width: 100%; padding: 8px;
|
||||
background: none; border: none;
|
||||
font-size: 13.5px; color: #94a3b8;
|
||||
cursor: pointer;
|
||||
}}
|
||||
.btn-cancel:hover {{ color: #64748b; }}
|
||||
/* Step 2 nav */
|
||||
.step2-nav {{
|
||||
display: flex; align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 24px;
|
||||
}}
|
||||
.back-btn {{
|
||||
background: none; border: none;
|
||||
font-size: 13.5px; color: #64748b;
|
||||
cursor: pointer; display: flex; align-items: center; gap: 4px;
|
||||
}}
|
||||
.back-btn:hover {{ color: #374151; }}
|
||||
/* Key icon */
|
||||
.key-icon-wrap {{
|
||||
width: 46px; height: 46px;
|
||||
background: #e0f2fe;
|
||||
border-radius: 12px;
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
margin-bottom: 14px;
|
||||
}}
|
||||
.key-icon-wrap svg {{ width: 22px; height: 22px; color: #0284c7; }}
|
||||
/* Form elements */
|
||||
.field-label {{
|
||||
font-size: 13.5px; font-weight: 600;
|
||||
color: #1e293b; display: block;
|
||||
margin-bottom: 7px;
|
||||
}}
|
||||
.key-input {{
|
||||
width: 100%; padding: 11px 13px;
|
||||
border: 1.5px solid #e2e8f0;
|
||||
border-radius: 10px;
|
||||
font-size: 14px; color: #0f172a;
|
||||
outline: none; transition: border-color 0.15s, box-shadow 0.15s;
|
||||
}}
|
||||
.key-input:focus {{
|
||||
border-color: #38bdf8;
|
||||
box-shadow: 0 0 0 3px rgba(56,189,248,0.12);
|
||||
}}
|
||||
.help-link {{
|
||||
display: inline-flex; align-items: center; gap: 4px;
|
||||
color: #0ea5e9; font-size: 13px;
|
||||
text-decoration: none; margin: 8px 0 16px;
|
||||
}}
|
||||
.help-link:hover {{ text-decoration: underline; }}
|
||||
/* Save toggle card */
|
||||
.save-card {{
|
||||
border: 1.5px solid #e2e8f0;
|
||||
border-radius: 12px;
|
||||
padding: 13px 15px;
|
||||
margin-bottom: 6px;
|
||||
}}
|
||||
.save-row {{
|
||||
display: flex; align-items: center; gap: 10px;
|
||||
}}
|
||||
.save-icon {{ font-size: 16px; }}
|
||||
.save-label {{
|
||||
flex: 1;
|
||||
font-size: 14px; font-weight: 500; color: #1e293b;
|
||||
}}
|
||||
/* Toggle switch */
|
||||
.toggle {{ position: relative; width: 44px; height: 24px; flex-shrink: 0; }}
|
||||
.toggle input {{ opacity: 0; width: 0; height: 0; }}
|
||||
.slider {{
|
||||
position: absolute; inset: 0;
|
||||
background: #e2e8f0;
|
||||
border-radius: 24px; cursor: pointer;
|
||||
transition: background 0.18s;
|
||||
}}
|
||||
.slider::before {{
|
||||
content: '';
|
||||
position: absolute;
|
||||
width: 18px; height: 18px;
|
||||
left: 3px; bottom: 3px;
|
||||
background: white;
|
||||
border-radius: 50%;
|
||||
transition: transform 0.18s;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.18);
|
||||
}}
|
||||
input:checked + .slider {{ background: #38bdf8; }}
|
||||
input:checked + .slider::before {{ transform: translateX(20px); }}
|
||||
/* Duration pills */
|
||||
.duration-section {{ margin-top: 14px; }}
|
||||
.duration-label {{
|
||||
font-size: 12px; font-weight: 600;
|
||||
color: #64748b; margin-bottom: 8px;
|
||||
text-transform: uppercase; letter-spacing: 0.05em;
|
||||
}}
|
||||
.pills {{ display: flex; flex-wrap: wrap; gap: 7px; }}
|
||||
.pill {{
|
||||
padding: 6px 13px;
|
||||
border: 1.5px solid #e2e8f0;
|
||||
border-radius: 20px;
|
||||
font-size: 13px; color: #475569;
|
||||
cursor: pointer; background: white;
|
||||
transition: all 0.13s;
|
||||
user-select: none;
|
||||
}}
|
||||
.pill:hover {{ border-color: #94a3b8; }}
|
||||
.pill.sel {{
|
||||
border-color: #38bdf8;
|
||||
color: #0284c7;
|
||||
background: #e0f2fe;
|
||||
}}
|
||||
/* Security note */
|
||||
.sec-note {{
|
||||
background: #f8fafc;
|
||||
border-radius: 10px;
|
||||
padding: 11px 14px;
|
||||
display: flex; gap: 9px; align-items: flex-start;
|
||||
margin: 16px 0;
|
||||
}}
|
||||
.sec-icon {{ font-size: 13px; color: #94a3b8; margin-top: 1px; flex-shrink: 0; }}
|
||||
.sec-note p {{ font-size: 12.5px; color: #64748b; line-height: 1.5; }}
|
||||
/* Connect button */
|
||||
.btn-connect {{
|
||||
width: 100%; padding: 15px;
|
||||
border: none; border-radius: 12px;
|
||||
font-size: 15px; font-weight: 600;
|
||||
cursor: pointer;
|
||||
background: #bae6fd; color: #0369a1;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}}
|
||||
.btn-connect.ready {{
|
||||
background: #0ea5e9; color: white;
|
||||
}}
|
||||
.btn-connect.ready:hover {{ background: #0284c7; }}
|
||||
/* Step visibility */
|
||||
.step {{ display: none; }}
|
||||
.step.show {{ display: block; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="modal">
|
||||
|
||||
<!-- ── STEP 1: Connect ─────────────────────────────────────── -->
|
||||
<div id="s1" class="step show">
|
||||
<div class="dots">
|
||||
<div class="dot active"></div>
|
||||
<div class="dot"></div>
|
||||
</div>
|
||||
<button class="close-btn" type="button" onclick="doCancel()" title="Close">×</button>
|
||||
|
||||
<div class="logos">
|
||||
<img src="/ui/assets/logos/litellm_logo.jpg" class="logo-img" alt="LiteLLM">
|
||||
<span class="logo-arrow">→</span>
|
||||
<div class="logo logo-s">{server_initial}</div>
|
||||
</div>
|
||||
|
||||
<h2 class="step-title">Connect {server_name} MCP</h2>
|
||||
<p class="step-subtitle">LiteLLM needs access to {server_name} to complete your request.</p>
|
||||
|
||||
<div class="info-box">
|
||||
<span class="info-icon">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>
|
||||
</span>
|
||||
<div>
|
||||
<h4>How it works</h4>
|
||||
<p>LiteLLM acts as a secure bridge. Your requests are routed through our MCP client directly to {server_name}’s API.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{access_section}
|
||||
|
||||
<button class="btn-primary" type="button" onclick="goStep2()">
|
||||
Continue to Authentication →
|
||||
</button>
|
||||
<button class="btn-cancel" type="button" onclick="doCancel()">Cancel</button>
|
||||
</div>
|
||||
|
||||
<!-- ── STEP 2: Provide API Key ──────────────────────────────── -->
|
||||
<div id="s2" class="step">
|
||||
<div class="step2-nav">
|
||||
<button class="back-btn" type="button" onclick="goStep1()">← Back</button>
|
||||
<div class="dots">
|
||||
<div class="dot active"></div>
|
||||
<div class="dot active"></div>
|
||||
</div>
|
||||
<button class="close-btn" style="position:static;" type="button" onclick="doCancel()" title="Close">×</button>
|
||||
</div>
|
||||
|
||||
<div class="key-icon-wrap">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#0284c7" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21 2l-2 2m-7.61 7.61a5.5 5.5 0 1 1-7.778 7.778 5.5 5.5 0 0 1 7.777-7.777zm0 0L15.5 7.5m0 0l3 3L22 7l-3-3m-3.5 3.5L19 4"/></svg>
|
||||
</div>
|
||||
<h2 class="step-title" style="text-align:left;">Provide API Key</h2>
|
||||
<p class="step-subtitle" style="text-align:left;">Enter your {server_name} API key to authorize this connection.</p>
|
||||
|
||||
<form method="POST" id="authForm" onsubmit="prepareSubmit()">
|
||||
<input type="hidden" name="client_id" value="{client_id}">
|
||||
<input type="hidden" name="redirect_uri" value="{redirect_uri}">
|
||||
<input type="hidden" name="code_challenge" value="{code_challenge}">
|
||||
<input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
|
||||
<input type="hidden" name="state" value="{state}">
|
||||
<input type="hidden" name="server_id" value="{server_id}">
|
||||
<input type="hidden" name="duration" id="durInput" value="until_revoked">
|
||||
|
||||
<label class="field-label">{server_name} API Key</label>
|
||||
<input
|
||||
type="password"
|
||||
name="api_key"
|
||||
id="apiKey"
|
||||
class="key-input"
|
||||
placeholder="Enter your API key"
|
||||
required
|
||||
autofocus
|
||||
oninput="syncBtn()"
|
||||
>
|
||||
|
||||
{help_link_html}
|
||||
|
||||
<div class="save-card">
|
||||
<div class="save-row">
|
||||
<span class="save-label">Save key for future use</span>
|
||||
<label class="toggle">
|
||||
<input type="checkbox" id="saveToggle" onchange="toggleDur()">
|
||||
<span class="slider"></span>
|
||||
</label>
|
||||
</div>
|
||||
<div id="durSection" class="duration-section" style="display:none;">
|
||||
<div class="duration-label">Duration</div>
|
||||
<div class="pills">
|
||||
<div class="pill" onclick="selDur('1h',this)">1 hour</div>
|
||||
<div class="pill sel" onclick="selDur('24h',this)">24 hours</div>
|
||||
<div class="pill" onclick="selDur('7d',this)">7 days</div>
|
||||
<div class="pill" onclick="selDur('30d',this)">30 days</div>
|
||||
<div class="pill" onclick="selDur('until_revoked',this)">Until I revoke</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="sec-note">
|
||||
<span class="sec-icon">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="11" width="18" height="11" rx="2" ry="2"/><path d="M7 11V7a5 5 0 0 1 10 0v4"/></svg>
|
||||
</span>
|
||||
<p>Your key is stored securely and transmitted over HTTPS. It is never shared with third parties.</p>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn-connect" id="connectBtn">
|
||||
Connect & Authorize
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<script>
|
||||
function goStep2() {{
|
||||
document.getElementById('s1').classList.remove('show');
|
||||
document.getElementById('s2').classList.add('show');
|
||||
}}
|
||||
function goStep1() {{
|
||||
document.getElementById('s2').classList.remove('show');
|
||||
document.getElementById('s1').classList.add('show');
|
||||
}}
|
||||
function doCancel() {{
|
||||
if (window.opener) window.close();
|
||||
else window.history.back();
|
||||
}}
|
||||
function toggleDur() {{
|
||||
const on = document.getElementById('saveToggle').checked;
|
||||
document.getElementById('durSection').style.display = on ? 'block' : 'none';
|
||||
}}
|
||||
function selDur(val, el) {{
|
||||
document.querySelectorAll('.pill').forEach(p => p.classList.remove('sel'));
|
||||
el.classList.add('sel');
|
||||
document.getElementById('durInput').value = val;
|
||||
}}
|
||||
function syncBtn() {{
|
||||
const btn = document.getElementById('connectBtn');
|
||||
if (document.getElementById('apiKey').value.length > 0) {{
|
||||
btn.classList.add('ready');
|
||||
}} else {{
|
||||
btn.classList.remove('ready');
|
||||
}}
|
||||
}}
|
||||
function prepareSubmit() {{
|
||||
// nothing extra needed — duration is already in the hidden input
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth metadata discovery endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/.well-known/oauth-authorization-server", include_in_schema=False)
|
||||
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
|
||||
"""RFC 8414 Authorization Server Metadata for the BYOK OAuth flow."""
|
||||
base_url = get_request_base_url(request)
|
||||
return JSONResponse(
|
||||
{
|
||||
"issuer": base_url,
|
||||
"authorization_endpoint": f"{base_url}/v1/mcp/oauth/authorize",
|
||||
"token_endpoint": f"{base_url}/v1/mcp/oauth/token",
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/.well-known/oauth-protected-resource", include_in_schema=False)
|
||||
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
||||
"""RFC 9728 Protected Resource Metadata pointing back at this server."""
|
||||
base_url = get_request_base_url(request)
|
||||
return JSONResponse(
|
||||
{
|
||||
"resource": base_url,
|
||||
"authorization_servers": [base_url],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization endpoint — GET (show form) and POST (process form)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/v1/mcp/oauth/authorize", include_in_schema=False)
|
||||
async def byok_authorize_get(
|
||||
request: Request,
|
||||
client_id: Optional[str] = None,
|
||||
redirect_uri: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
server_id: Optional[str] = None,
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Show the BYOK API-key entry form.
|
||||
|
||||
The MCP client navigates the user here; the user types their API key and
|
||||
clicks "Connect & Authorize", which POSTs back to this same path.
|
||||
"""
|
||||
if response_type != "code":
|
||||
raise HTTPException(status_code=400, detail="response_type must be 'code'")
|
||||
if not redirect_uri:
|
||||
raise HTTPException(status_code=400, detail="redirect_uri is required")
|
||||
if not code_challenge:
|
||||
raise HTTPException(status_code=400, detail="code_challenge is required")
|
||||
|
||||
# Resolve server metadata (name, description items, help URL).
|
||||
server_name = "MCP Server"
|
||||
access_items: list = []
|
||||
help_url = ""
|
||||
if server_id:
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
registry = global_mcp_server_manager.get_registry()
|
||||
if server_id in registry:
|
||||
srv = registry[server_id]
|
||||
server_name = srv.server_name or srv.name
|
||||
access_items = list(srv.byok_description or [])
|
||||
help_url = srv.byok_api_key_help_url or ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server_initial = (server_name[0].upper()) if server_name else "S"
|
||||
|
||||
html = _build_authorize_html(
|
||||
server_name=server_name,
|
||||
server_initial=server_initial,
|
||||
client_id=client_id or "",
|
||||
redirect_uri=redirect_uri,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method or "S256",
|
||||
state=state or "",
|
||||
server_id=server_id or "",
|
||||
access_items=access_items,
|
||||
help_url=help_url,
|
||||
)
|
||||
return HTMLResponse(content=html)
|
||||
|
||||
|
||||
@router.post("/v1/mcp/oauth/authorize", include_in_schema=False)
|
||||
async def byok_authorize_post(
|
||||
request: Request,
|
||||
client_id: str = Form(default=""),
|
||||
redirect_uri: str = Form(...),
|
||||
code_challenge: str = Form(...),
|
||||
code_challenge_method: str = Form(default="S256"),
|
||||
state: str = Form(default=""),
|
||||
server_id: str = Form(default=""),
|
||||
api_key: str = Form(...),
|
||||
) -> RedirectResponse:
|
||||
"""
|
||||
Process the BYOK API-key form submission.
|
||||
|
||||
Stores a short-lived authorization code and redirects the client back to
|
||||
redirect_uri with ?code=...&state=... query parameters.
|
||||
"""
|
||||
_purge_expired_codes()
|
||||
|
||||
# Validate redirect_uri scheme to prevent open redirect
|
||||
parsed_uri = urlparse(redirect_uri)
|
||||
if parsed_uri.scheme not in ("http", "https"):
|
||||
raise HTTPException(status_code=400, detail="Invalid redirect_uri scheme")
|
||||
|
||||
# Reject new codes if the store is at capacity (prevents memory exhaustion
|
||||
# from a burst of abandoned OAuth flows).
|
||||
if len(_byok_auth_codes) >= _AUTH_CODES_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Too many pending authorization flows"
|
||||
)
|
||||
|
||||
if code_challenge_method != "S256":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Only S256 code_challenge_method is supported"
|
||||
)
|
||||
|
||||
auth_code = str(uuid.uuid4())
|
||||
_byok_auth_codes[auth_code] = {
|
||||
"api_key": api_key,
|
||||
"server_id": server_id,
|
||||
"code_challenge": code_challenge,
|
||||
"redirect_uri": redirect_uri,
|
||||
"user_id": client_id, # external client passes LiteLLM user-id as client_id
|
||||
"expires_at": time.time() + _AUTH_CODE_TTL_SECONDS,
|
||||
}
|
||||
|
||||
params = urlencode({"code": auth_code, "state": state})
|
||||
separator = "&" if "?" in redirect_uri else "?"
|
||||
location = f"{redirect_uri}{separator}{params}"
|
||||
return RedirectResponse(url=location, status_code=302)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/v1/mcp/oauth/token", include_in_schema=False)
|
||||
async def byok_token(
|
||||
request: Request,
|
||||
grant_type: str = Form(...),
|
||||
code: str = Form(...),
|
||||
redirect_uri: str = Form(default=""),
|
||||
code_verifier: str = Form(...),
|
||||
client_id: str = Form(default=""),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Exchange an authorization code for a short-lived BYOK session JWT.
|
||||
|
||||
1. Validates the authorization code and PKCE challenge.
|
||||
2. Stores the API key via store_user_credential().
|
||||
3. Issues a signed JWT with type="byok_session".
|
||||
"""
|
||||
from litellm.proxy.proxy_server import master_key, prisma_client
|
||||
|
||||
_purge_expired_codes()
|
||||
|
||||
if grant_type != "authorization_code":
|
||||
raise HTTPException(status_code=400, detail="unsupported_grant_type")
|
||||
|
||||
record = _byok_auth_codes.get(code)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=400, detail="invalid_grant")
|
||||
|
||||
if time.time() > record["expires_at"]:
|
||||
del _byok_auth_codes[code]
|
||||
raise HTTPException(status_code=400, detail="invalid_grant")
|
||||
|
||||
# PKCE verification
|
||||
if not _verify_pkce(code_verifier, record["code_challenge"]):
|
||||
raise HTTPException(status_code=400, detail="invalid_grant")
|
||||
|
||||
# Consume the code (one-time use)
|
||||
del _byok_auth_codes[code]
|
||||
|
||||
server_id: str = record["server_id"]
|
||||
api_key_value: str = record["api_key"]
|
||||
# Prefer the user_id that was stored when the code was issued; fall back to
|
||||
# whatever client_id the token request supplies (they should match).
|
||||
user_id: str = record.get("user_id") or client_id
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot determine user_id; pass LiteLLM user id as client_id",
|
||||
)
|
||||
|
||||
# Persist the BYOK credential
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
await store_user_credential(
|
||||
prisma_client=prisma_client,
|
||||
user_id=user_id,
|
||||
server_id=server_id,
|
||||
credential=api_key_value,
|
||||
)
|
||||
# Invalidate any cached negative result so the user isn't blocked
|
||||
# for up to the TTL period after completing the OAuth flow.
|
||||
from litellm.proxy._experimental.mcp_server.server import (
|
||||
_invalidate_byok_cred_cache,
|
||||
)
|
||||
|
||||
_invalidate_byok_cred_cache(user_id, server_id)
|
||||
except Exception as exc:
|
||||
verbose_proxy_logger.error(
|
||||
"byok_token: failed to store user credential for user=%s server=%s: %s",
|
||||
user_id,
|
||||
server_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Failed to store credential")
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"byok_token: prisma_client is None — credential not persisted"
|
||||
)
|
||||
|
||||
if master_key is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Master key not configured; cannot issue token"
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"server_id": server_id,
|
||||
# "type" distinguishes this from regular proxy auth tokens.
|
||||
# The proxy's SSO JWT path uses asymmetric keys (RS256/ES256), so an
|
||||
# HS256 token signed with master_key cannot be accepted there.
|
||||
"type": "byok_session",
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
}
|
||||
access_token = jwt.encode(payload, cast(str, master_key), algorithm="HS256")
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Cost calculator for MCP tools.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from litellm.types.mcp import MCPServerCostInfo
|
||||
from litellm.types.utils import StandardLoggingMCPToolCall
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LitellmLoggingObject,
|
||||
)
|
||||
else:
|
||||
LitellmLoggingObject = Any
|
||||
|
||||
|
||||
class MCPCostCalculator:
|
||||
@staticmethod
|
||||
def calculate_mcp_tool_call_cost(
|
||||
litellm_logging_obj: Optional[LitellmLoggingObject],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of an MCP tool call.
|
||||
|
||||
Default is 0.0, unless user specifies a custom cost per request for MCP tools.
|
||||
"""
|
||||
if litellm_logging_obj is None:
|
||||
return 0.0
|
||||
|
||||
#########################################################
|
||||
# Get the response cost from logging object model_call_details
|
||||
# This is set when a user modifies the response in a post_mcp_tool_call_hook
|
||||
#########################################################
|
||||
response_cost = litellm_logging_obj.model_call_details.get(
|
||||
"response_cost", None
|
||||
)
|
||||
if response_cost is not None:
|
||||
return response_cost
|
||||
|
||||
#########################################################
|
||||
# Unpack the mcp_tool_call_metadata
|
||||
#########################################################
|
||||
mcp_tool_call_metadata: StandardLoggingMCPToolCall = (
|
||||
cast(
|
||||
StandardLoggingMCPToolCall,
|
||||
litellm_logging_obj.model_call_details.get(
|
||||
"mcp_tool_call_metadata", {}
|
||||
),
|
||||
)
|
||||
or {}
|
||||
)
|
||||
mcp_server_cost_info: MCPServerCostInfo = (
|
||||
mcp_tool_call_metadata.get("mcp_server_cost_info") or MCPServerCostInfo()
|
||||
)
|
||||
#########################################################
|
||||
# User defined cost per query
|
||||
#########################################################
|
||||
default_cost_per_query = mcp_server_cost_info.get(
|
||||
"default_cost_per_query", None
|
||||
)
|
||||
tool_name_to_cost_per_query: dict = (
|
||||
mcp_server_cost_info.get("tool_name_to_cost_per_query", {}) or {}
|
||||
)
|
||||
tool_name = mcp_tool_call_metadata.get("name", "")
|
||||
|
||||
#########################################################
|
||||
# 1. If tool_name is in tool_name_to_cost_per_query, use the cost per query
|
||||
# 2. If tool_name is not in tool_name_to_cost_per_query, use the default cost per query
|
||||
# 3. Default to 0.0 if no cost per query is found
|
||||
#########################################################
|
||||
cost_per_query: float = 0.0
|
||||
if tool_name in tool_name_to_cost_per_query:
|
||||
cost_per_query = tool_name_to_cost_per_query[tool_name]
|
||||
elif default_cost_per_query is not None:
|
||||
cost_per_query = default_cost_per_query
|
||||
return cost_per_query
|
||||
@@ -0,0 +1,767 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTable,
|
||||
MCPApprovalStatus,
|
||||
MCPSubmissionsSummary,
|
||||
NewMCPServerRequest,
|
||||
SpecialMCPServerName,
|
||||
UpdateMCPServerRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
_get_salt_key,
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.mcp import MCPCredentials
|
||||
|
||||
|
||||
def _prepare_mcp_server_data(
|
||||
data: Union[NewMCPServerRequest, UpdateMCPServerRequest],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper function to prepare MCP server data for database operations.
|
||||
Handles JSON field serialization for mcp_info and env fields.
|
||||
|
||||
Args:
|
||||
data: NewMCPServerRequest or UpdateMCPServerRequest object
|
||||
|
||||
Returns:
|
||||
Dict with properly serialized JSON fields
|
||||
"""
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
# Convert model to dict
|
||||
data_dict = data.model_dump(exclude_none=True)
|
||||
# Ensure alias is always present in the dict (even if None)
|
||||
if "alias" not in data_dict:
|
||||
data_dict["alias"] = getattr(data, "alias", None)
|
||||
|
||||
# Handle credentials serialization
|
||||
credentials = data_dict.get("credentials")
|
||||
if credentials is not None:
|
||||
data_dict["credentials"] = encrypt_credentials(
|
||||
credentials=credentials, encryption_key=_get_salt_key()
|
||||
)
|
||||
data_dict["credentials"] = safe_dumps(data_dict["credentials"])
|
||||
|
||||
# Handle static_headers serialization
|
||||
if data.static_headers is not None:
|
||||
data_dict["static_headers"] = safe_dumps(data.static_headers)
|
||||
|
||||
# Handle mcp_info serialization
|
||||
if data.mcp_info is not None:
|
||||
data_dict["mcp_info"] = safe_dumps(data.mcp_info)
|
||||
|
||||
# Handle env serialization
|
||||
if data.env is not None:
|
||||
data_dict["env"] = safe_dumps(data.env)
|
||||
|
||||
# Handle tool name override serialization
|
||||
if data.tool_name_to_display_name is not None:
|
||||
data_dict["tool_name_to_display_name"] = safe_dumps(
|
||||
data.tool_name_to_display_name
|
||||
)
|
||||
if data.tool_name_to_description is not None:
|
||||
data_dict["tool_name_to_description"] = safe_dumps(
|
||||
data.tool_name_to_description
|
||||
)
|
||||
|
||||
# mcp_access_groups is already List[str], no serialization needed
|
||||
|
||||
# Force include is_byok even when False (exclude_none=True would not drop it,
|
||||
# but be explicit to ensure a False value is always written to the DB).
|
||||
data_dict["is_byok"] = getattr(data, "is_byok", False)
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def encrypt_credentials(
|
||||
credentials: MCPCredentials, encryption_key: Optional[str]
|
||||
) -> MCPCredentials:
|
||||
auth_value = credentials.get("auth_value")
|
||||
if auth_value is not None:
|
||||
credentials["auth_value"] = encrypt_value_helper(
|
||||
value=auth_value,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
client_id = credentials.get("client_id")
|
||||
if client_id is not None:
|
||||
credentials["client_id"] = encrypt_value_helper(
|
||||
value=client_id,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
client_secret = credentials.get("client_secret")
|
||||
if client_secret is not None:
|
||||
credentials["client_secret"] = encrypt_value_helper(
|
||||
value=client_secret,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
# AWS SigV4 credential fields
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
if aws_access_key_id is not None:
|
||||
credentials["aws_access_key_id"] = encrypt_value_helper(
|
||||
value=aws_access_key_id,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
if aws_secret_access_key is not None:
|
||||
credentials["aws_secret_access_key"] = encrypt_value_helper(
|
||||
value=aws_secret_access_key,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
aws_session_token = credentials.get("aws_session_token")
|
||||
if aws_session_token is not None:
|
||||
credentials["aws_session_token"] = encrypt_value_helper(
|
||||
value=aws_session_token,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
# aws_region_name and aws_service_name are NOT secrets — stored as-is
|
||||
return credentials
|
||||
|
||||
|
||||
def decrypt_credentials(
|
||||
credentials: MCPCredentials,
|
||||
) -> MCPCredentials:
|
||||
"""Decrypt all secret fields in an MCPCredentials dict using the global salt key."""
|
||||
secret_fields = [
|
||||
"auth_value",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
]
|
||||
for field in secret_fields:
|
||||
value = credentials.get(field) # type: ignore[literal-required]
|
||||
if value is not None and isinstance(value, str):
|
||||
credentials[field] = decrypt_value_helper( # type: ignore[literal-required]
|
||||
value=value,
|
||||
key=field,
|
||||
exception_type="debug",
|
||||
return_original_value=True,
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
async def get_all_mcp_servers(
|
||||
prisma_client: PrismaClient,
|
||||
approval_status: Optional[str] = None,
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Returns mcp servers from the db, optionally filtered by approval_status.
|
||||
Pass approval_status=None to return all servers regardless of approval state.
|
||||
"""
|
||||
try:
|
||||
where: Dict[str, Any] = {}
|
||||
if approval_status is not None:
|
||||
where["approval_status"] = approval_status
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
where=where if where else {}
|
||||
)
|
||||
|
||||
return [
|
||||
LiteLLM_MCPServerTable(**mcp_server.model_dump())
|
||||
for mcp_server in mcp_servers
|
||||
]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy._experimental.mcp_server.db.py::get_all_mcp_servers - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
async def get_mcp_server(
|
||||
prisma_client: PrismaClient, server_id: str
|
||||
) -> Optional[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Returns the matching mcp server from the db iff exists
|
||||
"""
|
||||
mcp_server: Optional[
|
||||
LiteLLM_MCPServerTable
|
||||
] = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
}
|
||||
)
|
||||
return mcp_server
|
||||
|
||||
|
||||
async def get_mcp_servers(
|
||||
prisma_client: PrismaClient, server_ids: Iterable[str]
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Returns the matching mcp servers from the db with the server_ids
|
||||
"""
|
||||
_mcp_servers: List[
|
||||
LiteLLM_MCPServerTable
|
||||
] = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
where={
|
||||
"server_id": {"in": server_ids},
|
||||
}
|
||||
)
|
||||
final_mcp_servers: List[LiteLLM_MCPServerTable] = []
|
||||
for _mcp_server in _mcp_servers:
|
||||
final_mcp_servers.append(LiteLLM_MCPServerTable(**_mcp_server.model_dump()))
|
||||
|
||||
return final_mcp_servers
|
||||
|
||||
|
||||
async def get_mcp_servers_by_verificationtoken(
|
||||
prisma_client: PrismaClient, token: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns the mcp servers from the db for the verification token
|
||||
"""
|
||||
verification_token_record: LiteLLM_TeamTable = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
if (
|
||||
verification_token_record is not None
|
||||
and verification_token_record.object_permission is not None
|
||||
):
|
||||
mcp_servers = verification_token_record.object_permission.mcp_servers
|
||||
return mcp_servers or []
|
||||
|
||||
|
||||
async def get_mcp_servers_by_team(
|
||||
prisma_client: PrismaClient, team_id: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns the mcp servers from the db for the team id
|
||||
"""
|
||||
team_record: LiteLLM_TeamTable = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={
|
||||
"team_id": team_id,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
if team_record is not None and team_record.object_permission is not None:
|
||||
mcp_servers = team_record.object_permission.mcp_servers
|
||||
return mcp_servers or []
|
||||
|
||||
|
||||
async def get_all_mcp_servers_for_user(
|
||||
prisma_client: PrismaClient,
|
||||
user: UserAPIKeyAuth,
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Get all the mcp servers filtered by the given user has access to.
|
||||
|
||||
Following Least-Privilege Principle - the requestor should only be able to see the mcp servers that they have access to.
|
||||
"""
|
||||
|
||||
mcp_server_ids: Set[str] = set()
|
||||
mcp_servers = []
|
||||
|
||||
# Get the mcp servers for the key
|
||||
if user.api_key:
|
||||
token_mcp_servers = await get_mcp_servers_by_verificationtoken(
|
||||
prisma_client, user.api_key
|
||||
)
|
||||
mcp_server_ids.update(token_mcp_servers)
|
||||
|
||||
# check for special team membership
|
||||
if (
|
||||
SpecialMCPServerName.all_team_servers in mcp_server_ids
|
||||
and user.team_id is not None
|
||||
):
|
||||
team_mcp_servers = await get_mcp_servers_by_team(
|
||||
prisma_client, user.team_id
|
||||
)
|
||||
mcp_server_ids.update(team_mcp_servers)
|
||||
|
||||
if len(mcp_server_ids) > 0:
|
||||
mcp_servers = await get_mcp_servers(prisma_client, mcp_server_ids)
|
||||
|
||||
return mcp_servers
|
||||
|
||||
|
||||
async def get_objectpermissions_for_mcp_server(
|
||||
prisma_client: PrismaClient, mcp_server_id: str
|
||||
) -> List[LiteLLM_ObjectPermissionTable]:
|
||||
"""
|
||||
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server
|
||||
"""
|
||||
object_permission_records = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_many(
|
||||
where={
|
||||
"mcp_servers": {"has": mcp_server_id},
|
||||
},
|
||||
include={
|
||||
"teams": True,
|
||||
"verification_tokens": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return object_permission_records
|
||||
|
||||
|
||||
async def get_virtualkeys_for_mcp_server(
|
||||
prisma_client: PrismaClient, server_id: str
|
||||
) -> List:
|
||||
"""
|
||||
Get all the virtual keys that have access to the mcp server
|
||||
"""
|
||||
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"mcp_servers": {"has": server_id},
|
||||
},
|
||||
)
|
||||
|
||||
if virtual_keys is None:
|
||||
return []
|
||||
return virtual_keys
|
||||
|
||||
|
||||
async def delete_mcp_server_from_team(prisma_client: PrismaClient, server_id: str):
|
||||
"""
|
||||
Remove the mcp server from the team
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def delete_mcp_server_from_virtualkey():
|
||||
"""
|
||||
Remove the mcp server from the virtual key
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def delete_mcp_server(
|
||||
prisma_client: PrismaClient, server_id: str
|
||||
) -> Optional[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Delete the mcp server from the db by server_id
|
||||
|
||||
Returns the deleted mcp server record if it exists, otherwise None
|
||||
"""
|
||||
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
},
|
||||
)
|
||||
return deleted_server
|
||||
|
||||
|
||||
async def create_mcp_server(
|
||||
prisma_client: PrismaClient, data: NewMCPServerRequest, touched_by: str
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""
|
||||
Create a new mcp server record in the db
|
||||
"""
|
||||
if data.server_id is None:
|
||||
data.server_id = str(uuid.uuid4())
|
||||
|
||||
# Use helper to prepare data with proper JSON serialization
|
||||
data_dict = _prepare_mcp_server_data(data)
|
||||
|
||||
# Add audit fields
|
||||
data_dict["created_by"] = touched_by
|
||||
data_dict["updated_by"] = touched_by
|
||||
|
||||
new_mcp_server = await prisma_client.db.litellm_mcpservertable.create(
|
||||
data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
return new_mcp_server
|
||||
|
||||
|
||||
async def update_mcp_server(
|
||||
prisma_client: PrismaClient, data: UpdateMCPServerRequest, touched_by: str
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""
|
||||
Update a new mcp server record in the db
|
||||
"""
|
||||
import json
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
# Use helper to prepare data with proper JSON serialization
|
||||
data_dict = _prepare_mcp_server_data(data)
|
||||
|
||||
# Pre-fetch existing record once if we need it for auth_type or credential logic
|
||||
existing = None
|
||||
has_credentials = (
|
||||
"credentials" in data_dict and data_dict["credentials"] is not None
|
||||
)
|
||||
if data.auth_type or has_credentials:
|
||||
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
where={"server_id": data.server_id}
|
||||
)
|
||||
|
||||
# Clear stale credentials when auth_type changes but no new credentials provided
|
||||
if (
|
||||
data.auth_type
|
||||
and "credentials" not in data_dict
|
||||
and existing
|
||||
and existing.auth_type is not None
|
||||
and existing.auth_type != data.auth_type
|
||||
):
|
||||
data_dict["credentials"] = None
|
||||
|
||||
# Merge credentials: preserve existing fields not present in the update.
|
||||
# Without this, a partial credential update (e.g. changing only region)
|
||||
# would wipe encrypted secrets that the UI cannot display back.
|
||||
if "credentials" in data_dict and data_dict["credentials"] is not None:
|
||||
if existing and existing.credentials:
|
||||
# Only merge when auth_type is unchanged. Switching auth types
|
||||
# (e.g. oauth2 → api_key) should replace credentials entirely
|
||||
# to avoid stale secrets from the previous auth type lingering.
|
||||
auth_type_unchanged = (
|
||||
data.auth_type is None or data.auth_type == existing.auth_type
|
||||
)
|
||||
if auth_type_unchanged:
|
||||
existing_creds = (
|
||||
json.loads(existing.credentials)
|
||||
if isinstance(existing.credentials, str)
|
||||
else dict(existing.credentials)
|
||||
)
|
||||
new_creds = (
|
||||
json.loads(data_dict["credentials"])
|
||||
if isinstance(data_dict["credentials"], str)
|
||||
else dict(data_dict["credentials"])
|
||||
)
|
||||
# New values override existing; existing keys not in update are preserved
|
||||
merged = {**existing_creds, **new_creds}
|
||||
data_dict["credentials"] = safe_dumps(merged)
|
||||
|
||||
# Add audit fields
|
||||
data_dict["updated_by"] = touched_by
|
||||
|
||||
updated_mcp_server = await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": data.server_id}, data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
return updated_mcp_server
|
||||
|
||||
|
||||
async def rotate_mcp_server_credentials_master_key(
|
||||
prisma_client: PrismaClient, touched_by: str, new_master_key: str
|
||||
):
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
|
||||
|
||||
for mcp_server in mcp_servers:
|
||||
credentials = mcp_server.credentials
|
||||
if not credentials:
|
||||
continue
|
||||
|
||||
credentials_copy = dict(credentials)
|
||||
# Decrypt with current key first, then re-encrypt with new key
|
||||
decrypted_credentials = decrypt_credentials(
|
||||
credentials=cast(MCPCredentials, credentials_copy),
|
||||
)
|
||||
encrypted_credentials = encrypt_credentials(
|
||||
credentials=decrypted_credentials,
|
||||
encryption_key=new_master_key,
|
||||
)
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
serialized_credentials = safe_dumps(encrypted_credentials)
|
||||
|
||||
await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": mcp_server.server_id},
|
||||
data={
|
||||
"credentials": serialized_credentials,
|
||||
"updated_by": touched_by,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def store_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
credential: str,
|
||||
) -> None:
|
||||
"""Store a user credential for a BYOK MCP server."""
|
||||
|
||||
encoded = base64.urlsafe_b64encode(credential.encode()).decode()
|
||||
await prisma_client.db.litellm_mcpusercredentials.upsert(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": user_id,
|
||||
"server_id": server_id,
|
||||
"credential_b64": encoded,
|
||||
},
|
||||
"update": {"credential_b64": encoded},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> Optional[str]:
|
||||
"""Return credential for a user+server pair, or None."""
|
||||
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
try:
|
||||
return base64.urlsafe_b64decode(row.credential_b64).decode()
|
||||
except Exception:
|
||||
# Fall back to nacl decryption for credentials stored by older code
|
||||
return decrypt_value_helper(
|
||||
value=row.credential_b64,
|
||||
key="byok_credential",
|
||||
exception_type="debug",
|
||||
return_original_value=False,
|
||||
)
|
||||
|
||||
|
||||
async def has_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> bool:
|
||||
"""Return True if the user has a stored credential for this server."""
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
return row is not None
|
||||
|
||||
|
||||
async def delete_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> None:
|
||||
"""Delete the user's stored credential for a BYOK MCP server."""
|
||||
await prisma_client.db.litellm_mcpusercredentials.delete(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth2 user-credential helpers ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def store_user_oauth_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str] = None,
|
||||
expires_in: Optional[int] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Persist an OAuth2 access token for a user+server pair.
|
||||
|
||||
The payload is JSON-serialised and stored base64-encoded in the same
|
||||
``credential_b64`` column used by BYOK. A ``"type": "oauth2"`` key
|
||||
differentiates it from plain BYOK API keys.
|
||||
"""
|
||||
|
||||
expires_at: Optional[str] = None
|
||||
if expires_in is not None:
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
).isoformat()
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"type": "oauth2",
|
||||
"access_token": access_token,
|
||||
"connected_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if refresh_token:
|
||||
payload["refresh_token"] = refresh_token
|
||||
if expires_at:
|
||||
payload["expires_at"] = expires_at
|
||||
if scopes:
|
||||
payload["scopes"] = scopes
|
||||
|
||||
# Guard against silently overwriting a BYOK credential with an OAuth token.
|
||||
# BYOK credentials lack a "type" field (or use a non-"oauth2" type).
|
||||
existing = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if existing is not None:
|
||||
_byok_error = ValueError(
|
||||
f"A non-OAuth2 credential already exists for user {user_id} "
|
||||
f"and server {server_id}. Refusing to overwrite."
|
||||
)
|
||||
try:
|
||||
raw = json.loads(base64.urlsafe_b64decode(existing.credential_b64).decode())
|
||||
except Exception:
|
||||
# Credential is not base64+JSON — it's a plain-text BYOK key.
|
||||
raise _byok_error
|
||||
if raw.get("type") != "oauth2":
|
||||
raise _byok_error
|
||||
|
||||
encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
|
||||
await prisma_client.db.litellm_mcpusercredentials.upsert(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": user_id,
|
||||
"server_id": server_id,
|
||||
"credential_b64": encoded,
|
||||
},
|
||||
"update": {"credential_b64": encoded},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool:
|
||||
"""Return True if the OAuth2 credential's access_token has expired.
|
||||
|
||||
Checks the ``expires_at`` ISO-format string stored in the credential payload.
|
||||
Returns False when ``expires_at`` is absent or unparseable (treat as non-expired).
|
||||
"""
|
||||
expires_at = cred.get("expires_at")
|
||||
if not expires_at:
|
||||
return False
|
||||
try:
|
||||
exp_dt = datetime.fromisoformat(expires_at)
|
||||
if exp_dt.tzinfo is None:
|
||||
exp_dt = exp_dt.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > exp_dt
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
async def get_user_oauth_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Return the decoded OAuth2 payload dict for a user+server pair, or None."""
|
||||
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
|
||||
parsed = json.loads(decoded)
|
||||
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
|
||||
return parsed
|
||||
# Row exists but is a BYOK (plain string), not an OAuth token
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def list_user_oauth_credentials(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return all OAuth2 credential payloads for a user, tagged with server_id."""
|
||||
|
||||
rows = await prisma_client.db.litellm_mcpusercredentials.find_many(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
results: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
|
||||
parsed = json.loads(decoded)
|
||||
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
|
||||
parsed["server_id"] = row.server_id
|
||||
results.append(parsed)
|
||||
except Exception:
|
||||
pass # Skip non-OAuth rows (BYOK plain strings)
|
||||
return results
|
||||
|
||||
|
||||
async def approve_mcp_server(
|
||||
prisma_client: PrismaClient,
|
||||
server_id: str,
|
||||
touched_by: str,
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""Set approval_status=active and record reviewed_at."""
|
||||
now = datetime.now(timezone.utc)
|
||||
updated = await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": server_id},
|
||||
data={
|
||||
"approval_status": MCPApprovalStatus.active,
|
||||
"reviewed_at": now,
|
||||
"updated_by": touched_by,
|
||||
},
|
||||
)
|
||||
return LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
|
||||
|
||||
async def reject_mcp_server(
|
||||
prisma_client: PrismaClient,
|
||||
server_id: str,
|
||||
touched_by: str,
|
||||
review_notes: Optional[str] = None,
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""Set approval_status=rejected, record reviewed_at and review_notes."""
|
||||
now = datetime.now(timezone.utc)
|
||||
data: Dict[str, Any] = {
|
||||
"approval_status": MCPApprovalStatus.rejected,
|
||||
"reviewed_at": now,
|
||||
"updated_by": touched_by,
|
||||
}
|
||||
if review_notes is not None:
|
||||
data["review_notes"] = review_notes
|
||||
updated = await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": server_id},
|
||||
data=data,
|
||||
)
|
||||
return LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
|
||||
|
||||
async def get_mcp_submissions(
|
||||
prisma_client: PrismaClient,
|
||||
) -> MCPSubmissionsSummary:
|
||||
"""
|
||||
Returns all MCP servers that were submitted by non-admin users (submitted_at IS NOT NULL),
|
||||
along with a summary count breakdown by approval_status.
|
||||
Mirrors get_guardrail_submissions() from guardrail_endpoints.py.
|
||||
"""
|
||||
rows = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
where={"submitted_at": {"not": None}},
|
||||
order={"submitted_at": "desc"},
|
||||
take=500, # safety cap; paginate if needed in a future iteration
|
||||
)
|
||||
items = [LiteLLM_MCPServerTable(**r.model_dump()) for r in rows]
|
||||
|
||||
pending = sum(
|
||||
1 for i in items if i.approval_status == MCPApprovalStatus.pending_review
|
||||
)
|
||||
active = sum(1 for i in items if i.approval_status == MCPApprovalStatus.active)
|
||||
rejected = sum(1 for i in items if i.approval_status == MCPApprovalStatus.rejected)
|
||||
|
||||
return MCPSubmissionsSummary(
|
||||
total=len(items),
|
||||
pending_review=pending,
|
||||
active=active,
|
||||
rejected=rejected,
|
||||
items=items,
|
||||
)
|
||||
@@ -0,0 +1,741 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.utils import get_server_root_path
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
router = APIRouter(
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
|
||||
def get_request_base_url(request: Request) -> str:
|
||||
"""
|
||||
Get the base URL for the request, considering X-Forwarded-* headers.
|
||||
|
||||
When behind a proxy (like nginx), the proxy may set:
|
||||
- X-Forwarded-Proto: The original protocol (http/https)
|
||||
- X-Forwarded-Host: The original host (may include port)
|
||||
- X-Forwarded-Port: The original port (if not in Host header)
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
The reconstructed base URL (e.g., "https://proxy.example.com")
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip("/")
|
||||
parsed = urlparse(base_url)
|
||||
|
||||
# Get forwarded headers
|
||||
x_forwarded_proto = request.headers.get("X-Forwarded-Proto")
|
||||
x_forwarded_host = request.headers.get("X-Forwarded-Host")
|
||||
x_forwarded_port = request.headers.get("X-Forwarded-Port")
|
||||
|
||||
# Start with the original scheme
|
||||
scheme = x_forwarded_proto if x_forwarded_proto else parsed.scheme
|
||||
|
||||
# Handle host and port
|
||||
if x_forwarded_host:
|
||||
# X-Forwarded-Host may already include port (e.g., "example.com:8080")
|
||||
if ":" in x_forwarded_host and not x_forwarded_host.startswith("["):
|
||||
# Host includes port
|
||||
netloc = x_forwarded_host
|
||||
elif x_forwarded_port:
|
||||
# Port is separate
|
||||
netloc = f"{x_forwarded_host}:{x_forwarded_port}"
|
||||
else:
|
||||
# Just host, no explicit port
|
||||
netloc = x_forwarded_host
|
||||
else:
|
||||
# No X-Forwarded-Host, use original netloc
|
||||
netloc = parsed.netloc
|
||||
if x_forwarded_port and ":" not in netloc:
|
||||
# Add forwarded port if not already in netloc
|
||||
netloc = f"{netloc}:{x_forwarded_port}"
|
||||
|
||||
# Reconstruct the URL
|
||||
return urlunparse((scheme, netloc, parsed.path, "", "", ""))
|
||||
|
||||
|
||||
def encode_state_with_base_url(
|
||||
base_url: str,
|
||||
original_state: str,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
client_redirect_uri: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encode the base_url, original state, and PKCE parameters using encryption.
|
||||
|
||||
Args:
|
||||
base_url: The base URL to encode
|
||||
original_state: The original state parameter
|
||||
code_challenge: PKCE code challenge from client
|
||||
code_challenge_method: PKCE code challenge method from client
|
||||
client_redirect_uri: Original redirect_uri from client
|
||||
|
||||
Returns:
|
||||
An encrypted string that encodes all values
|
||||
"""
|
||||
state_data = {
|
||||
"base_url": base_url,
|
||||
"original_state": original_state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"client_redirect_uri": client_redirect_uri,
|
||||
}
|
||||
state_json = json.dumps(state_data, sort_keys=True)
|
||||
encrypted_state = encrypt_value_helper(state_json)
|
||||
return encrypted_state
|
||||
|
||||
|
||||
def decode_state_hash(encrypted_state: str) -> dict:
|
||||
"""
|
||||
Decode an encrypted state to retrieve all OAuth session data.
|
||||
|
||||
Args:
|
||||
encrypted_state: The encrypted string to decode
|
||||
|
||||
Returns:
|
||||
A dict containing base_url, original_state, and optional PKCE parameters
|
||||
|
||||
Raises:
|
||||
Exception: If decryption fails or data is malformed
|
||||
"""
|
||||
decrypted_json = decrypt_value_helper(encrypted_state, "oauth_state")
|
||||
if decrypted_json is None:
|
||||
raise ValueError("Failed to decrypt state parameter")
|
||||
|
||||
state_data = json.loads(decrypted_json)
|
||||
return state_data
|
||||
|
||||
|
||||
def _resolve_oauth2_server_for_root_endpoints(
|
||||
client_ip: Optional[str] = None,
|
||||
) -> Optional[MCPServer]:
|
||||
"""
|
||||
Resolve the MCP server for root-level OAuth endpoints (no server name in path).
|
||||
|
||||
When the MCP SDK hits root-level endpoints like /register, /authorize, /token
|
||||
without a server name prefix, we try to find the right server automatically.
|
||||
Returns the server if exactly one OAuth2 server is configured, else None.
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip)
|
||||
oauth2_servers = [s for s in registry.values() if s.auth_type == MCPAuth.oauth2]
|
||||
if len(oauth2_servers) == 1:
|
||||
return oauth2_servers[0]
|
||||
return None
|
||||
|
||||
|
||||
async def authorize_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
state: str = "",
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
if mcp_server.auth_type != "oauth2":
|
||||
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
parsed = urlparse(redirect_uri)
|
||||
base_url = urlunparse(parsed._replace(query=""))
|
||||
request_base_url = get_request_base_url(request)
|
||||
encoded_state = encode_state_with_base_url(
|
||||
base_url=base_url,
|
||||
original_state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
client_redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
params = {
|
||||
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
|
||||
"redirect_uri": f"{request_base_url}/callback",
|
||||
"state": encoded_state,
|
||||
"response_type": response_type or "code",
|
||||
}
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
elif mcp_server.scopes:
|
||||
params["scope"] = " ".join(mcp_server.scopes)
|
||||
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
if code_challenge_method:
|
||||
params["code_challenge_method"] = code_challenge_method
|
||||
|
||||
parsed_auth_url = urlparse(mcp_server.authorization_url)
|
||||
existing_params = dict(parse_qsl(parsed_auth_url.query))
|
||||
existing_params.update(params)
|
||||
final_url = urlunparse(parsed_auth_url._replace(query=urlencode(existing_params)))
|
||||
return RedirectResponse(final_url)
|
||||
|
||||
|
||||
async def exchange_token_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
grant_type: str,
|
||||
code: Optional[str],
|
||||
redirect_uri: Optional[str],
|
||||
client_id: str,
|
||||
client_secret: Optional[str],
|
||||
code_verifier: Optional[str],
|
||||
):
|
||||
if grant_type != "authorization_code":
|
||||
raise HTTPException(status_code=400, detail="Unsupported grant_type")
|
||||
|
||||
if mcp_server.token_url is None:
|
||||
raise HTTPException(status_code=400, detail="MCP server token url is not set")
|
||||
|
||||
proxy_base_url = get_request_base_url(request)
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
|
||||
"client_secret": mcp_server.client_secret
|
||||
if mcp_server.client_secret
|
||||
else client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": f"{proxy_base_url}/callback",
|
||||
}
|
||||
|
||||
if code_verifier:
|
||||
token_data["code_verifier"] = code_verifier
|
||||
|
||||
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
response = await async_client.post(
|
||||
mcp_server.token_url,
|
||||
headers={"Accept": "application/json"},
|
||||
data=token_data,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
token_response = response.json()
|
||||
access_token = token_response["access_token"]
|
||||
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": token_response.get("token_type", "Bearer"),
|
||||
"expires_in": token_response.get("expires_in", 3600),
|
||||
}
|
||||
|
||||
if "refresh_token" in token_response and token_response["refresh_token"]:
|
||||
result["refresh_token"] = token_response["refresh_token"]
|
||||
if "scope" in token_response and token_response["scope"]:
|
||||
result["scope"] = token_response["scope"]
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
async def register_client_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
client_name: str,
|
||||
grant_types: Optional[list],
|
||||
response_types: Optional[list],
|
||||
token_endpoint_auth_method: Optional[str],
|
||||
fallback_client_id: Optional[str] = None,
|
||||
):
|
||||
request_base_url = get_request_base_url(request)
|
||||
dummy_return = {
|
||||
"client_id": fallback_client_id or mcp_server.server_name,
|
||||
"client_secret": "dummy",
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
}
|
||||
|
||||
if mcp_server.client_id and mcp_server.client_secret:
|
||||
return dummy_return
|
||||
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
if mcp_server.registration_url is None:
|
||||
return dummy_return
|
||||
|
||||
register_data = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
"grant_types": grant_types or [],
|
||||
"response_types": response_types or [],
|
||||
"token_endpoint_auth_method": token_endpoint_auth_method or "",
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Oauth2Register
|
||||
)
|
||||
response = await async_client.post(
|
||||
mcp_server.registration_url,
|
||||
headers=headers,
|
||||
json=register_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
token_response = response.json()
|
||||
|
||||
return JSONResponse(token_response)
|
||||
|
||||
|
||||
@router.get("/{mcp_server_name}/authorize")
|
||||
@router.get("/authorize")
|
||||
async def authorize(
|
||||
request: Request,
|
||||
redirect_uri: str,
|
||||
client_id: Optional[str] = None,
|
||||
state: str = "",
|
||||
mcp_server_name: Optional[str] = None,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
# Redirect to real OAuth provider with PKCE support
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
lookup_name: Optional[str] = mcp_server_name or client_id
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = (
|
||||
global_mcp_server_manager.get_mcp_server_by_name(
|
||||
lookup_name, client_ip=client_ip
|
||||
)
|
||||
if lookup_name
|
||||
else None
|
||||
)
|
||||
if mcp_server is None and mcp_server_name is None:
|
||||
mcp_server = _resolve_oauth2_server_for_root_endpoints()
|
||||
if mcp_server is None:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
# Use server's stored client_id when caller doesn't supply one.
|
||||
# Raise a clear error instead of passing an empty string — an empty
|
||||
# client_id would silently produce a broken authorization URL.
|
||||
resolved_client_id: str = mcp_server.client_id or client_id or ""
|
||||
if not resolved_client_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "client_id is required but was not supplied and is not "
|
||||
"stored on the MCP server record. Provide client_id as a query "
|
||||
"parameter or configure it on the server."
|
||||
},
|
||||
)
|
||||
return await authorize_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_id=resolved_client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
response_type=response_type,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{mcp_server_name}/token")
|
||||
@router.post("/token")
|
||||
async def token_endpoint(
|
||||
request: Request,
|
||||
grant_type: str = Form(...),
|
||||
code: str = Form(None),
|
||||
redirect_uri: str = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: Optional[str] = Form(None),
|
||||
code_verifier: str = Form(None),
|
||||
mcp_server_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Accept the authorization code from client and exchange it for OAuth token.
|
||||
Supports PKCE flow by forwarding code_verifier to upstream provider.
|
||||
|
||||
1. Call the token endpoint with PKCE parameters
|
||||
2. Store the user's token in the db - and generate a LiteLLM virtual key
|
||||
3. Return the token
|
||||
4. Return a virtual key in this response
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
lookup_name = mcp_server_name or client_id
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
lookup_name, client_ip=client_ip
|
||||
)
|
||||
if mcp_server is None and mcp_server_name is None:
|
||||
mcp_server = _resolve_oauth2_server_for_root_endpoints()
|
||||
if mcp_server is None:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
return await exchange_token_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
grant_type=grant_type,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def callback(code: str, state: str):
|
||||
try:
|
||||
# Decode the state hash to get base_url, original state, and PKCE params
|
||||
state_data = decode_state_hash(state)
|
||||
base_url = state_data["base_url"]
|
||||
original_state = state_data["original_state"]
|
||||
|
||||
# Forward code and original state back to client
|
||||
params = {"code": code, "state": original_state}
|
||||
|
||||
# Forward to client's callback endpoint
|
||||
complete_returned_url = f"{base_url}?{urlencode(params)}"
|
||||
return RedirectResponse(url=complete_returned_url, status_code=302)
|
||||
|
||||
except Exception:
|
||||
# fallback if state hash not found
|
||||
return HTMLResponse(
|
||||
"<html><body>Authentication incomplete. You can close this window.</body></html>"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Optional .well-known endpoints for MCP + OAuth discovery
|
||||
# ------------------------------
|
||||
"""
|
||||
Per SEP-985, the client MUST:
|
||||
1. Try resource_metadata from WWW-Authenticate header (if present)
|
||||
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
|
||||
(
|
||||
If the resource identifier value contains a path or query component, any terminating slash (/)
|
||||
following the host component MUST be removed before inserting /.well-known/ and the well-known
|
||||
URI path suffix between the host component and the path(include root path) and/or query components.
|
||||
https://datatracker.ietf.org/doc/html/rfc9728#section-3.1)
|
||||
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
|
||||
|
||||
Dual Pattern Support:
|
||||
- Standard MCP pattern: /mcp/{server_name} (recommended, used by mcp-inspector, VSCode Copilot)
|
||||
- LiteLLM legacy pattern: /{server_name}/mcp (backward compatibility)
|
||||
|
||||
The resource URL returned matches the pattern used in the discovery request.
|
||||
"""
|
||||
|
||||
|
||||
def _build_oauth_protected_resource_response(
|
||||
request: Request,
|
||||
mcp_server_name: Optional[str],
|
||||
use_standard_pattern: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Build OAuth protected resource response with the appropriate URL pattern.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
mcp_server_name: Name of the MCP server
|
||||
use_standard_pattern: If True, use /mcp/{server_name} pattern;
|
||||
if False, use /{server_name}/mcp pattern
|
||||
|
||||
Returns:
|
||||
OAuth protected resource metadata dict
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
request_base_url = get_request_base_url(request)
|
||||
|
||||
# When no server name provided, try to resolve the single OAuth2 server
|
||||
if mcp_server_name is None:
|
||||
resolved = _resolve_oauth2_server_for_root_endpoints()
|
||||
if resolved:
|
||||
mcp_server_name = resolved.server_name or resolved.name
|
||||
|
||||
mcp_server: Optional[MCPServer] = None
|
||||
if mcp_server_name:
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
mcp_server_name, client_ip=client_ip
|
||||
)
|
||||
|
||||
# Build resource URL based on the pattern
|
||||
if mcp_server_name:
|
||||
if use_standard_pattern:
|
||||
# Standard MCP pattern: /mcp/{server_name}
|
||||
resource_url = f"{request_base_url}/mcp/{mcp_server_name}"
|
||||
else:
|
||||
# LiteLLM legacy pattern: /{server_name}/mcp
|
||||
resource_url = f"{request_base_url}/{mcp_server_name}/mcp"
|
||||
else:
|
||||
resource_url = f"{request_base_url}/mcp"
|
||||
|
||||
return {
|
||||
"authorization_servers": [
|
||||
(
|
||||
f"{request_base_url}/{mcp_server_name}"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}"
|
||||
)
|
||||
],
|
||||
"resource": resource_url,
|
||||
"scopes_supported": mcp_server.scopes
|
||||
if mcp_server and mcp_server.scopes
|
||||
else [],
|
||||
}
|
||||
|
||||
|
||||
# Standard MCP pattern: /.well-known/oauth-protected-resource/mcp/{server_name}
|
||||
# This is the pattern expected by standard MCP clients (mcp-inspector, VSCode Copilot)
|
||||
@router.get(
|
||||
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
|
||||
)
|
||||
async def oauth_protected_resource_mcp_standard(request: Request, mcp_server_name: str):
|
||||
"""
|
||||
OAuth protected resource discovery endpoint using standard MCP URL pattern.
|
||||
|
||||
Standard pattern: /mcp/{server_name}
|
||||
Discovery path: /.well-known/oauth-protected-resource/mcp/{server_name}
|
||||
|
||||
This endpoint is compliant with MCP specification and works with standard
|
||||
MCP clients like mcp-inspector and VSCode Copilot.
|
||||
"""
|
||||
return _build_oauth_protected_resource_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
use_standard_pattern=True,
|
||||
)
|
||||
|
||||
|
||||
# LiteLLM legacy pattern: /.well-known/oauth-protected-resource/{server_name}/mcp
|
||||
# Kept for backward compatibility with existing deployments
|
||||
@router.get(
|
||||
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}/mcp"
|
||||
)
|
||||
@router.get("/.well-known/oauth-protected-resource")
|
||||
async def oauth_protected_resource_mcp(
|
||||
request: Request, mcp_server_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
OAuth protected resource discovery endpoint using LiteLLM legacy URL pattern.
|
||||
|
||||
Legacy pattern: /{server_name}/mcp
|
||||
Discovery path: /.well-known/oauth-protected-resource/{server_name}/mcp
|
||||
|
||||
This endpoint is kept for backward compatibility. New integrations should
|
||||
use the standard MCP pattern (/mcp/{server_name}) instead.
|
||||
"""
|
||||
return _build_oauth_protected_resource_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
use_standard_pattern=False,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
|
||||
RFC 8414: Path-aware OAuth discovery
|
||||
If the issuer identifier value contains a path component, any
|
||||
terminating "/" MUST be removed before inserting "/.well-known/" and
|
||||
the well-known URI suffix between the host component and the path(include root path)
|
||||
component.
|
||||
"""
|
||||
|
||||
|
||||
def _build_oauth_authorization_server_response(
|
||||
request: Request,
|
||||
mcp_server_name: Optional[str],
|
||||
) -> dict:
|
||||
"""
|
||||
Build OAuth authorization server metadata response.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
mcp_server_name: Name of the MCP server
|
||||
|
||||
Returns:
|
||||
OAuth authorization server metadata dict
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
request_base_url = get_request_base_url(request)
|
||||
|
||||
# When no server name provided, try to resolve the single OAuth2 server
|
||||
if mcp_server_name is None:
|
||||
resolved = _resolve_oauth2_server_for_root_endpoints()
|
||||
if resolved:
|
||||
mcp_server_name = resolved.server_name or resolved.name
|
||||
|
||||
authorization_endpoint = (
|
||||
f"{request_base_url}/{mcp_server_name}/authorize"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}/authorize"
|
||||
)
|
||||
token_endpoint = (
|
||||
f"{request_base_url}/{mcp_server_name}/token"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}/token"
|
||||
)
|
||||
|
||||
mcp_server: Optional[MCPServer] = None
|
||||
if mcp_server_name:
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
mcp_server_name, client_ip=client_ip
|
||||
)
|
||||
|
||||
return {
|
||||
"issuer": request_base_url, # point to your proxy
|
||||
"authorization_endpoint": authorization_endpoint,
|
||||
"token_endpoint": token_endpoint,
|
||||
"response_types_supported": ["code"],
|
||||
"scopes_supported": mcp_server.scopes
|
||||
if mcp_server and mcp_server.scopes
|
||||
else [],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post"],
|
||||
# Claude expects a registration endpoint, even if we just fake it
|
||||
"registration_endpoint": f"{request_base_url}/{mcp_server_name}/register"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}/register",
|
||||
}
|
||||
|
||||
|
||||
# Standard MCP pattern: /.well-known/oauth-authorization-server/mcp/{server_name}
|
||||
@router.get(
|
||||
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
|
||||
)
|
||||
async def oauth_authorization_server_mcp_standard(
|
||||
request: Request, mcp_server_name: str
|
||||
):
|
||||
"""
|
||||
OAuth authorization server discovery endpoint using standard MCP URL pattern.
|
||||
|
||||
Standard pattern: /mcp/{server_name}
|
||||
Discovery path: /.well-known/oauth-authorization-server/mcp/{server_name}
|
||||
"""
|
||||
return _build_oauth_authorization_server_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
)
|
||||
|
||||
|
||||
# LiteLLM legacy pattern and root endpoint
|
||||
@router.get(
|
||||
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}"
|
||||
)
|
||||
@router.get("/.well-known/oauth-authorization-server")
|
||||
async def oauth_authorization_server_mcp(
|
||||
request: Request, mcp_server_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
OAuth authorization server discovery endpoint.
|
||||
|
||||
Supports both legacy pattern (/{server_name}) and root endpoint.
|
||||
"""
|
||||
return _build_oauth_authorization_server_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
)
|
||||
|
||||
|
||||
# Alias for standard OpenID discovery
|
||||
@router.get("/.well-known/openid-configuration")
|
||||
async def openid_configuration(request: Request):
|
||||
return await oauth_authorization_server_mcp(request)
|
||||
|
||||
|
||||
# Additional legacy pattern support
|
||||
@router.get("/.well-known/oauth-authorization-server/{mcp_server_name}/mcp")
|
||||
async def oauth_authorization_server_legacy(request: Request, mcp_server_name: str):
|
||||
"""
|
||||
OAuth authorization server discovery for legacy /{server_name}/mcp pattern.
|
||||
"""
|
||||
return _build_oauth_authorization_server_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{mcp_server_name}/register")
|
||||
@router.post("/register")
|
||||
async def register_client(request: Request, mcp_server_name: Optional[str] = None):
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
# Get the correct base URL considering X-Forwarded-* headers
|
||||
request_base_url = get_request_base_url(request)
|
||||
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data}
|
||||
|
||||
dummy_return = {
|
||||
"client_id": mcp_server_name or "dummy_client",
|
||||
"client_secret": "dummy",
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
}
|
||||
if not mcp_server_name:
|
||||
resolved = _resolve_oauth2_server_for_root_endpoints()
|
||||
if resolved:
|
||||
return await register_client_with_server(
|
||||
request=request,
|
||||
mcp_server=resolved,
|
||||
client_name=data.get("client_name", ""),
|
||||
grant_types=data.get("grant_types", []),
|
||||
response_types=data.get("response_types", []),
|
||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
|
||||
fallback_client_id=resolved.server_name or resolved.name,
|
||||
)
|
||||
return dummy_return
|
||||
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
mcp_server_name, client_ip=client_ip
|
||||
)
|
||||
if mcp_server is None:
|
||||
return dummy_return
|
||||
return await register_client_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_name=data.get("client_name", ""),
|
||||
grant_types=data.get("grant_types", []),
|
||||
response_types=data.get("response_types", []),
|
||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
|
||||
fallback_client_id=mcp_server_name,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Guardrail translation mapping for MCP tool calls."""
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.guardrail_translation.handler import (
|
||||
MCPGuardrailTranslationHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
# This mapping lives alongside the MCP server implementation because MCP
|
||||
# integrations are managed by the proxy subsystem, not litellm.llms providers.
|
||||
# Unified guardrails import this module explicitly to register the handler.
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.call_mcp_tool: MCPGuardrailTranslationHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "MCPGuardrailTranslationHandler"]
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
MCP Guardrail Handler for Unified Guardrails.
|
||||
|
||||
Converts an MCP call_tool (name + arguments) into a single OpenAI-compatible
|
||||
tool_call and passes it to apply_guardrail. Works with the synthetic payload
|
||||
from ProxyLogging._convert_mcp_to_llm_format.
|
||||
|
||||
Note: For MCP tool definitions (schema) -> OpenAI tools=[], see
|
||||
litellm.experimental_mcp_client.tools.transform_mcp_tool_to_openai_tool
|
||||
when you have a full MCP Tool from list_tools. Here we only have the call
|
||||
payload (name + arguments) so we just build the tool_call.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.experimental_mcp_client.tools import transform_mcp_tool_to_openai_tool
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
||||
|
||||
class MCPGuardrailTranslationHandler(BaseTranslation):
|
||||
"""Guardrail translation handler for MCP tool calls (passes a single tool_call to guardrail)."""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
mcp_tool_name = data.get("mcp_tool_name") or data.get("name")
|
||||
mcp_arguments = data.get("mcp_arguments") or data.get("arguments")
|
||||
mcp_tool_description = data.get("mcp_tool_description") or data.get(
|
||||
"description"
|
||||
)
|
||||
if mcp_arguments is None or not isinstance(mcp_arguments, dict):
|
||||
mcp_arguments = {}
|
||||
|
||||
if not mcp_tool_name:
|
||||
verbose_proxy_logger.debug("MCP Guardrail: mcp_tool_name missing")
|
||||
return data
|
||||
|
||||
# Convert MCP input via transform_mcp_tool_to_openai_tool, then map to litellm
|
||||
# ChatCompletionToolParam (openai SDK type has incompatible strict/cache_control).
|
||||
mcp_tool = MCPTool(
|
||||
name=mcp_tool_name,
|
||||
description=mcp_tool_description or "",
|
||||
inputSchema={}, # Call payload has no schema; guardrail gets args from request_data
|
||||
)
|
||||
openai_tool = transform_mcp_tool_to_openai_tool(mcp_tool)
|
||||
fn = openai_tool["function"]
|
||||
tool_def: ChatCompletionToolParam = {
|
||||
"type": "function",
|
||||
"function": ChatCompletionToolParamFunctionChunk(
|
||||
name=fn["name"],
|
||||
description=fn.get("description") or "",
|
||||
parameters=fn.get("parameters")
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
strict=fn.get("strict", False) or False, # Default to False if None
|
||||
),
|
||||
}
|
||||
inputs: GenericGuardrailAPIInputs = GenericGuardrailAPIInputs(
|
||||
tools=[tool_def],
|
||||
)
|
||||
|
||||
await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "CallToolResult",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
verbose_proxy_logger.debug(
|
||||
"MCP Guardrail: Output processing not implemented for MCP tools",
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
MCP OAuth2 Debug Headers
|
||||
========================
|
||||
|
||||
Client-side debugging for MCP authentication flows.
|
||||
|
||||
When a client sends the ``x-litellm-mcp-debug: true`` header, LiteLLM
|
||||
returns masked diagnostic headers in the response so operators can
|
||||
troubleshoot OAuth2 issues without SSH access to the gateway.
|
||||
|
||||
Response headers returned (all values are masked for safety):
|
||||
|
||||
x-mcp-debug-inbound-auth
|
||||
Which inbound auth headers were present and how they were classified.
|
||||
Example: ``x-litellm-api-key=Bearer sk-12****1234``
|
||||
|
||||
x-mcp-debug-oauth2-token
|
||||
The OAuth2 token extracted from the Authorization header (masked).
|
||||
Shows ``(none)`` if absent, or flags ``SAME_AS_LITELLM_KEY`` when
|
||||
the LiteLLM API key is accidentally leaking to the MCP server.
|
||||
|
||||
x-mcp-debug-auth-resolution
|
||||
Which auth priority was used for the outbound MCP call:
|
||||
``per-request-header``, ``m2m-client-credentials``, ``static-token``,
|
||||
``oauth2-passthrough``, or ``no-auth``.
|
||||
|
||||
x-mcp-debug-outbound-url
|
||||
The upstream MCP server URL that will receive the request.
|
||||
|
||||
x-mcp-debug-server-auth-type
|
||||
The ``auth_type`` configured on the MCP server (e.g. ``oauth2``,
|
||||
``bearer_token``, ``none``).
|
||||
|
||||
Debugging Guide
|
||||
---------------
|
||||
|
||||
**Common issue: LiteLLM API key leaking to the MCP server**
|
||||
|
||||
Symptom: ``x-mcp-debug-oauth2-token`` shows ``SAME_AS_LITELLM_KEY``.
|
||||
|
||||
This means the ``Authorization`` header carries the LiteLLM API key and
|
||||
it's being forwarded to the upstream MCP server instead of an OAuth2 token.
|
||||
|
||||
Fix: Move the LiteLLM key to ``x-litellm-api-key`` so the ``Authorization``
|
||||
header is free for OAuth2 discovery::
|
||||
|
||||
# WRONG — blocks OAuth2 discovery
|
||||
claude mcp add --transport http my_server http://proxy/mcp/server \\
|
||||
--header "Authorization: Bearer sk-..."
|
||||
|
||||
# CORRECT — LiteLLM key in dedicated header, Authorization free for OAuth2
|
||||
claude mcp add --transport http my_server http://proxy/mcp/server \\
|
||||
--header "x-litellm-api-key: Bearer sk-..." \\
|
||||
--header "x-litellm-mcp-debug: true"
|
||||
|
||||
**Common issue: No OAuth2 token present**
|
||||
|
||||
Symptom: ``x-mcp-debug-oauth2-token`` shows ``(none)`` and
|
||||
``x-mcp-debug-auth-resolution`` shows ``no-auth``.
|
||||
|
||||
This means the client didn't go through the OAuth2 flow. Check that:
|
||||
1. The ``Authorization`` header is NOT set as a static header in the client config.
|
||||
2. The ``.well-known/oauth-protected-resource`` endpoint returns valid metadata.
|
||||
3. The MCP server in LiteLLM config has ``auth_type: oauth2``.
|
||||
|
||||
**Common issue: M2M token used instead of user token**
|
||||
|
||||
Symptom: ``x-mcp-debug-auth-resolution`` shows ``m2m-client-credentials``.
|
||||
|
||||
This means the server has ``client_id``/``client_secret``/``token_url``
|
||||
configured and LiteLLM is fetching a machine-to-machine token instead of
|
||||
using the per-user OAuth2 token. If you want per-user tokens, remove the
|
||||
client credentials from the server config.
|
||||
|
||||
Usage from Claude Code::
|
||||
|
||||
claude mcp add --transport http my_server http://proxy/mcp/server \\
|
||||
--header "x-litellm-api-key: Bearer sk-..." \\
|
||||
--header "x-litellm-mcp-debug: true"
|
||||
|
||||
Usage with curl::
|
||||
|
||||
curl -H "x-litellm-mcp-debug: true" \\
|
||||
-H "x-litellm-api-key: Bearer sk-..." \\
|
||||
http://localhost:4000/mcp/atlassian_mcp
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from starlette.types import Message, Send
|
||||
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
# Header the client sends to opt into debug mode
|
||||
MCP_DEBUG_REQUEST_HEADER = "x-litellm-mcp-debug"
|
||||
|
||||
# Prefix for all debug response headers
|
||||
_RESPONSE_HEADER_PREFIX = "x-mcp-debug"
|
||||
|
||||
|
||||
class MCPDebug:
|
||||
"""
|
||||
Static helper class for MCP OAuth2 debug headers.
|
||||
|
||||
Provides opt-in client-side diagnostics by injecting masked
|
||||
authentication info into HTTP response headers.
|
||||
"""
|
||||
|
||||
# Masker: show first 6 and last 4 chars so you can distinguish token types
|
||||
# e.g. "Bearer****ef01" vs "sk-123****cdef"
|
||||
_masker = SensitiveDataMasker(
|
||||
sensitive_patterns={
|
||||
"authorization",
|
||||
"token",
|
||||
"key",
|
||||
"secret",
|
||||
"auth",
|
||||
"bearer",
|
||||
},
|
||||
visible_prefix=6,
|
||||
visible_suffix=4,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mask(value: Optional[str]) -> str:
|
||||
"""Mask a single value for safe display in headers."""
|
||||
if not value:
|
||||
return "(none)"
|
||||
return MCPDebug._masker._mask_value(value)
|
||||
|
||||
@staticmethod
|
||||
def is_debug_enabled(headers: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Check if the client opted into MCP debug mode.
|
||||
|
||||
Looks for ``x-litellm-mcp-debug: true`` (case-insensitive) in the
|
||||
request headers.
|
||||
"""
|
||||
for key, val in headers.items():
|
||||
if key.lower() == MCP_DEBUG_REQUEST_HEADER:
|
||||
return val.strip().lower() in ("true", "1", "yes")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def resolve_auth_resolution(
|
||||
server: "MCPServer",
|
||||
mcp_auth_header: Optional[str],
|
||||
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
|
||||
oauth2_headers: Optional[Dict[str, str]],
|
||||
) -> str:
|
||||
"""
|
||||
Determine which auth priority will be used for the outbound MCP call.
|
||||
|
||||
Returns one of: ``per-request-header``, ``m2m-client-credentials``,
|
||||
``static-token``, ``oauth2-passthrough``, or ``no-auth``.
|
||||
"""
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
has_server_specific = bool(
|
||||
mcp_server_auth_headers
|
||||
and (
|
||||
mcp_server_auth_headers.get(server.alias or "")
|
||||
or mcp_server_auth_headers.get(server.server_name or "")
|
||||
)
|
||||
)
|
||||
if has_server_specific or mcp_auth_header:
|
||||
return "per-request-header"
|
||||
if server.has_client_credentials:
|
||||
return "m2m-client-credentials"
|
||||
if server.authentication_token:
|
||||
return "static-token"
|
||||
if oauth2_headers and server.auth_type == MCPAuth.oauth2:
|
||||
return "oauth2-passthrough"
|
||||
return "no-auth"
|
||||
|
||||
@staticmethod
|
||||
def build_debug_headers(
|
||||
*,
|
||||
inbound_headers: Dict[str, str],
|
||||
oauth2_headers: Optional[Dict[str, str]],
|
||||
litellm_api_key: Optional[str],
|
||||
auth_resolution: str,
|
||||
server_url: Optional[str],
|
||||
server_auth_type: Optional[str],
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Build masked debug response headers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inbound_headers : dict
|
||||
Raw headers received from the MCP client.
|
||||
oauth2_headers : dict or None
|
||||
Extracted OAuth2 headers (``{"Authorization": "Bearer ..."}``).
|
||||
litellm_api_key : str or None
|
||||
The LiteLLM API key extracted from ``x-litellm-api-key`` or
|
||||
``Authorization`` header.
|
||||
auth_resolution : str
|
||||
Which auth priority was selected for the outbound call.
|
||||
server_url : str or None
|
||||
Upstream MCP server URL.
|
||||
server_auth_type : str or None
|
||||
The ``auth_type`` configured on the server (e.g. ``oauth2``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Headers to include in the response (all values masked).
|
||||
"""
|
||||
debug: Dict[str, str] = {}
|
||||
|
||||
# --- Inbound auth summary ---
|
||||
inbound_parts = []
|
||||
for hdr_name in ("x-litellm-api-key", "authorization", "x-mcp-auth"):
|
||||
for k, v in inbound_headers.items():
|
||||
if k.lower() == hdr_name:
|
||||
inbound_parts.append(f"{hdr_name}={MCPDebug._mask(v)}")
|
||||
break
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-inbound-auth"] = (
|
||||
"; ".join(inbound_parts) if inbound_parts else "(none)"
|
||||
)
|
||||
|
||||
# --- OAuth2 token ---
|
||||
oauth2_token = (oauth2_headers or {}).get("Authorization")
|
||||
if oauth2_token and litellm_api_key:
|
||||
oauth2_raw = oauth2_token.removeprefix("Bearer ").strip()
|
||||
litellm_raw = litellm_api_key.removeprefix("Bearer ").strip()
|
||||
if oauth2_raw == litellm_raw:
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = (
|
||||
f"{MCPDebug._mask(oauth2_token)} "
|
||||
f"(SAME_AS_LITELLM_KEY - likely misconfigured)"
|
||||
)
|
||||
else:
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
|
||||
oauth2_token
|
||||
)
|
||||
else:
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
|
||||
oauth2_token
|
||||
)
|
||||
|
||||
# --- Auth resolution ---
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-auth-resolution"] = auth_resolution
|
||||
|
||||
# --- Server info ---
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-outbound-url"] = server_url or "(unknown)"
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-server-auth-type"] = (
|
||||
server_auth_type or "(none)"
|
||||
)
|
||||
|
||||
return debug
|
||||
|
||||
@staticmethod
|
||||
def wrap_send_with_debug_headers(send: Send, debug_headers: Dict[str, str]) -> Send:
|
||||
"""
|
||||
Return a new ASGI ``send`` callable that injects *debug_headers*
|
||||
into the ``http.response.start`` message.
|
||||
"""
|
||||
|
||||
async def _send_with_debug(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
headers = list(message.get("headers", []))
|
||||
for k, v in debug_headers.items():
|
||||
headers.append((k.encode(), v.encode()))
|
||||
message = {**message, "headers": headers}
|
||||
await send(message)
|
||||
|
||||
return _send_with_debug
|
||||
|
||||
@staticmethod
|
||||
def maybe_build_debug_headers(
|
||||
*,
|
||||
raw_headers: Optional[Dict[str, str]],
|
||||
scope: Dict,
|
||||
mcp_servers: Optional[List[str]],
|
||||
mcp_auth_header: Optional[str],
|
||||
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
|
||||
oauth2_headers: Optional[Dict[str, str]],
|
||||
client_ip: Optional[str],
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Build debug headers if debug mode is enabled, otherwise return empty dict.
|
||||
|
||||
This is the single entry point called from the MCP request handler.
|
||||
"""
|
||||
if not raw_headers or not MCPDebug.is_debug_enabled(raw_headers):
|
||||
return {}
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
server_url: Optional[str] = None
|
||||
server_auth_type: Optional[str] = None
|
||||
auth_resolution = "no-auth"
|
||||
|
||||
for server_name in mcp_servers or []:
|
||||
server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
server_name, client_ip=client_ip
|
||||
)
|
||||
if server:
|
||||
server_url = server.url
|
||||
server_auth_type = server.auth_type
|
||||
auth_resolution = MCPDebug.resolve_auth_resolution(
|
||||
server, mcp_auth_header, mcp_server_auth_headers, oauth2_headers
|
||||
)
|
||||
break
|
||||
|
||||
scope_headers = MCPRequestHandler._safe_get_headers_from_scope(scope)
|
||||
litellm_key = MCPRequestHandler.get_litellm_api_key_from_headers(scope_headers)
|
||||
|
||||
return MCPDebug.build_debug_headers(
|
||||
inbound_headers=raw_headers,
|
||||
oauth2_headers=oauth2_headers,
|
||||
litellm_api_key=litellm_key,
|
||||
auth_resolution=auth_resolution,
|
||||
server_url=server_url,
|
||||
server_auth_type=server_auth_type,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
OAuth2 client_credentials token cache for MCP servers.
|
||||
|
||||
Automatically fetches and refreshes access tokens for MCP servers configured
|
||||
with ``client_id``, ``client_secret``, and ``token_url``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.constants import (
|
||||
MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
|
||||
MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
|
||||
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
|
||||
MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
|
||||
class MCPOAuth2TokenCache(InMemoryCache):
|
||||
"""
|
||||
In-memory cache for OAuth2 client_credentials tokens, keyed by server_id.
|
||||
|
||||
Inherits from ``InMemoryCache`` for TTL-based storage and eviction.
|
||||
Adds per-server ``asyncio.Lock`` to prevent duplicate concurrent fetches.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
max_size_in_memory=MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
|
||||
default_ttl=MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
|
||||
)
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _get_lock(self, server_id: str) -> asyncio.Lock:
|
||||
return self._locks.setdefault(server_id, asyncio.Lock())
|
||||
|
||||
async def async_get_token(self, server: "MCPServer") -> Optional[str]:
|
||||
"""Return a valid access token, fetching or refreshing as needed.
|
||||
|
||||
Returns ``None`` when the server lacks client credentials config.
|
||||
"""
|
||||
if not server.has_client_credentials:
|
||||
return None
|
||||
|
||||
server_id = server.server_id
|
||||
|
||||
# Fast path — cached token is still valid
|
||||
cached = self.get_cache(server_id)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Slow path — acquire per-server lock then double-check
|
||||
async with self._get_lock(server_id):
|
||||
cached = self.get_cache(server_id)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
token, ttl = await self._fetch_token(server)
|
||||
self.set_cache(server_id, token, ttl=ttl)
|
||||
return token
|
||||
|
||||
async def _fetch_token(self, server: "MCPServer") -> Tuple[str, int]:
|
||||
"""POST to ``token_url`` with ``grant_type=client_credentials``.
|
||||
|
||||
Returns ``(access_token, ttl_seconds)`` where ttl accounts for the
|
||||
expiry buffer so the cache entry expires before the real token does.
|
||||
"""
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
|
||||
|
||||
if not server.client_id or not server.client_secret or not server.token_url:
|
||||
raise ValueError(
|
||||
f"MCP server '{server.server_id}' missing required OAuth2 fields: "
|
||||
f"client_id={bool(server.client_id)}, "
|
||||
f"client_secret={bool(server.client_secret)}, "
|
||||
f"token_url={bool(server.token_url)}"
|
||||
)
|
||||
|
||||
data: Dict[str, str] = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": server.client_id,
|
||||
"client_secret": server.client_secret,
|
||||
}
|
||||
if server.scopes:
|
||||
data["scope"] = " ".join(server.scopes)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Fetching OAuth2 client_credentials token for MCP server %s",
|
||||
server.server_id,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(server.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise ValueError(
|
||||
f"OAuth2 token request for MCP server '{server.server_id}' "
|
||||
f"failed with status {exc.response.status_code}"
|
||||
) from exc
|
||||
|
||||
body = response.json()
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise ValueError(
|
||||
f"OAuth2 token response for MCP server '{server.server_id}' "
|
||||
f"returned non-object JSON (got {type(body).__name__})"
|
||||
)
|
||||
|
||||
access_token = body.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError(
|
||||
f"OAuth2 token response for MCP server '{server.server_id}' "
|
||||
f"missing 'access_token'"
|
||||
)
|
||||
|
||||
# Safely parse expires_in — providers may return null or non-numeric values
|
||||
raw_expires_in = body.get("expires_in")
|
||||
try:
|
||||
expires_in = (
|
||||
int(raw_expires_in)
|
||||
if raw_expires_in is not None
|
||||
else MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
expires_in = MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
|
||||
|
||||
ttl = max(
|
||||
expires_in - MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
|
||||
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
"Fetched OAuth2 token for MCP server %s (expires in %ds)",
|
||||
server.server_id,
|
||||
expires_in,
|
||||
)
|
||||
return access_token, ttl
|
||||
|
||||
def invalidate(self, server_id: str) -> None:
|
||||
"""Remove a cached token (e.g. after a 401)."""
|
||||
self.delete_cache(server_id)
|
||||
|
||||
|
||||
mcp_oauth2_token_cache = MCPOAuth2TokenCache()
|
||||
|
||||
|
||||
async def resolve_mcp_auth(
|
||||
server: "MCPServer",
|
||||
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
|
||||
) -> Optional[Union[str, Dict[str, str]]]:
|
||||
"""Resolve the auth value for an MCP server.
|
||||
|
||||
Priority:
|
||||
1. ``mcp_auth_header`` — per-request/per-user override
|
||||
2. OAuth2 client_credentials token — auto-fetched and cached
|
||||
3. ``server.authentication_token`` — static token from config/DB
|
||||
"""
|
||||
if mcp_auth_header:
|
||||
return mcp_auth_header
|
||||
if server.has_client_credentials:
|
||||
return await mcp_oauth2_token_cache.async_get_token(server)
|
||||
return server.authentication_token
|
||||
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
This module is used to generate MCP tools from OpenAPI specs.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import os
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.tool_registry import (
|
||||
global_mcp_tool_registry,
|
||||
)
|
||||
|
||||
# Store the base URL and headers globally
|
||||
BASE_URL = ""
|
||||
HEADERS: Dict[str, str] = {}
|
||||
|
||||
# Per-request auth header override for BYOK servers.
|
||||
# Set this ContextVar before calling a local tool handler to inject the user's
|
||||
# stored credential into the HTTP request made by the tool function closure.
|
||||
_request_auth_header: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
||||
"_request_auth_header", default=None
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_path_parameter_value(param_value: Any, param_name: str) -> str:
|
||||
"""Ensure path params cannot introduce directory traversal."""
|
||||
if param_value is None:
|
||||
return ""
|
||||
|
||||
value_str = str(param_value)
|
||||
if value_str == "":
|
||||
return ""
|
||||
|
||||
normalized_value = value_str.replace("\\", "/")
|
||||
if "/" in normalized_value:
|
||||
raise ValueError(
|
||||
f"Path parameter '{param_name}' must not contain path separators"
|
||||
)
|
||||
|
||||
if any(part in {".", ".."} for part in PurePosixPath(normalized_value).parts):
|
||||
raise ValueError(
|
||||
f"Path parameter '{param_name}' cannot include '.' or '..' segments"
|
||||
)
|
||||
|
||||
return quote(value_str, safe="")
|
||||
|
||||
|
||||
def load_openapi_spec(filepath: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Sync wrapper. For URL specs, use the shared/custom MCP httpx client.
|
||||
"""
|
||||
try:
|
||||
# If we're already inside an event loop, prefer the async function.
|
||||
asyncio.get_running_loop()
|
||||
raise RuntimeError(
|
||||
"load_openapi_spec() was called from within a running event loop. "
|
||||
"Use 'await load_openapi_spec_async(...)' instead."
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# "no running event loop" is fine; other RuntimeErrors we re-raise
|
||||
if "no running event loop" not in str(e).lower():
|
||||
raise
|
||||
return asyncio.run(load_openapi_spec_async(filepath))
|
||||
|
||||
|
||||
async def load_openapi_spec_async(filepath: str) -> Dict[str, Any]:
|
||||
if filepath.startswith("http://") or filepath.startswith("https://"):
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
|
||||
# NOTE: do not close shared client if get_async_httpx_client returns a shared singleton.
|
||||
# If it returns a new client each time, consider wrapping it in an async context manager.
|
||||
r = await client.get(filepath)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# fallback: local file
|
||||
# Local filesystem path
|
||||
if not os.path.exists(filepath):
|
||||
raise FileNotFoundError(f"OpenAPI spec not found at {filepath}")
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_base_url(spec: Dict[str, Any], spec_path: Optional[str] = None) -> str:
|
||||
"""Extract base URL from OpenAPI spec."""
|
||||
# OpenAPI 3.x
|
||||
if "servers" in spec and spec["servers"]:
|
||||
server_url = spec["servers"][0]["url"]
|
||||
|
||||
# If the server URL is relative (starts with /), derive base from spec_path
|
||||
if server_url.startswith("/") and spec_path:
|
||||
if spec_path.startswith("http://") or spec_path.startswith("https://"):
|
||||
# Extract base URL from spec_path (e.g., https://petstore3.swagger.io/api/v3/openapi.json)
|
||||
# Combine domain with the relative server URL
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(spec_path)
|
||||
base_domain = f"{parsed.scheme}://{parsed.netloc}"
|
||||
full_base_url = base_domain + server_url
|
||||
verbose_logger.info(
|
||||
f"OpenAPI spec has relative server URL '{server_url}'. "
|
||||
f"Deriving base from spec_path: {full_base_url}"
|
||||
)
|
||||
return full_base_url
|
||||
|
||||
return server_url
|
||||
# OpenAPI 2.x (Swagger)
|
||||
elif "host" in spec:
|
||||
scheme = spec.get("schemes", ["https"])[0]
|
||||
base_path = spec.get("basePath", "")
|
||||
return f"{scheme}://{spec['host']}{base_path}"
|
||||
|
||||
# Fallback: derive base URL from spec_path if it's a URL
|
||||
if spec_path and (
|
||||
spec_path.startswith("http://") or spec_path.startswith("https://")
|
||||
):
|
||||
for suffix in [
|
||||
"/openapi.json",
|
||||
"/openapi.yaml",
|
||||
"/swagger.json",
|
||||
"/swagger.yaml",
|
||||
]:
|
||||
if spec_path.endswith(suffix):
|
||||
base_url = spec_path[: -len(suffix)]
|
||||
verbose_logger.info(
|
||||
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
if spec_path.split("/")[-1].endswith((".json", ".yaml", ".yml")):
|
||||
base_url = "/".join(spec_path.split("/")[:-1])
|
||||
verbose_logger.info(
|
||||
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_ref(
|
||||
param: Dict[str, Any], component_params: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve a single parameter, following a $ref if present.
|
||||
|
||||
Returns the resolved param dict, or None if the $ref target is absent from
|
||||
components (so callers can skip/filter it rather than propagating a stub
|
||||
with name=None that would corrupt deduplication).
|
||||
"""
|
||||
ref = param.get("$ref", "")
|
||||
if not ref.startswith("#/components/parameters/"):
|
||||
return param
|
||||
return component_params.get(ref.split("/")[-1])
|
||||
|
||||
|
||||
def _resolve_param_list(
|
||||
raw: List[Dict[str, Any]], component_params: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Resolve $refs in a parameter list, dropping any unresolvable entries."""
|
||||
result = []
|
||||
for p in raw:
|
||||
resolved = _resolve_ref(p, component_params)
|
||||
if resolved is not None and resolved.get("name"):
|
||||
result.append(resolved)
|
||||
return result
|
||||
|
||||
|
||||
def resolve_operation_params(
|
||||
operation: Dict[str, Any],
|
||||
path_item: Dict[str, Any],
|
||||
components: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Return a copy of *operation* with fully-resolved, merged parameters.
|
||||
|
||||
Handles two common patterns in real-world OpenAPI specs:
|
||||
|
||||
1. **$ref parameters** — ``{"$ref": "#/components/parameters/per-page"}``
|
||||
instead of inline objects. Each ref is resolved against
|
||||
``components["parameters"]``; unresolvable refs are silently dropped so
|
||||
they cannot corrupt the deduplication set with ``(None, None)`` keys.
|
||||
|
||||
2. **Path-level parameters** — params defined on the path item that apply
|
||||
to every HTTP method on that path (e.g. ``owner``, ``repo``). They are
|
||||
merged with the operation-level params; operation-level wins when the
|
||||
same ``name`` + ``in`` combination appears in both.
|
||||
"""
|
||||
component_params = components.get("parameters", {})
|
||||
path_level = _resolve_param_list(path_item.get("parameters", []), component_params)
|
||||
op_level = _resolve_param_list(operation.get("parameters", []), component_params)
|
||||
op_keys = {(p["name"], p.get("in")) for p in op_level}
|
||||
merged = [
|
||||
p for p in path_level if (p["name"], p.get("in")) not in op_keys
|
||||
] + op_level
|
||||
result = dict(operation)
|
||||
result["parameters"] = merged
|
||||
return result
|
||||
|
||||
|
||||
def extract_parameters(operation: Dict[str, Any]) -> tuple:
|
||||
"""Extract parameter names from OpenAPI operation."""
|
||||
path_params = []
|
||||
query_params = []
|
||||
body_params = []
|
||||
|
||||
# OpenAPI 3.x and 2.x parameters
|
||||
if "parameters" in operation:
|
||||
for param in operation["parameters"]:
|
||||
if "name" not in param:
|
||||
continue
|
||||
param_name = param["name"]
|
||||
if param.get("in") == "path":
|
||||
path_params.append(param_name)
|
||||
elif param.get("in") == "query":
|
||||
query_params.append(param_name)
|
||||
elif param.get("in") == "body":
|
||||
body_params.append(param_name)
|
||||
|
||||
# OpenAPI 3.x requestBody
|
||||
if "requestBody" in operation:
|
||||
body_params.append("body")
|
||||
|
||||
return path_params, query_params, body_params
|
||||
|
||||
|
||||
def build_input_schema(operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build MCP input schema from OpenAPI operation."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
# Process parameters
|
||||
if "parameters" in operation:
|
||||
for param in operation["parameters"]:
|
||||
if "name" not in param:
|
||||
continue
|
||||
param_name = param["name"]
|
||||
param_schema = param.get("schema", {})
|
||||
param_type = param_schema.get("type", "string")
|
||||
|
||||
properties[param_name] = {
|
||||
"type": param_type,
|
||||
"description": param.get("description", ""),
|
||||
}
|
||||
|
||||
if param.get("required", False):
|
||||
required.append(param_name)
|
||||
|
||||
# Process requestBody (OpenAPI 3.x)
|
||||
if "requestBody" in operation:
|
||||
request_body = operation["requestBody"]
|
||||
content = request_body.get("content", {})
|
||||
|
||||
# Try to get JSON schema
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema", {})
|
||||
properties["body"] = {
|
||||
"type": "object",
|
||||
"description": request_body.get("description", "Request body"),
|
||||
"properties": schema.get("properties", {}),
|
||||
}
|
||||
if request_body.get("required", False):
|
||||
required.append("body")
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required if required else [],
|
||||
}
|
||||
|
||||
|
||||
def create_tool_function(
|
||||
path: str,
|
||||
method: str,
|
||||
operation: Dict[str, Any],
|
||||
base_url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Create a tool function for an OpenAPI operation.
|
||||
|
||||
This function creates an async tool function that can be called with
|
||||
keyword arguments. Parameter names from the OpenAPI spec are accessed
|
||||
directly via **kwargs, avoiding syntax errors from invalid Python identifiers.
|
||||
|
||||
Args:
|
||||
path: API endpoint path
|
||||
method: HTTP method (get, post, put, delete, patch)
|
||||
operation: OpenAPI operation object
|
||||
base_url: Base URL for the API
|
||||
headers: Optional headers to include in requests (e.g., authentication)
|
||||
|
||||
Returns:
|
||||
An async function that accepts **kwargs and makes the HTTP request
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
path_params, query_params, body_params = extract_parameters(operation)
|
||||
original_method = method.lower()
|
||||
|
||||
async def tool_function(**kwargs: Any) -> str:
|
||||
"""
|
||||
Dynamically generated tool function.
|
||||
|
||||
Accepts keyword arguments where keys are the original OpenAPI parameter names.
|
||||
The function safely handles parameter names that aren't valid Python identifiers
|
||||
by using **kwargs instead of named parameters.
|
||||
"""
|
||||
# Allow per-request auth override (e.g. BYOK credential set via ContextVar).
|
||||
# The ContextVar holds the full Authorization header value, including the
|
||||
# correct prefix (Bearer / ApiKey / Basic) formatted by the caller in
|
||||
# server.py based on the server's configured auth_type.
|
||||
effective_headers = dict(headers)
|
||||
override_auth = _request_auth_header.get()
|
||||
if override_auth:
|
||||
effective_headers["Authorization"] = override_auth
|
||||
|
||||
# Build URL from base_url and path
|
||||
url = base_url + path
|
||||
|
||||
# Replace path parameters using original names from OpenAPI spec
|
||||
# Apply path traversal validation and URL encoding
|
||||
for param_name in path_params:
|
||||
param_value = kwargs.get(param_name, "")
|
||||
if param_value:
|
||||
try:
|
||||
# Sanitize and encode path parameter to prevent traversal attacks
|
||||
safe_value = _sanitize_path_parameter_value(param_value, param_name)
|
||||
except ValueError as exc:
|
||||
return "Invalid path parameter: " + str(exc)
|
||||
# Replace {param_name} or {{param_name}} in URL
|
||||
url = url.replace("{" + param_name + "}", safe_value)
|
||||
url = url.replace("{{" + param_name + "}}", safe_value)
|
||||
|
||||
# Build query params using original parameter names
|
||||
params: Dict[str, Any] = {}
|
||||
for param_name in query_params:
|
||||
param_value = kwargs.get(param_name, "")
|
||||
if param_value:
|
||||
# Use original parameter name in query string (as expected by API)
|
||||
params[param_name] = param_value
|
||||
|
||||
# Build request body
|
||||
json_body: Optional[Dict[str, Any]] = None
|
||||
if body_params:
|
||||
# Try "body" first (most common), then check all body param names
|
||||
body_value = kwargs.get("body", {})
|
||||
if not body_value:
|
||||
for param_name in body_params:
|
||||
body_value = kwargs.get(param_name, {})
|
||||
if body_value:
|
||||
break
|
||||
|
||||
if isinstance(body_value, dict):
|
||||
json_body = body_value
|
||||
elif body_value:
|
||||
# If it's a string, try to parse as JSON
|
||||
try:
|
||||
json_body = (
|
||||
json.loads(body_value)
|
||||
if isinstance(body_value, str)
|
||||
else {"data": body_value}
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
json_body = {"data": body_value}
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
|
||||
|
||||
if original_method == "get":
|
||||
response = await client.get(url, params=params, headers=effective_headers)
|
||||
elif original_method == "post":
|
||||
response = await client.post(
|
||||
url, params=params, json=json_body, headers=effective_headers
|
||||
)
|
||||
elif original_method == "put":
|
||||
response = await client.put(
|
||||
url, params=params, json=json_body, headers=effective_headers
|
||||
)
|
||||
elif original_method == "delete":
|
||||
response = await client.delete(
|
||||
url, params=params, headers=effective_headers
|
||||
)
|
||||
elif original_method == "patch":
|
||||
response = await client.patch(
|
||||
url, params=params, json=json_body, headers=effective_headers
|
||||
)
|
||||
else:
|
||||
return f"Unsupported HTTP method: {original_method}"
|
||||
|
||||
return response.text
|
||||
|
||||
return tool_function
|
||||
|
||||
|
||||
def register_tools_from_openapi(spec: Dict[str, Any], base_url: str):
|
||||
"""Register MCP tools from OpenAPI specification."""
|
||||
paths = spec.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
for method in ["get", "post", "put", "delete", "patch"]:
|
||||
if method in path_item:
|
||||
operation = path_item[method]
|
||||
|
||||
# Generate tool name
|
||||
operation_id = operation.get(
|
||||
"operationId", f"{method}_{path.replace('/', '_')}"
|
||||
)
|
||||
tool_name = operation_id.replace(" ", "_").lower()
|
||||
|
||||
# Get description
|
||||
description = operation.get(
|
||||
"summary", operation.get("description", f"{method.upper()} {path}")
|
||||
)
|
||||
|
||||
# Build input schema
|
||||
input_schema = build_input_schema(operation)
|
||||
|
||||
# Create tool function
|
||||
tool_func = create_tool_function(path, method, operation, base_url)
|
||||
tool_func.__name__ = tool_name
|
||||
tool_func.__doc__ = description
|
||||
|
||||
# Register tool with local registry
|
||||
global_mcp_tool_registry.register_tool(
|
||||
name=tool_name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
handler=tool_func,
|
||||
)
|
||||
verbose_logger.debug(f"Registered tool: {tool_name}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Semantic MCP Tool Filtering using semantic-router
|
||||
|
||||
Filters MCP tools semantically for /chat/completions and /responses endpoints.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from semantic_router.routers import SemanticRouter
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
|
||||
class SemanticMCPToolFilter:
|
||||
"""Filters MCP tools using semantic similarity to reduce context window size."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: str,
|
||||
litellm_router_instance: "Router",
|
||||
top_k: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
enabled: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the semantic tool filter.
|
||||
|
||||
Args:
|
||||
embedding_model: Model to use for embeddings (e.g., "text-embedding-3-small")
|
||||
litellm_router_instance: Router instance for embedding generation
|
||||
top_k: Maximum number of tools to return
|
||||
similarity_threshold: Minimum similarity score for filtering
|
||||
enabled: Whether filtering is enabled
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.top_k = top_k
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
self.router_instance = litellm_router_instance
|
||||
self.tool_router: Optional["SemanticRouter"] = None
|
||||
self._tool_map: Dict[str, Any] = {} # MCPTool objects or OpenAI function dicts
|
||||
|
||||
async def build_router_from_mcp_registry(self) -> None:
|
||||
"""Build semantic router from all MCP tools in the registry (no auth checks)."""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get all servers from registry without auth checks
|
||||
registry = global_mcp_server_manager.get_registry()
|
||||
if not registry:
|
||||
verbose_logger.warning("MCP registry is empty")
|
||||
self.tool_router = None
|
||||
return
|
||||
|
||||
# Fetch tools from all servers in parallel
|
||||
all_tools = []
|
||||
for server_id, server in registry.items():
|
||||
try:
|
||||
tools = await global_mcp_server_manager.get_tools_for_server(
|
||||
server_id
|
||||
)
|
||||
all_tools.extend(tools)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to fetch tools from server {server_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not all_tools:
|
||||
verbose_logger.warning("No MCP tools found in registry")
|
||||
self.tool_router = None
|
||||
return
|
||||
|
||||
verbose_logger.info(
|
||||
f"Fetched {len(all_tools)} tools from {len(registry)} MCP servers"
|
||||
)
|
||||
self._build_router(all_tools)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to build router from MCP registry: {e}")
|
||||
self.tool_router = None
|
||||
raise
|
||||
|
||||
def _extract_tool_info(self, tool) -> tuple[str, str]:
|
||||
"""Extract name and description from MCP tool or OpenAI function dict."""
|
||||
name: str
|
||||
description: str
|
||||
|
||||
if isinstance(tool, dict):
|
||||
# OpenAI function format
|
||||
name = tool.get("name", "")
|
||||
description = tool.get("description", name)
|
||||
else:
|
||||
# MCPTool object
|
||||
name = str(tool.name)
|
||||
description = str(tool.description) if tool.description else str(tool.name)
|
||||
|
||||
return name, description
|
||||
|
||||
def _build_router(self, tools: List) -> None:
|
||||
"""Build semantic router with tools (MCPTool objects or OpenAI function dicts)."""
|
||||
from semantic_router.routers import SemanticRouter
|
||||
from semantic_router.routers.base import Route
|
||||
|
||||
from litellm.router_strategy.auto_router.litellm_encoder import (
|
||||
LiteLLMRouterEncoder,
|
||||
)
|
||||
|
||||
if not tools:
|
||||
self.tool_router = None
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert tools to routes
|
||||
routes = []
|
||||
self._tool_map = {}
|
||||
|
||||
for tool in tools:
|
||||
name, description = self._extract_tool_info(tool)
|
||||
self._tool_map[name] = tool
|
||||
|
||||
routes.append(
|
||||
Route(
|
||||
name=name,
|
||||
description=description,
|
||||
utterances=[description],
|
||||
score_threshold=self.similarity_threshold,
|
||||
)
|
||||
)
|
||||
|
||||
self.tool_router = SemanticRouter(
|
||||
routes=routes,
|
||||
encoder=LiteLLMRouterEncoder(
|
||||
litellm_router_instance=self.router_instance,
|
||||
model_name=self.embedding_model,
|
||||
score_threshold=self.similarity_threshold,
|
||||
),
|
||||
auto_sync="local",
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Built semantic router with {len(routes)} tools")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to build semantic router: {e}")
|
||||
self.tool_router = None
|
||||
raise
|
||||
|
||||
async def filter_tools(
|
||||
self,
|
||||
query: str,
|
||||
available_tools: List[Any],
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Filter tools semantically based on query.
|
||||
|
||||
Args:
|
||||
query: User query to match against tools
|
||||
available_tools: Full list of available MCP tools
|
||||
top_k: Override default top_k (optional)
|
||||
|
||||
Returns:
|
||||
Filtered and ordered list of tools (up to top_k)
|
||||
"""
|
||||
# Early returns for cases where we can't/shouldn't filter
|
||||
if not self.enabled:
|
||||
return available_tools
|
||||
|
||||
if not available_tools:
|
||||
return available_tools
|
||||
|
||||
if not query or not query.strip():
|
||||
return available_tools
|
||||
|
||||
# Router should be built on startup - if not, something went wrong
|
||||
if self.tool_router is None:
|
||||
verbose_logger.warning(
|
||||
"Router not initialized - was build_router_from_mcp_registry() called on startup?"
|
||||
)
|
||||
return available_tools
|
||||
|
||||
# Run semantic filtering
|
||||
try:
|
||||
limit = top_k or self.top_k
|
||||
matches = self.tool_router(text=query, limit=limit)
|
||||
matched_tool_names = self._extract_tool_names_from_matches(matches)
|
||||
|
||||
if not matched_tool_names:
|
||||
return available_tools
|
||||
|
||||
return self._get_tools_by_names(matched_tool_names, available_tools)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Semantic tool filter failed: {e}", exc_info=True)
|
||||
return available_tools
|
||||
|
||||
def _extract_tool_names_from_matches(self, matches) -> List[str]:
|
||||
"""Extract tool names from semantic router match results."""
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
# Handle single match
|
||||
if hasattr(matches, "name") and matches.name:
|
||||
return [matches.name]
|
||||
|
||||
# Handle list of matches
|
||||
if isinstance(matches, list):
|
||||
return [m.name for m in matches if hasattr(m, "name") and m.name]
|
||||
|
||||
return []
|
||||
|
||||
def _get_tools_by_names(
|
||||
self, tool_names: List[str], available_tools: List[Any]
|
||||
) -> List[Any]:
|
||||
"""Get tools from available_tools by their names, preserving order."""
|
||||
# Match tools from available_tools (preserves format - dict or MCPTool)
|
||||
matched_tools = []
|
||||
for tool in available_tools:
|
||||
tool_name, _ = self._extract_tool_info(tool)
|
||||
if tool_name in tool_names:
|
||||
matched_tools.append(tool)
|
||||
|
||||
# Reorder to match semantic router's ordering
|
||||
tool_map = {self._extract_tool_info(t)[0]: t for t in matched_tools}
|
||||
return [tool_map[name] for name in tool_names if name in tool_map]
|
||||
|
||||
def extract_user_query(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Extract user query from messages for /chat/completions or /responses.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries (from 'messages' or 'input' field)
|
||||
|
||||
Returns:
|
||||
Extracted query string
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
texts = [
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in content
|
||||
if isinstance(block, (dict, str))
|
||||
]
|
||||
return " ".join(texts)
|
||||
|
||||
return ""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
This is a modification of code from: https://github.com/SecretiveShell/MCP-Bridge/blob/master/mcp_bridge/mcp_server/sse_transport.py
|
||||
|
||||
Credit to the maintainers of SecretiveShell for their SSE Transport implementation
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import anyio
|
||||
import mcp.types as types
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
"""
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications,
|
||||
suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests,
|
||||
and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST
|
||||
requests, which should contain client messages that link to a
|
||||
previously-established SSE session.
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_read_stream_writers: dict[
|
||||
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST
|
||||
messages to the relative or absolute URL given.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
verbose_logger.debug(
|
||||
f"SseServerTransport initialized with endpoint: {endpoint}"
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_sse(self, request: Request):
|
||||
if request.scope["type"] != "http":
|
||||
verbose_logger.error("connect_sse received non-HTTP request")
|
||||
raise ValueError("connect_sse can only handle HTTP requests")
|
||||
|
||||
verbose_logger.debug("Setting up SSE connection")
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
verbose_logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]]
|
||||
sse_stream_reader: MemoryObjectReceiveStream[dict[str, Any]]
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
|
||||
0, dict[str, Any]
|
||||
)
|
||||
|
||||
async def sse_writer():
|
||||
verbose_logger.debug("Starting SSE writer")
|
||||
async with sse_stream_writer, write_stream_reader:
|
||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
||||
verbose_logger.debug(f"Sent endpoint event: {session_uri}")
|
||||
|
||||
async for message in write_stream_reader:
|
||||
verbose_logger.debug(f"Sending message via SSE: {message}")
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
||||
)
|
||||
verbose_logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response, request.scope, request.receive, request._send)
|
||||
|
||||
verbose_logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
) -> Response:
|
||||
verbose_logger.debug("Handling POST message")
|
||||
request = Request(scope, receive)
|
||||
|
||||
session_id_param = request.query_params.get("session_id")
|
||||
if session_id_param is None:
|
||||
verbose_logger.warning("Received request without session_id")
|
||||
response = Response("session_id is required", status_code=400)
|
||||
return response
|
||||
|
||||
try:
|
||||
session_id = UUID(hex=session_id_param)
|
||||
verbose_logger.debug(f"Parsed session ID: {session_id}")
|
||||
except ValueError:
|
||||
verbose_logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return response
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
verbose_logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return response
|
||||
|
||||
json = await request.json()
|
||||
verbose_logger.debug(f"Received JSON: {json}")
|
||||
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate(json)
|
||||
verbose_logger.debug(f"Validated client message: {message}")
|
||||
except ValidationError as err:
|
||||
verbose_logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await writer.send(err)
|
||||
return response
|
||||
|
||||
verbose_logger.debug(f"Sending message to writer: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await writer.send(message)
|
||||
return response
|
||||
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.types.mcp_server.tool_registry import MCPTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.types import Tool as MCPToolSDKTool
|
||||
else:
|
||||
try:
|
||||
from mcp.types import Tool as MCPToolSDKTool
|
||||
except ImportError:
|
||||
MCPToolSDKTool = None # type: ignore
|
||||
|
||||
|
||||
class MCPToolRegistry:
|
||||
"""
|
||||
A registry for managing MCP tools
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Registry to store all registered tools
|
||||
self.tools: Dict[str, MCPTool] = {}
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
input_schema: Dict[str, Any],
|
||||
handler: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new tool in the registry
|
||||
"""
|
||||
self.tools[name] = MCPTool(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
handler=handler,
|
||||
)
|
||||
verbose_logger.debug(f"Registered tool: {name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[MCPTool]:
|
||||
"""
|
||||
Get a tool from the registry by name
|
||||
"""
|
||||
return self.tools.get(name)
|
||||
|
||||
def list_tools(self, tool_prefix: Optional[str] = None) -> List[MCPTool]:
|
||||
"""
|
||||
List all registered tools
|
||||
"""
|
||||
if tool_prefix:
|
||||
return [
|
||||
tool
|
||||
for tool in self.tools.values()
|
||||
if tool.name.startswith(tool_prefix)
|
||||
]
|
||||
return list(self.tools.values())
|
||||
|
||||
def convert_tools_to_mcp_sdk_tool_type(
|
||||
self, tools: List[MCPTool]
|
||||
) -> List["MCPToolSDKTool"]:
|
||||
if MCPToolSDKTool is None:
|
||||
raise ImportError(
|
||||
"MCP SDK is not installed. Please install it with: pip install 'litellm[proxy]'"
|
||||
)
|
||||
return [
|
||||
MCPToolSDKTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=tool.input_schema,
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
def load_tools_from_config(
|
||||
self, mcp_tools_config: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Load and register tools from the proxy config
|
||||
|
||||
Args:
|
||||
mcp_tools_config: The mcp_tools config from the proxy config
|
||||
"""
|
||||
if mcp_tools_config is None:
|
||||
raise ValueError(
|
||||
"mcp_tools_config is required, please set `mcp_tools` in your proxy config"
|
||||
)
|
||||
|
||||
for tool_config in mcp_tools_config:
|
||||
if not isinstance(tool_config, dict):
|
||||
raise ValueError("mcp_tools_config must be a list of dictionaries")
|
||||
|
||||
name = tool_config.get("name")
|
||||
description = tool_config.get("description")
|
||||
input_schema = tool_config.get("input_schema", {})
|
||||
handler_name = tool_config.get("handler")
|
||||
|
||||
if not all([name, description, handler_name]):
|
||||
continue
|
||||
|
||||
# Try to resolve the handler
|
||||
# First check if it's a module path (e.g., "module.submodule.function")
|
||||
if handler_name is None:
|
||||
raise ValueError(f"handler is required for tool {name}")
|
||||
handler = get_instance_fn(handler_name)
|
||||
|
||||
if handler is None:
|
||||
verbose_logger.warning(
|
||||
f"Warning: Could not find handler {handler_name} for tool {name}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Register the tool
|
||||
if name is None:
|
||||
raise ValueError(f"name is required for tool {name}")
|
||||
if description is None:
|
||||
raise ValueError(f"description is required for tool {name}")
|
||||
|
||||
self.register_tool(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
handler=handler,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"all registered tools: %s", json.dumps(self.tools, indent=4, default=str)
|
||||
)
|
||||
|
||||
|
||||
global_mcp_tool_registry = MCPToolRegistry()
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Helpers to resolve real team contexts for UI session tokens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
def clone_user_api_key_auth_with_team(
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
team_id: str,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""Return a deep copy of the auth context with a different team id."""
|
||||
|
||||
try:
|
||||
cloned_auth = user_api_key_auth.model_copy()
|
||||
except AttributeError:
|
||||
cloned_auth = user_api_key_auth.copy() # type: ignore[attr-defined]
|
||||
cloned_auth.team_id = team_id
|
||||
return cloned_auth
|
||||
|
||||
|
||||
async def resolve_ui_session_team_ids(
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
) -> List[str]:
|
||||
"""Resolve the real team ids backing a UI session token."""
|
||||
|
||||
if (
|
||||
user_api_key_auth.team_id != UI_SESSION_TOKEN_TEAM_ID
|
||||
or not user_api_key_auth.user_id
|
||||
):
|
||||
return []
|
||||
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_logger.debug("Cannot resolve UI session team ids without DB access")
|
||||
return []
|
||||
|
||||
try:
|
||||
user_obj = await get_user_object(
|
||||
user_id=user_api_key_auth.user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
verbose_logger.warning(
|
||||
"Failed to load teams for UI session token user.",
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
|
||||
if user_obj is None or not user_obj.teams:
|
||||
return []
|
||||
|
||||
resolved_team_ids: List[str] = []
|
||||
for team_id in user_obj.teams:
|
||||
if team_id and team_id not in resolved_team_ids:
|
||||
resolved_team_ids.append(team_id)
|
||||
return resolved_team_ids
|
||||
|
||||
|
||||
async def build_effective_auth_contexts(
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
) -> List[UserAPIKeyAuth]:
|
||||
"""Return auth contexts that reflect the actual teams for UI session tokens."""
|
||||
|
||||
resolved_team_ids = await resolve_ui_session_team_ids(user_api_key_auth)
|
||||
if resolved_team_ids:
|
||||
return [
|
||||
clone_user_api_key_auth_with_team(user_api_key_auth, team_id)
|
||||
for team_id in resolved_team_ids
|
||||
]
|
||||
return [user_api_key_auth]
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
MCP Server Utilities
|
||||
"""
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple
|
||||
|
||||
import os
|
||||
import importlib
|
||||
|
||||
# Constants
|
||||
LITELLM_MCP_SERVER_NAME = "litellm-mcp-server"
|
||||
LITELLM_MCP_SERVER_VERSION = "1.0.0"
|
||||
LITELLM_MCP_SERVER_DESCRIPTION = "MCP Server for LiteLLM"
|
||||
MCP_TOOL_PREFIX_SEPARATOR = os.environ.get("MCP_TOOL_PREFIX_SEPARATOR", "-")
|
||||
MCP_TOOL_PREFIX_FORMAT = "{server_name}{separator}{tool_name}"
|
||||
|
||||
|
||||
def is_mcp_available() -> bool:
|
||||
"""
|
||||
Returns True if the MCP module is available, False otherwise
|
||||
"""
|
||||
try:
|
||||
importlib.import_module("mcp")
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def normalize_server_name(server_name: str) -> str:
|
||||
"""
|
||||
Normalize server name by replacing spaces with underscores
|
||||
"""
|
||||
return server_name.replace(" ", "_")
|
||||
|
||||
|
||||
def validate_and_normalize_mcp_server_payload(payload: Any) -> None:
|
||||
"""
|
||||
Validate and normalize MCP server payload fields (server_name and alias).
|
||||
|
||||
This function:
|
||||
1. Validates that server_name and alias don't contain the MCP_TOOL_PREFIX_SEPARATOR
|
||||
2. Normalizes alias by replacing spaces with underscores
|
||||
3. Sets default alias if not provided (using server_name as base)
|
||||
|
||||
Args:
|
||||
payload: The payload object containing server_name and alias fields
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
# Server name validation: disallow '-'
|
||||
if hasattr(payload, "server_name") and payload.server_name:
|
||||
validate_mcp_server_name(payload.server_name, raise_http_exception=True)
|
||||
|
||||
# Alias validation: disallow '-'
|
||||
if hasattr(payload, "alias") and payload.alias:
|
||||
validate_mcp_server_name(payload.alias, raise_http_exception=True)
|
||||
|
||||
# Alias normalization and defaulting
|
||||
alias = getattr(payload, "alias", None)
|
||||
server_name = getattr(payload, "server_name", None)
|
||||
|
||||
if not alias and server_name:
|
||||
alias = normalize_server_name(server_name)
|
||||
elif alias:
|
||||
alias = normalize_server_name(alias)
|
||||
|
||||
# Update the payload with normalized alias
|
||||
if hasattr(payload, "alias"):
|
||||
payload.alias = alias
|
||||
|
||||
|
||||
def add_server_prefix_to_name(name: str, server_name: str) -> str:
|
||||
"""Add server name prefix to any MCP resource name."""
|
||||
formatted_server_name = normalize_server_name(server_name)
|
||||
|
||||
return MCP_TOOL_PREFIX_FORMAT.format(
|
||||
server_name=formatted_server_name,
|
||||
separator=MCP_TOOL_PREFIX_SEPARATOR,
|
||||
tool_name=name,
|
||||
)
|
||||
|
||||
|
||||
def get_server_prefix(server: Any) -> str:
|
||||
"""Return the prefix for a server: alias if present, else server_name, else server_id"""
|
||||
if hasattr(server, "alias") and server.alias:
|
||||
return server.alias
|
||||
if hasattr(server, "server_name") and server.server_name:
|
||||
return server.server_name
|
||||
if hasattr(server, "server_id"):
|
||||
return server.server_id
|
||||
return ""
|
||||
|
||||
|
||||
def split_server_prefix_from_name(prefixed_name: str) -> Tuple[str, str]:
|
||||
"""Return the unprefixed name plus the server name used as prefix."""
|
||||
if MCP_TOOL_PREFIX_SEPARATOR in prefixed_name:
|
||||
parts = prefixed_name.split(MCP_TOOL_PREFIX_SEPARATOR, 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1], parts[0]
|
||||
return prefixed_name, ""
|
||||
|
||||
|
||||
def is_tool_name_prefixed(tool_name: str) -> bool:
|
||||
"""
|
||||
Check if tool name has server prefix
|
||||
|
||||
Args:
|
||||
tool_name: Tool name to check
|
||||
|
||||
Returns:
|
||||
True if tool name is prefixed, False otherwise
|
||||
"""
|
||||
return MCP_TOOL_PREFIX_SEPARATOR in tool_name
|
||||
|
||||
|
||||
def validate_mcp_server_name(
|
||||
server_name: str, raise_http_exception: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that MCP server name does not contain 'MCP_TOOL_PREFIX_SEPARATOR'.
|
||||
|
||||
Args:
|
||||
server_name: The server name to validate
|
||||
raise_http_exception: If True, raises HTTPException instead of generic Exception
|
||||
|
||||
Raises:
|
||||
Exception or HTTPException: If server name contains 'MCP_TOOL_PREFIX_SEPARATOR'
|
||||
"""
|
||||
if server_name and MCP_TOOL_PREFIX_SEPARATOR in server_name:
|
||||
error_message = f"Server name cannot contain '{MCP_TOOL_PREFIX_SEPARATOR}'. Use an alternative character instead Found: {server_name}"
|
||||
if raise_http_exception:
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail={"error": error_message}
|
||||
)
|
||||
else:
|
||||
raise Exception(error_message)
|
||||
|
||||
|
||||
def merge_mcp_headers(
|
||||
*,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
static_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Merge outbound HTTP headers for MCP calls.
|
||||
|
||||
This is used when calling out to external MCP servers (or OpenAPI-based MCP tools).
|
||||
|
||||
Merge rules:
|
||||
- Start with `extra_headers` (typically OAuth2-derived headers)
|
||||
- Overlay `static_headers` (user-configured per MCP server)
|
||||
|
||||
If both contain the same key, `static_headers` wins. This matches the existing
|
||||
behavior in `MCPServerManager` where `server.static_headers` is applied after
|
||||
any caller-provided headers.
|
||||
"""
|
||||
merged: Dict[str, str] = {}
|
||||
|
||||
if extra_headers:
|
||||
merged.update({str(k): str(v) for k, v in extra_headers.items()})
|
||||
|
||||
if static_headers:
|
||||
merged.update({str(k): str(v) for k, v in static_headers.items()})
|
||||
|
||||
return merged or None
|
||||
@@ -0,0 +1,4 @@
|
||||
def my_custom_rule(input): # receives the model response
|
||||
# if len(input) < 5: # trigger fallback if the model response is too short
|
||||
return False
|
||||
return True
|
||||
@@ -0,0 +1,40 @@
|
||||
### DEPRECATED ###
|
||||
## unused file. initially written for json logging on proxy.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from logging import Formatter
|
||||
|
||||
from litellm import json_logs
|
||||
|
||||
# Set default log level to INFO
|
||||
log_level = os.getenv("LITELLM_LOG", "INFO")
|
||||
numeric_level: str = getattr(logging, log_level.upper())
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
def __init__(self):
|
||||
super(JsonFormatter, self).__init__()
|
||||
|
||||
def format(self, record):
|
||||
json_record = {
|
||||
"message": record.getMessage(),
|
||||
"level": record.levelname,
|
||||
"timestamp": self.formatTime(record, self.datefmt),
|
||||
}
|
||||
return json.dumps(json_record)
|
||||
|
||||
|
||||
logger = logging.root
|
||||
handler = logging.StreamHandler()
|
||||
if json_logs:
|
||||
handler.setFormatter(JsonFormatter())
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
logger.handlers = [handler]
|
||||
logger.setLevel(numeric_level)
|
||||
@@ -0,0 +1,14 @@
|
||||
model_list:
|
||||
- model_name: bedrock-claude
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
aws_region_name: us-east-1
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["datadog"] # logs llm success + failure logs on datadog
|
||||
service_callback: ["datadog"] # logs redis, postgres failures on datadog
|
||||
|
||||
general_settings:
|
||||
store_prompts_in_spend_logs: true
|
||||
@@ -0,0 +1,41 @@
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: claude-sonnet-4-5-20250929
|
||||
litellm_params:
|
||||
model: anthropic/claude-sonnet-4-5-20250929
|
||||
- model_name: gpt-4.1-mini
|
||||
litellm_params:
|
||||
model: openai/gpt-4.1-mini
|
||||
- model_name: gpt-5-mini
|
||||
litellm_params:
|
||||
model: openai/gpt-5-mini
|
||||
- model_name: custom_litellm_model
|
||||
litellm_params:
|
||||
model: litellm_agent/claude-sonnet-4-5-20250929
|
||||
litellm_system_prompt: "Be a helpful assistant."
|
||||
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "tool_policy"
|
||||
litellm_params:
|
||||
guardrail: tool_policy
|
||||
mode: [pre_call, post_call]
|
||||
default_on: true
|
||||
|
||||
mcp_servers:
|
||||
my_http_server:
|
||||
url: "http://0.0.0.0:8001/mcp"
|
||||
transport: "http"
|
||||
description: "My custom MCP server"
|
||||
available_on_public_internet: true
|
||||
|
||||
general_settings:
|
||||
store_model_in_db: true
|
||||
store_prompts_in_spend_logs: true
|
||||
@@ -0,0 +1,110 @@
|
||||
model_list:
|
||||
- model_name: claude-3-5-sonnet
|
||||
litellm_params:
|
||||
model: claude-3-haiku-20240307
|
||||
# - model_name: gemini-1.5-flash-gemini
|
||||
# litellm_params:
|
||||
# model: vertex_ai_beta/gemini-1.5-flash
|
||||
# api_base: https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash
|
||||
- litellm_params:
|
||||
api_base: http://0.0.0.0:8080
|
||||
api_key: ''
|
||||
model: gpt-4o
|
||||
rpm: 800
|
||||
input_cost_per_token: 300
|
||||
model_name: gpt-4o
|
||||
- model_name: llama3-70b-8192
|
||||
litellm_params:
|
||||
model: groq/llama3-70b-8192
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: predibase/llama-3-8b-instruct
|
||||
api_key: os.environ/PREDIBASE_API_KEY
|
||||
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||
max_new_tokens: 256
|
||||
# - litellm_params:
|
||||
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
# api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
# model: azure/gpt-35-turbo
|
||||
# rpm: 10
|
||||
# model_name: gpt-3.5-turbo-fake-model
|
||||
- litellm_params:
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2024-02-15-preview
|
||||
model: azure/chatgpt-v-2
|
||||
tpm: 100
|
||||
model_name: gpt-3.5-turbo
|
||||
- litellm_params:
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
model_name: bedrock-anthropic-claude-3
|
||||
- litellm_params:
|
||||
model: claude-3-haiku-20240307
|
||||
model_name: anthropic-claude-3
|
||||
- litellm_params:
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2024-02-15-preview
|
||||
model: azure/chatgpt-v-2
|
||||
drop_params: True
|
||||
tpm: 100
|
||||
model_name: gpt-3.5-turbo
|
||||
- model_name: tts
|
||||
litellm_params:
|
||||
model: openai/tts-1
|
||||
- model_name: gpt-4-turbo-preview
|
||||
litellm_params:
|
||||
api_base: https://openai-france-1234.openai.azure.com
|
||||
api_key: os.environ/AZURE_FRANCE_API_KEY
|
||||
api_version: 2024-02-15-preview
|
||||
model: azure/gpt-turbo
|
||||
- model_name: text-embedding
|
||||
litellm_params:
|
||||
model: textembedding-gecko-multilingual@001
|
||||
vertex_project: my-project-9d5c
|
||||
vertex_location: us-central1
|
||||
- model_name: lbl/command-r-plus
|
||||
litellm_params:
|
||||
model: openai/lbl/command-r-plus
|
||||
api_key: "os.environ/VLLM_API_KEY"
|
||||
api_base: http://vllm-command:8000/v1
|
||||
rpm: 1000
|
||||
input_cost_per_token: 0
|
||||
output_cost_per_token: 0
|
||||
model_info:
|
||||
max_input_tokens: 80920
|
||||
|
||||
# litellm_settings:
|
||||
# callbacks: ["dynamic_rate_limiter"]
|
||||
# # success_callback: ["langfuse"]
|
||||
# # failure_callback: ["langfuse"]
|
||||
# # default_team_settings:
|
||||
# # - team_id: proj1
|
||||
# # success_callback: ["langfuse"]
|
||||
# # langfuse_public_key: os.environ/LANGFUSE_PUBLIC_KEY
|
||||
# # langfuse_secret: os.environ/LANGFUSE_SECRET
|
||||
# # langfuse_host: https://us.cloud.langfuse.com
|
||||
# # - team_id: proj2
|
||||
# # success_callback: ["langfuse"]
|
||||
# # langfuse_public_key: os.environ/LANGFUSE_PUBLIC_KEY
|
||||
# # langfuse_secret: os.environ/LANGFUSE_SECRET
|
||||
# # langfuse_host: https://us.cloud.langfuse.com
|
||||
|
||||
assistant_settings:
|
||||
custom_llm_provider: openai
|
||||
litellm_params:
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
||||
router_settings:
|
||||
enable_pre_call_checks: true
|
||||
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["s3"]
|
||||
|
||||
# general_settings:
|
||||
# # alerting: ["slack"]
|
||||
# enable_jwt_auth: True
|
||||
# litellm_jwtauth:
|
||||
# team_id_jwt_field: "client_id"
|
||||
4459
llm-gateway-competitors/litellm-wheel-src/litellm/proxy/_types.py
Normal file
4459
llm-gateway-competitors/litellm-wheel-src/litellm/proxy/_types.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
A2A Protocol endpoints for LiteLLM Proxy.
|
||||
|
||||
Allows clients to invoke agents through LiteLLM using the A2A protocol.
|
||||
The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _jsonrpc_error(
|
||||
request_id: Optional[str],
|
||||
code: int,
|
||||
message: str,
|
||||
status_code: int = 400,
|
||||
) -> JSONResponse:
|
||||
"""Create a JSON-RPC 2.0 error response."""
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {"code": code, "message": message},
|
||||
},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
def _get_agent(agent_id: str):
|
||||
"""Look up an agent by ID or name. Returns None if not found."""
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
agent = global_agent_registry.get_agent_by_name(agent_name=agent_id)
|
||||
return agent
|
||||
|
||||
|
||||
def _enforce_inbound_trace_id(agent: Any, request: Request) -> None:
|
||||
"""Raise 400 if agent requires x-litellm-trace-id on inbound calls and it is missing."""
|
||||
agent_litellm_params = agent.litellm_params or {}
|
||||
if not agent_litellm_params.get("require_trace_id_on_calls_to_agent"):
|
||||
return
|
||||
|
||||
from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers
|
||||
|
||||
headers_dict = dict(request.headers)
|
||||
trace_id = get_chain_id_from_headers(headers_dict)
|
||||
if not trace_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Agent '{agent.agent_id}' requires x-litellm-trace-id header "
|
||||
"on all inbound requests."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_stream_message(
|
||||
api_base: Optional[str],
|
||||
request_id: str,
|
||||
params: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
proxy_server_request: Optional[dict] = None,
|
||||
*,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
proxy_logging_obj: Optional[Any] = None,
|
||||
) -> StreamingResponse:
|
||||
"""Handle message/stream method via SDK functions.
|
||||
|
||||
When user_api_key_dict, request_data, and proxy_logging_obj are provided,
|
||||
uses common_request_processing.async_streaming_data_generator with NDJSON
|
||||
serializers so proxy hooks and cost injection apply.
|
||||
"""
|
||||
from litellm.a2a_protocol import asend_message_streaming
|
||||
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
|
||||
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
|
||||
async def _error_stream():
|
||||
yield json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": "Server error: 'a2a' package not installed",
|
||||
},
|
||||
}
|
||||
) + "\n"
|
||||
|
||||
return StreamingResponse(_error_stream(), media_type="application/x-ndjson")
|
||||
|
||||
from a2a.types import MessageSendParams, SendStreamingMessageRequest
|
||||
|
||||
use_proxy_hooks = (
|
||||
user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
and proxy_logging_obj is not None
|
||||
)
|
||||
|
||||
async def stream_response():
|
||||
try:
|
||||
a2a_request = SendStreamingMessageRequest(
|
||||
id=request_id,
|
||||
params=MessageSendParams(**params),
|
||||
)
|
||||
a2a_stream = asend_message_streaming(
|
||||
request=a2a_request,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
agent_id=agent_id,
|
||||
metadata=metadata,
|
||||
proxy_server_request=proxy_server_request,
|
||||
agent_extra_headers=agent_extra_headers,
|
||||
)
|
||||
|
||||
if (
|
||||
use_proxy_hooks
|
||||
and user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
and proxy_logging_obj is not None
|
||||
):
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
)
|
||||
|
||||
def _ndjson_chunk(chunk: Any) -> str:
|
||||
if hasattr(chunk, "model_dump"):
|
||||
obj = chunk.model_dump(mode="json", exclude_none=True)
|
||||
else:
|
||||
obj = chunk
|
||||
return json.dumps(obj) + "\n"
|
||||
|
||||
def _ndjson_error(proxy_exc: Any) -> str:
|
||||
return (
|
||||
json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": getattr(
|
||||
proxy_exc,
|
||||
"message",
|
||||
f"Streaming error: {proxy_exc!s}",
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
async for (
|
||||
line
|
||||
) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
response=a2a_stream,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=request_data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
serialize_chunk=_ndjson_chunk,
|
||||
serialize_error=_ndjson_error,
|
||||
):
|
||||
yield line
|
||||
else:
|
||||
async for chunk in a2a_stream:
|
||||
if hasattr(chunk, "model_dump"):
|
||||
yield json.dumps(
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
) + "\n"
|
||||
else:
|
||||
yield json.dumps(chunk) + "\n"
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
|
||||
if (
|
||||
use_proxy_hooks
|
||||
and proxy_logging_obj is not None
|
||||
and user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
):
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
if transformed_exception is not None:
|
||||
e = transformed_exception
|
||||
if isinstance(e, HTTPException):
|
||||
raise
|
||||
yield json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {"code": -32603, "message": f"Streaming error: {str(e)}"},
|
||||
}
|
||||
) + "\n"
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/a2a/{agent_id}/.well-known/agent-card.json",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.get(
|
||||
"/a2a/{agent_id}/.well-known/agent.json",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_agent_card(
|
||||
agent_id: str,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get the agent card for an agent (A2A discovery endpoint).
|
||||
|
||||
Supports both standard paths:
|
||||
- /.well-known/agent-card.json
|
||||
- /.well-known/agent.json
|
||||
|
||||
The URL in the agent card is rewritten to point to the LiteLLM proxy,
|
||||
so all subsequent A2A calls go through LiteLLM for logging and cost tracking.
|
||||
"""
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = _get_agent(agent_id)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
# Check agent permission (skip for admin users)
|
||||
is_allowed = await AgentRequestHandler.is_agent_allowed(
|
||||
agent_id=agent.agent_id,
|
||||
user_api_key_auth=user_api_key_dict,
|
||||
)
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
|
||||
)
|
||||
|
||||
# Copy and rewrite URL to point to LiteLLM proxy
|
||||
agent_card = dict(agent.agent_card_params)
|
||||
agent_card["url"] = f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}"
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Returning agent card for '{agent_id}' with proxy URL: {agent_card['url']}"
|
||||
)
|
||||
return JSONResponse(content=agent_card)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting agent card: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/a2a/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/a2a/{agent_id}/message/send",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/a2a/{agent_id}/message/send",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def invoke_agent_a2a( # noqa: PLR0915
|
||||
agent_id: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Invoke an agent using the A2A protocol (JSON-RPC 2.0).
|
||||
|
||||
Supported methods:
|
||||
- message/send: Send a message and get a response
|
||||
- message/stream: Send a message and stream the response
|
||||
"""
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
body = {}
|
||||
try:
|
||||
body = await request.json()
|
||||
|
||||
verbose_proxy_logger.debug(f"A2A request for agent '{agent_id}': {body}")
|
||||
|
||||
# Validate JSON-RPC format
|
||||
if body.get("jsonrpc") != "2.0":
|
||||
return _jsonrpc_error(
|
||||
body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'"
|
||||
)
|
||||
|
||||
request_id = body.get("id")
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
|
||||
if params:
|
||||
# extract any litellm params from the params - eg. 'guardrails'
|
||||
params_to_remove = []
|
||||
for key, value in params.items():
|
||||
if key in all_litellm_params:
|
||||
params_to_remove.append(key)
|
||||
body[key] = value
|
||||
for key in params_to_remove:
|
||||
params.pop(key)
|
||||
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
return _jsonrpc_error(
|
||||
request_id,
|
||||
-32603,
|
||||
"Server error: 'a2a' package not installed. Please install 'a2a-sdk'.",
|
||||
500,
|
||||
)
|
||||
|
||||
# Find the agent
|
||||
agent = _get_agent(agent_id)
|
||||
if agent is None:
|
||||
return _jsonrpc_error(
|
||||
request_id, -32000, f"Agent '{agent_id}' not found", 404
|
||||
)
|
||||
|
||||
is_allowed = await AgentRequestHandler.is_agent_allowed(
|
||||
agent_id=agent.agent_id,
|
||||
user_api_key_auth=user_api_key_dict,
|
||||
)
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
|
||||
)
|
||||
|
||||
_enforce_inbound_trace_id(agent, request)
|
||||
|
||||
# Get backend URL and agent name
|
||||
agent_url = agent.agent_card_params.get("url")
|
||||
agent_name = agent.agent_card_params.get("name", agent_id)
|
||||
|
||||
# Get litellm_params (may include custom_llm_provider for completion bridge)
|
||||
litellm_params = agent.litellm_params or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# URL is required unless using completion bridge with a provider that derives endpoint from model
|
||||
# (e.g., bedrock/agentcore derives endpoint from ARN in model string)
|
||||
if not agent_url and not custom_llm_provider:
|
||||
return _jsonrpc_error(
|
||||
request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}"
|
||||
)
|
||||
|
||||
# Set up data dict for litellm processing
|
||||
if "metadata" not in body:
|
||||
body["metadata"] = {}
|
||||
body["metadata"]["agent_id"] = agent.agent_id
|
||||
|
||||
body.update(
|
||||
{
|
||||
"model": f"a2a_agent/{agent_name}",
|
||||
"custom_llm_provider": "a2a_agent",
|
||||
}
|
||||
)
|
||||
|
||||
# Add litellm data (user_api_key, user_id, team_id, etc.)
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
)
|
||||
|
||||
processor = ProxyBaseLLMRequestProcessing(data=body)
|
||||
data, logging_obj = await processor.common_processing_pre_call_logic(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
proxy_config=proxy_config,
|
||||
route_type="asend_message",
|
||||
version=version,
|
||||
)
|
||||
|
||||
# Build merged headers for the backend agent
|
||||
static_headers: Dict[str, str] = dict(agent.static_headers or {})
|
||||
|
||||
raw_headers = dict(request.headers)
|
||||
normalized = {k.lower(): v for k, v in raw_headers.items()}
|
||||
|
||||
dynamic_headers: Dict[str, str] = {}
|
||||
|
||||
# 1. Admin-configured extra_headers: forward named headers from client request
|
||||
if agent.extra_headers:
|
||||
for header_name in agent.extra_headers:
|
||||
val = normalized.get(header_name.lower())
|
||||
if val is not None:
|
||||
dynamic_headers[header_name] = val
|
||||
|
||||
# 2. Convention-based forwarding: x-a2a-{agent_id_or_name}-{header_name}
|
||||
# Matches both agent_id (UUID) and agent_name (alias), case-insensitive.
|
||||
for alias in (agent.agent_id.lower(), agent.agent_name.lower()):
|
||||
prefix = f"x-a2a-{alias}-"
|
||||
for key, val in normalized.items():
|
||||
if key.startswith(prefix):
|
||||
header_name = key[len(prefix) :]
|
||||
if header_name:
|
||||
dynamic_headers[header_name] = val
|
||||
|
||||
agent_extra_headers = merge_agent_headers(
|
||||
dynamic_headers=dynamic_headers or None,
|
||||
static_headers=static_headers or None,
|
||||
)
|
||||
|
||||
# Route through SDK functions
|
||||
if method == "message/send":
|
||||
from a2a.types import MessageSendParams, SendMessageRequest
|
||||
|
||||
a2a_request = SendMessageRequest(
|
||||
id=request_id,
|
||||
params=MessageSendParams(**params),
|
||||
)
|
||||
response = await asend_message(
|
||||
request=a2a_request,
|
||||
api_base=agent_url,
|
||||
litellm_params=litellm_params,
|
||||
agent_id=agent.agent_id,
|
||||
metadata=data.get("metadata", {}),
|
||||
proxy_server_request=data.get("proxy_server_request"),
|
||||
litellm_logging_obj=logging_obj,
|
||||
agent_extra_headers=agent_extra_headers,
|
||||
)
|
||||
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
)
|
||||
return JSONResponse(
|
||||
content=(
|
||||
response.model_dump(mode="json", exclude_none=True) # type: ignore
|
||||
if hasattr(response, "model_dump")
|
||||
else response
|
||||
)
|
||||
)
|
||||
|
||||
elif method == "message/stream":
|
||||
return await _handle_stream_message(
|
||||
api_base=agent_url,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
agent_id=agent.agent_id,
|
||||
metadata=data.get("metadata", {}),
|
||||
proxy_server_request=data.get("proxy_server_request"),
|
||||
agent_extra_headers=agent_extra_headers,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
else:
|
||||
return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error invoking agent: {e}")
|
||||
return _jsonrpc_error(body.get("id"), -32603, f"Internal error: {str(e)}", 500)
|
||||
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
A2A Agent Routing
|
||||
|
||||
Handles routing for A2A agents (models with "a2a/<agent-name>" prefix).
|
||||
Looks up agents in the registry and injects their API base URL.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def route_a2a_agent_request(data: dict, route_type: str) -> Optional[Any]:
|
||||
"""
|
||||
Route A2A agent requests directly to litellm with injected API base.
|
||||
|
||||
Returns None if not an A2A request (allows normal routing to continue).
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.route_llm_request import (
|
||||
ROUTE_ENDPOINT_MAPPING,
|
||||
ProxyModelNotFoundError,
|
||||
)
|
||||
|
||||
model_name = data.get("model", "")
|
||||
|
||||
# Check if this is an A2A agent request
|
||||
if not isinstance(model_name, str) or not model_name.startswith("a2a/"):
|
||||
return None
|
||||
|
||||
# Extract agent name (e.g., "a2a/my-agent" -> "my-agent")
|
||||
agent_name = model_name[4:]
|
||||
|
||||
# Look up agent in registry
|
||||
agent = global_agent_registry.get_agent_by_name(agent_name)
|
||||
if agent is None:
|
||||
verbose_proxy_logger.error(f"[A2A] Agent '{agent_name}' not found in registry")
|
||||
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
|
||||
raise ProxyModelNotFoundError(route=route_name, model_name=model_name)
|
||||
|
||||
# Get API base URL from agent config
|
||||
if not agent.agent_card_params or "url" not in agent.agent_card_params:
|
||||
verbose_proxy_logger.error(f"[A2A] Agent '{agent_name}' has no URL configured")
|
||||
route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type)
|
||||
raise ProxyModelNotFoundError(route=route_name, model_name=model_name)
|
||||
|
||||
# Inject API base and route to litellm
|
||||
data["api_base"] = agent.agent_card_params["url"]
|
||||
verbose_proxy_logger.debug(f"[A2A] Routing {model_name} to {data['api_base']}")
|
||||
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
@@ -0,0 +1,458 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
handle_update_object_permission_common,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
def __init__(self):
|
||||
self.agent_list: List[AgentResponse] = []
|
||||
|
||||
def reset_agent_list(self):
|
||||
self.agent_list = []
|
||||
|
||||
def register_agent(self, agent_config: AgentResponse):
|
||||
self.agent_list.append(agent_config)
|
||||
|
||||
def deregister_agent(self, agent_name: str):
|
||||
self.agent_list = [
|
||||
agent for agent in self.agent_list if agent.agent_name != agent_name
|
||||
]
|
||||
|
||||
def get_agent_list(self, agent_names: Optional[List[str]] = None):
|
||||
if agent_names is not None:
|
||||
return [
|
||||
agent for agent in self.agent_list if agent.agent_name in agent_names
|
||||
]
|
||||
return self.agent_list
|
||||
|
||||
def get_public_agent_list(self) -> List[AgentResponse]:
|
||||
public_agent_list: List[AgentResponse] = []
|
||||
if litellm.public_agent_groups is None:
|
||||
return public_agent_list
|
||||
for agent in self.agent_list:
|
||||
if agent.agent_id in litellm.public_agent_groups:
|
||||
public_agent_list.append(agent)
|
||||
return public_agent_list
|
||||
|
||||
def _create_agent_id(self, agent_config: AgentConfig) -> str:
|
||||
return hashlib.sha256(
|
||||
json.dumps(agent_config, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None):
|
||||
if agent_config is None:
|
||||
return None
|
||||
|
||||
for agent_config_item in agent_config:
|
||||
if not isinstance(agent_config_item, dict):
|
||||
raise ValueError("agent_config must be a list of dictionaries")
|
||||
|
||||
agent_name = agent_config_item.get("agent_name")
|
||||
agent_card_params = agent_config_item.get("agent_card_params")
|
||||
if not all([agent_name, agent_card_params]):
|
||||
continue
|
||||
|
||||
# create a stable hash id for config item
|
||||
config_hash = self._create_agent_id(agent_config_item)
|
||||
|
||||
self.register_agent(agent_config=AgentResponse(agent_id=config_hash, **agent_config_item)) # type: ignore
|
||||
|
||||
def load_agents_from_db_and_config(
|
||||
self,
|
||||
agent_config: Optional[List[AgentConfig]] = None,
|
||||
db_agents: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
self.reset_agent_list()
|
||||
|
||||
if agent_config:
|
||||
for agent_config_item in agent_config:
|
||||
if not isinstance(agent_config_item, dict):
|
||||
raise ValueError("agent_config must be a list of dictionaries")
|
||||
|
||||
self.register_agent(agent_config=AgentResponse(agent_id=self._create_agent_id(agent_config_item), **agent_config_item)) # type: ignore
|
||||
|
||||
if db_agents:
|
||||
for db_agent in db_agents:
|
||||
if not isinstance(db_agent, dict):
|
||||
raise ValueError("db_agents must be a list of dictionaries")
|
||||
|
||||
self.register_agent(agent_config=AgentResponse(**db_agent)) # type: ignore
|
||||
return self.agent_list
|
||||
|
||||
###########################################################
|
||||
########### DB management helpers for agents ###########
|
||||
############################################################
|
||||
async def add_agent_to_db(
|
||||
self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
Add an agent to the database
|
||||
"""
|
||||
try:
|
||||
agent_name = agent.get("agent_name")
|
||||
|
||||
# Serialize litellm_params
|
||||
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||||
if hasattr(litellm_params_obj, "model_dump"):
|
||||
litellm_params_dict = litellm_params_obj.model_dump()
|
||||
else:
|
||||
litellm_params_dict = (
|
||||
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||
)
|
||||
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||
|
||||
# Serialize agent_card_params
|
||||
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||||
if hasattr(agent_card_params_obj, "model_dump"):
|
||||
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||||
else:
|
||||
agent_card_params_dict = (
|
||||
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||||
)
|
||||
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||
|
||||
# Handle object_permission (MCP tool access for agent)
|
||||
object_permission_id: Optional[str] = None
|
||||
if agent.get("object_permission") is not None:
|
||||
agent_copy = dict(agent)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
agent_copy, None, prisma_client
|
||||
)
|
||||
|
||||
# Serialize static_headers
|
||||
static_headers_obj = agent.get("static_headers")
|
||||
static_headers_val: Optional[str] = (
|
||||
safe_dumps(dict(static_headers_obj)) if static_headers_obj else None
|
||||
)
|
||||
|
||||
extra_headers_val: Optional[List[str]] = agent.get("extra_headers")
|
||||
|
||||
create_data: Dict[str, Any] = {
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
if static_headers_val is not None:
|
||||
create_data["static_headers"] = static_headers_val
|
||||
if extra_headers_val is not None:
|
||||
create_data["extra_headers"] = extra_headers_val
|
||||
if object_permission_id is not None:
|
||||
create_data["object_permission_id"] = object_permission_id
|
||||
|
||||
for rate_field in (
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"session_tpm_limit",
|
||||
"session_rpm_limit",
|
||||
):
|
||||
_val = agent.get(rate_field)
|
||||
if _val is not None:
|
||||
create_data[rate_field] = _val
|
||||
|
||||
# Create agent in DB
|
||||
created_agent = await prisma_client.db.litellm_agentstable.create(
|
||||
data=create_data,
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
created_agent_dict = created_agent.model_dump()
|
||||
if created_agent.object_permission is not None:
|
||||
try:
|
||||
created_agent_dict[
|
||||
"object_permission"
|
||||
] = created_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
created_agent_dict[
|
||||
"object_permission"
|
||||
] = created_agent.object_permission.dict()
|
||||
return AgentResponse(**created_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error adding agent to DB: {str(e)}")
|
||||
|
||||
async def delete_agent_from_db(
|
||||
self, agent_id: str, prisma_client: PrismaClient
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete an agent from the database
|
||||
"""
|
||||
try:
|
||||
deleted_agent = await prisma_client.db.litellm_agentstable.delete(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
return dict(deleted_agent)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error deleting agent from DB: {str(e)}")
|
||||
|
||||
async def patch_agent_in_db(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent: PatchAgentRequest,
|
||||
prisma_client: PrismaClient,
|
||||
updated_by: str,
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
Patch an agent in the database.
|
||||
|
||||
Get the existing agent from the database and patch it with the new values.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent to patch
|
||||
agent: The new agent values to patch
|
||||
prisma_client: The Prisma client to use
|
||||
updated_by: The user ID of the user who is patching the agent
|
||||
|
||||
Returns:
|
||||
The patched agent
|
||||
"""
|
||||
try:
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict(existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise Exception(f"Agent with ID {agent_id} not found")
|
||||
|
||||
augment_agent = {**existing_agent, **agent}
|
||||
update_data: Dict[str, Any] = {}
|
||||
if augment_agent.get("agent_name"):
|
||||
update_data["agent_name"] = augment_agent.get("agent_name")
|
||||
if augment_agent.get("litellm_params"):
|
||||
update_data["litellm_params"] = safe_dumps(
|
||||
augment_agent.get("litellm_params")
|
||||
)
|
||||
if augment_agent.get("agent_card_params"):
|
||||
update_data["agent_card_params"] = safe_dumps(
|
||||
augment_agent.get("agent_card_params")
|
||||
)
|
||||
|
||||
for rate_field in (
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"session_tpm_limit",
|
||||
"session_rpm_limit",
|
||||
):
|
||||
if rate_field in agent:
|
||||
update_data[rate_field] = agent.get(rate_field)
|
||||
if "static_headers" in agent:
|
||||
headers_value = agent.get("static_headers")
|
||||
update_data["static_headers"] = safe_dumps(
|
||||
dict(headers_value) if headers_value is not None else {}
|
||||
)
|
||||
if "extra_headers" in agent:
|
||||
extra_headers_value = agent.get("extra_headers")
|
||||
update_data["extra_headers"] = (
|
||||
extra_headers_value if extra_headers_value is not None else []
|
||||
)
|
||||
if agent.get("object_permission") is not None:
|
||||
agent_copy = dict(augment_agent)
|
||||
existing_object_permission_id = existing_agent.get(
|
||||
"object_permission_id"
|
||||
)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
agent_copy,
|
||||
existing_object_permission_id,
|
||||
prisma_client,
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
# Patch agent in DB
|
||||
patched_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
where={"agent_id": agent_id},
|
||||
data={
|
||||
**update_data,
|
||||
"updated_by": updated_by,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
patched_agent_dict = patched_agent.model_dump()
|
||||
if patched_agent.object_permission is not None:
|
||||
try:
|
||||
patched_agent_dict[
|
||||
"object_permission"
|
||||
] = patched_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
patched_agent_dict[
|
||||
"object_permission"
|
||||
] = patched_agent.object_permission.dict()
|
||||
return AgentResponse(**patched_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error patching agent in DB: {str(e)}")
|
||||
|
||||
async def update_agent_in_db(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent: AgentConfig,
|
||||
prisma_client: PrismaClient,
|
||||
updated_by: str,
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
Update an agent in the database
|
||||
"""
|
||||
try:
|
||||
agent_name = agent.get("agent_name")
|
||||
|
||||
# Serialize litellm_params
|
||||
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||||
if hasattr(litellm_params_obj, "model_dump"):
|
||||
litellm_params_dict = litellm_params_obj.model_dump()
|
||||
else:
|
||||
litellm_params_dict = (
|
||||
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||
)
|
||||
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||
|
||||
# Serialize agent_card_params
|
||||
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||||
if hasattr(agent_card_params_obj, "model_dump"):
|
||||
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||||
else:
|
||||
agent_card_params_dict = (
|
||||
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||||
)
|
||||
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||
|
||||
# Serialize static_headers for update
|
||||
static_headers_obj_u = agent.get("static_headers")
|
||||
static_headers_val_u: str = (
|
||||
safe_dumps(dict(static_headers_obj_u))
|
||||
if static_headers_obj_u is not None
|
||||
else safe_dumps({})
|
||||
)
|
||||
extra_headers_val_u: List[str] = agent.get("extra_headers") or []
|
||||
|
||||
update_data: Dict[str, Any] = {
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"static_headers": static_headers_val_u,
|
||||
"extra_headers": extra_headers_val_u,
|
||||
"updated_by": updated_by,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
for rate_field in (
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"session_tpm_limit",
|
||||
"session_rpm_limit",
|
||||
):
|
||||
_val = agent.get(rate_field)
|
||||
if _val is not None:
|
||||
update_data[rate_field] = _val
|
||||
|
||||
if agent.get("object_permission") is not None:
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
existing_object_permission_id = (
|
||||
existing_agent.object_permission_id
|
||||
if existing_agent is not None
|
||||
else None
|
||||
)
|
||||
agent_copy = dict(agent)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
agent_copy,
|
||||
existing_object_permission_id,
|
||||
prisma_client,
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
|
||||
# Update agent in DB
|
||||
updated_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
where={"agent_id": agent_id},
|
||||
data=update_data,
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
updated_agent_dict = updated_agent.model_dump()
|
||||
if updated_agent.object_permission is not None:
|
||||
try:
|
||||
updated_agent_dict[
|
||||
"object_permission"
|
||||
] = updated_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
updated_agent_dict[
|
||||
"object_permission"
|
||||
] = updated_agent.object_permission.dict()
|
||||
return AgentResponse(**updated_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error updating agent in DB: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_all_agents_from_db(
|
||||
prisma_client: PrismaClient,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all agents from the database
|
||||
"""
|
||||
try:
|
||||
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
|
||||
order={"created_at": "desc"},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
agents: List[Dict[str, Any]] = []
|
||||
for agent in agents_from_db:
|
||||
agent_dict = dict(agent)
|
||||
# object_permission is eagerly loaded via include above
|
||||
if agent.object_permission is not None:
|
||||
try:
|
||||
agent_dict[
|
||||
"object_permission"
|
||||
] = agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
agent_dict["object_permission"] = agent.object_permission.dict()
|
||||
agents.append(agent_dict)
|
||||
|
||||
return agents
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting agents from DB: {str(e)}")
|
||||
|
||||
def get_agent_by_id(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> Optional[AgentResponse]:
|
||||
"""
|
||||
Get an agent by its ID from the database
|
||||
"""
|
||||
try:
|
||||
for agent in self.agent_list:
|
||||
if agent.agent_id == agent_id:
|
||||
return agent
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||||
|
||||
def get_agent_by_name(self, agent_name: str) -> Optional[AgentResponse]:
|
||||
"""
|
||||
Get an agent by its name from the database
|
||||
"""
|
||||
try:
|
||||
for agent in self.agent_list:
|
||||
if agent.agent_name == agent_name:
|
||||
return agent
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||||
|
||||
|
||||
global_agent_registry = AgentRegistry()
|
||||
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
Agent Permission Handler for LiteLLM Proxy.
|
||||
|
||||
Handles agent permission checking for keys and teams using object_permission_id.
|
||||
Follows the same pattern as MCP permission handling.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTable,
|
||||
UI_TEAM_ID,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
||||
|
||||
class AgentRequestHandler:
|
||||
"""
|
||||
Class to handle agent permission checking, including:
|
||||
1. Key-level agent permissions
|
||||
2. Team-level agent permissions
|
||||
3. Agent access group resolution
|
||||
|
||||
Follows the same inheritance logic as MCP:
|
||||
- If team has restrictions and key has restrictions: use intersection
|
||||
- If team has restrictions and key has none: inherit from team
|
||||
- If team has no restrictions: use key restrictions
|
||||
- If no restrictions: allow all agents
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_allowed_agents(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of allowed agent IDs for the given user/key based on permissions.
|
||||
|
||||
Returns:
|
||||
List[str]: List of allowed agent IDs. Empty list means no restrictions (allow all).
|
||||
"""
|
||||
try:
|
||||
allowed_agents: List[str] = []
|
||||
allowed_agents_for_key = (
|
||||
await AgentRequestHandler._get_allowed_agents_for_key(user_api_key_auth)
|
||||
)
|
||||
allowed_agents_for_team = (
|
||||
await AgentRequestHandler._get_allowed_agents_for_team(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
|
||||
# If team has agent restrictions, handle inheritance and intersection logic
|
||||
if len(allowed_agents_for_team) > 0:
|
||||
if len(allowed_agents_for_key) > 0:
|
||||
# Key has its own agent permissions - use intersection with team permissions
|
||||
for agent_id in allowed_agents_for_key:
|
||||
if agent_id in allowed_agents_for_team:
|
||||
allowed_agents.append(agent_id)
|
||||
else:
|
||||
# Key has no agent permissions - inherit from team
|
||||
allowed_agents = allowed_agents_for_team
|
||||
else:
|
||||
allowed_agents = allowed_agents_for_key
|
||||
|
||||
return list(set(allowed_agents))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get allowed agents: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def is_agent_allowed(
|
||||
agent_id: str,
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a specific agent is allowed for the given user/key.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to check
|
||||
user_api_key_auth: User authentication info
|
||||
|
||||
Returns:
|
||||
bool: True if agent is allowed, False otherwise
|
||||
"""
|
||||
allowed_agents = await AgentRequestHandler.get_allowed_agents(user_api_key_auth)
|
||||
|
||||
# Empty list means no restrictions - allow all
|
||||
if len(allowed_agents) == 0:
|
||||
return True
|
||||
|
||||
return agent_id in allowed_agents
|
||||
|
||||
@staticmethod
|
||||
def _get_key_object_permission(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> Optional[LiteLLM_ObjectPermissionTable]:
|
||||
"""
|
||||
Get key object_permission - already loaded by get_key_object() in main auth flow.
|
||||
|
||||
Note: object_permission is automatically populated when the key is fetched via
|
||||
get_key_object() in litellm/proxy/auth/auth_checks.py
|
||||
"""
|
||||
if not user_api_key_auth:
|
||||
return None
|
||||
|
||||
return user_api_key_auth.object_permission
|
||||
|
||||
@staticmethod
|
||||
async def _get_team_object_permission(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> Optional[LiteLLM_ObjectPermissionTable]:
|
||||
"""
|
||||
Get team object_permission - automatically loaded by get_team_object() in main auth flow.
|
||||
|
||||
Note: object_permission is automatically populated when the team is fetched via
|
||||
get_team_object() in litellm/proxy/auth/auth_checks.py
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if not user_api_key_auth or not user_api_key_auth.team_id or not prisma_client:
|
||||
return None
|
||||
|
||||
# Get the team object (which has object_permission already loaded)
|
||||
team_obj: Optional[LiteLLM_TeamTable] = await get_team_object(
|
||||
team_id=user_api_key_auth.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if not team_obj:
|
||||
return None
|
||||
|
||||
return team_obj.object_permission
|
||||
|
||||
@staticmethod
|
||||
async def _get_allowed_agents_for_key(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get allowed agents for a key.
|
||||
|
||||
1. First checks native key-level agent permissions (object_permission)
|
||||
2. Also includes agents from key's access_group_ids (unified access groups)
|
||||
|
||||
Note: object_permission is already loaded by get_key_object() in main auth flow.
|
||||
"""
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
all_agents: List[str] = []
|
||||
|
||||
# 1. Get agents from object_permission (native permissions)
|
||||
key_object_permission = AgentRequestHandler._get_key_object_permission(
|
||||
user_api_key_auth
|
||||
)
|
||||
if key_object_permission is not None:
|
||||
# Get direct agents
|
||||
direct_agents = key_object_permission.agents or []
|
||||
|
||||
# Get agents from access groups
|
||||
access_group_agents = (
|
||||
await AgentRequestHandler._get_agents_from_access_groups(
|
||||
key_object_permission.agent_access_groups or []
|
||||
)
|
||||
)
|
||||
|
||||
all_agents = direct_agents + access_group_agents
|
||||
|
||||
# 2. Fallback: get agent IDs from key's access_group_ids (unified access groups)
|
||||
key_access_group_ids = user_api_key_auth.access_group_ids or []
|
||||
if key_access_group_ids:
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_get_agent_ids_from_access_groups,
|
||||
)
|
||||
|
||||
unified_agents = await _get_agent_ids_from_access_groups(
|
||||
access_group_ids=key_access_group_ids,
|
||||
)
|
||||
all_agents.extend(unified_agents)
|
||||
|
||||
return list(set(all_agents))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get allowed agents for key: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _get_allowed_agents_for_team(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get allowed agents for a team.
|
||||
|
||||
1. First checks native team-level agent permissions (object_permission)
|
||||
2. Also includes agents from team's access_group_ids (unified access groups)
|
||||
|
||||
Fetches the team object once and reuses it for both permission sources.
|
||||
"""
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
if user_api_key_auth.team_id is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if not prisma_client:
|
||||
return []
|
||||
|
||||
# Fetch the team object once for both permission sources
|
||||
team_obj = await get_team_object(
|
||||
team_id=user_api_key_auth.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if team_obj is None:
|
||||
return []
|
||||
|
||||
all_agents: List[str] = []
|
||||
|
||||
# 1. Get agents from object_permission (native permissions)
|
||||
object_permissions = team_obj.object_permission
|
||||
if object_permissions is not None:
|
||||
# Get direct agents
|
||||
direct_agents = object_permissions.agents or []
|
||||
|
||||
# Get agents from access groups
|
||||
access_group_agents = (
|
||||
await AgentRequestHandler._get_agents_from_access_groups(
|
||||
object_permissions.agent_access_groups or []
|
||||
)
|
||||
)
|
||||
|
||||
all_agents = direct_agents + access_group_agents
|
||||
|
||||
# 2. Also include agents from team's access_group_ids (unified access groups)
|
||||
team_access_group_ids = team_obj.access_group_ids or []
|
||||
if team_access_group_ids:
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_get_agent_ids_from_access_groups,
|
||||
)
|
||||
|
||||
unified_agents = await _get_agent_ids_from_access_groups(
|
||||
access_group_ids=team_access_group_ids,
|
||||
)
|
||||
all_agents.extend(unified_agents)
|
||||
|
||||
return list(set(all_agents))
|
||||
except Exception as e:
|
||||
# litellm-dashboard is the default UI team and will never have agents;
|
||||
# skip noisy warnings for it.
|
||||
if user_api_key_auth.team_id != UI_TEAM_ID:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get allowed agents for team: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_config_agent_ids_for_access_groups(
|
||||
config_agents: List, access_groups: List[str]
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Helper to get agent_ids from config-loaded agents that match any of the given access groups.
|
||||
"""
|
||||
server_ids: Set[str] = set()
|
||||
for agent in config_agents:
|
||||
agent_access_groups = getattr(agent, "agent_access_groups", None)
|
||||
if agent_access_groups:
|
||||
if any(group in agent_access_groups for group in access_groups):
|
||||
server_ids.add(agent.agent_id)
|
||||
return server_ids
|
||||
|
||||
@staticmethod
|
||||
async def _get_db_agent_ids_for_access_groups(
|
||||
prisma_client, access_groups: List[str]
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Helper to get agent_ids from DB agents that match any of the given access groups.
|
||||
"""
|
||||
agent_ids: Set[str] = set()
|
||||
if access_groups and prisma_client is not None:
|
||||
try:
|
||||
agents = await prisma_client.db.litellm_agentstable.find_many(
|
||||
where={"agent_access_groups": {"hasSome": access_groups}}
|
||||
)
|
||||
for agent in agents:
|
||||
agent_ids.add(agent.agent_id)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting agents from access groups: {e}")
|
||||
return agent_ids
|
||||
|
||||
@staticmethod
|
||||
async def _get_agents_from_access_groups(
|
||||
access_groups: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Resolve agent access groups to agent IDs by querying BOTH the agent table (DB) AND config-loaded agents.
|
||||
"""
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
# Use the helper for config-loaded agents
|
||||
agent_ids = AgentRequestHandler._get_config_agent_ids_for_access_groups(
|
||||
global_agent_registry.agent_list, access_groups
|
||||
)
|
||||
|
||||
# Use the helper for DB agents
|
||||
db_agent_ids = (
|
||||
await AgentRequestHandler._get_db_agent_ids_for_access_groups(
|
||||
prisma_client, access_groups
|
||||
)
|
||||
)
|
||||
agent_ids.update(db_agent_ids)
|
||||
|
||||
return list(agent_ids)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get agents from access groups: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def get_agent_access_groups(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of agent access groups for the given user/key based on permissions.
|
||||
"""
|
||||
access_groups: List[str] = []
|
||||
access_groups_for_key = (
|
||||
await AgentRequestHandler._get_agent_access_groups_for_key(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
access_groups_for_team = (
|
||||
await AgentRequestHandler._get_agent_access_groups_for_team(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
|
||||
# If team has access groups, then key must have a subset of the team's access groups
|
||||
if len(access_groups_for_team) > 0:
|
||||
for access_group in access_groups_for_key:
|
||||
if access_group in access_groups_for_team:
|
||||
access_groups.append(access_group)
|
||||
else:
|
||||
access_groups = access_groups_for_key
|
||||
|
||||
return list(set(access_groups))
|
||||
|
||||
@staticmethod
|
||||
async def _get_agent_access_groups_for_key(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""Get agent access groups for the key."""
|
||||
from litellm.proxy.auth.auth_checks import get_object_permission
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
if user_api_key_auth.object_permission_id is None:
|
||||
return []
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_logger.debug("prisma_client is None")
|
||||
return []
|
||||
|
||||
try:
|
||||
key_object_permission = await get_object_permission(
|
||||
object_permission_id=user_api_key_auth.object_permission_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if key_object_permission is None:
|
||||
return []
|
||||
|
||||
return key_object_permission.agent_access_groups or []
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get agent access groups for key: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _get_agent_access_groups_for_team(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[str]:
|
||||
"""Get agent access groups for the team."""
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if user_api_key_auth is None:
|
||||
return []
|
||||
|
||||
if user_api_key_auth.team_id is None:
|
||||
return []
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_logger.debug("prisma_client is None")
|
||||
return []
|
||||
|
||||
try:
|
||||
team_obj: Optional[LiteLLM_TeamTable] = await get_team_object(
|
||||
team_id=user_api_key_auth.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if team_obj is None:
|
||||
verbose_logger.debug("team_obj is None")
|
||||
return []
|
||||
|
||||
object_permissions = team_obj.object_permission
|
||||
if object_permissions is None:
|
||||
return []
|
||||
|
||||
return object_permissions.agent_access_groups or []
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get agent access groups for team: {str(e)}"
|
||||
)
|
||||
return []
|
||||
@@ -0,0 +1,944 @@
|
||||
"""
|
||||
Agent endpoints for registering + discovering agents via LiteLLM.
|
||||
|
||||
Follows the A2A Spec.
|
||||
|
||||
1. Register an agent via POST `/v1/agents`
|
||||
2. Discover agents via GET `/v1/agents`
|
||||
3. Get specific agent via GET `/v1/agents/{agent_id}`
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.rbac_utils import check_feature_access_for_user
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||
from litellm.types.agents import (
|
||||
AgentConfig,
|
||||
AgentMakePublicResponse,
|
||||
AgentResponse,
|
||||
MakeAgentsPublicRequest,
|
||||
PatchAgentRequest,
|
||||
)
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _check_agent_management_permission(user_api_key_dict: UserAPIKeyAuth) -> None:
|
||||
"""
|
||||
Raises HTTP 403 if the caller does not have permission to create, update,
|
||||
or delete agents. Only PROXY_ADMIN users are allowed to perform these
|
||||
write operations.
|
||||
"""
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only proxy admins can create, update, or delete agents. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
AGENT_HEALTH_CHECK_TIMEOUT_SECONDS = float(
|
||||
os.environ.get("LITELLM_AGENT_HEALTH_CHECK_TIMEOUT", "5.0")
|
||||
)
|
||||
AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS = float(
|
||||
os.environ.get("LITELLM_AGENT_HEALTH_CHECK_GATHER_TIMEOUT", "30.0")
|
||||
)
|
||||
|
||||
|
||||
async def _check_agent_url_health(
|
||||
agent: AgentResponse,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a GET request against the agent's URL and return the health result.
|
||||
|
||||
Returns a dict with ``agent_id``, ``healthy`` (bool), and an optional
|
||||
``error`` message.
|
||||
"""
|
||||
url = (agent.agent_card_params or {}).get("url")
|
||||
if not url:
|
||||
return {"agent_id": agent.agent_id, "healthy": True}
|
||||
|
||||
try:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.AgentHealthCheck,
|
||||
params={"timeout": AGENT_HEALTH_CHECK_TIMEOUT_SECONDS},
|
||||
)
|
||||
response = await client.get(url)
|
||||
if response.status_code >= 500:
|
||||
return {
|
||||
"agent_id": agent.agent_id,
|
||||
"healthy": False,
|
||||
"error": f"HTTP {response.status_code}",
|
||||
}
|
||||
return {"agent_id": agent.agent_id, "healthy": True}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"agent_id": agent.agent_id,
|
||||
"healthy": False,
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/agents",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[AgentResponse],
|
||||
)
|
||||
async def get_agents(
|
||||
request: Request,
|
||||
health_check: bool = Query(
|
||||
False,
|
||||
description="When true, performs a GET request to each agent's URL. Agents with reachable URLs (HTTP status < 500) and agents without a URL are returned; unreachable agents are filtered out.",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), # Used for auth
|
||||
):
|
||||
"""
|
||||
Example usage:
|
||||
```
|
||||
curl -X GET "http://localhost:4000/v1/agents" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
```
|
||||
|
||||
Pass `?health_check=true` to filter out agents whose URL is unreachable:
|
||||
```
|
||||
curl -X GET "http://localhost:4000/v1/agents?health_check=true" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
```
|
||||
|
||||
Returns: List[AgentResponse]
|
||||
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
try:
|
||||
returned_agents: List[AgentResponse] = []
|
||||
|
||||
# Admin users get all agents
|
||||
if (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
):
|
||||
returned_agents = global_agent_registry.get_agent_list()
|
||||
else:
|
||||
# Get allowed agents from object_permission (key/team level)
|
||||
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
|
||||
user_api_key_auth=user_api_key_dict
|
||||
)
|
||||
|
||||
# If no restrictions (empty list), return all agents
|
||||
if len(allowed_agent_ids) == 0:
|
||||
returned_agents = global_agent_registry.get_agent_list()
|
||||
else:
|
||||
# Filter agents by allowed IDs
|
||||
all_agents = global_agent_registry.get_agent_list()
|
||||
returned_agents = [
|
||||
agent for agent in all_agents if agent.agent_id in allowed_agent_ids
|
||||
]
|
||||
|
||||
# Fetch current spend from DB for all returned agents
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is not None:
|
||||
agent_ids = [agent.agent_id for agent in returned_agents]
|
||||
if agent_ids:
|
||||
db_agents = await prisma_client.db.litellm_agentstable.find_many(
|
||||
where={"agent_id": {"in": agent_ids}},
|
||||
)
|
||||
spend_map = {a.agent_id: a.spend for a in db_agents}
|
||||
for agent in returned_agents:
|
||||
if agent.agent_id in spend_map:
|
||||
agent.spend = spend_map[agent.agent_id]
|
||||
|
||||
# add is_public field to each agent - we do it this way, to allow setting config agents as public
|
||||
for agent in returned_agents:
|
||||
if agent.litellm_params is None:
|
||||
agent.litellm_params = {}
|
||||
agent.litellm_params[
|
||||
"is_public"
|
||||
] = litellm.public_agent_groups is not None and (
|
||||
agent.agent_id in litellm.public_agent_groups
|
||||
)
|
||||
|
||||
if health_check:
|
||||
agents_with_url = [
|
||||
agent
|
||||
for agent in returned_agents
|
||||
if (agent.agent_card_params or {}).get("url")
|
||||
]
|
||||
agents_without_url = [
|
||||
agent
|
||||
for agent in returned_agents
|
||||
if not (agent.agent_card_params or {}).get("url")
|
||||
]
|
||||
try:
|
||||
health_results = await asyncio.wait_for(
|
||||
asyncio.gather(
|
||||
*[_check_agent_url_health(agent) for agent in agents_with_url]
|
||||
),
|
||||
timeout=AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
verbose_proxy_logger.warning(
|
||||
"Agent health check gather timed out after %s seconds",
|
||||
AGENT_HEALTH_CHECK_GATHER_TIMEOUT_SECONDS,
|
||||
)
|
||||
health_results = [
|
||||
{
|
||||
"agent_id": agent.agent_id,
|
||||
"healthy": False,
|
||||
"error": "Health check timed out",
|
||||
}
|
||||
for agent in agents_with_url
|
||||
]
|
||||
healthy_ids = {
|
||||
result["agent_id"] for result in health_results if result["healthy"]
|
||||
}
|
||||
returned_agents = [
|
||||
agent for agent in agents_with_url if agent.agent_id in healthy_ids
|
||||
] + agents_without_url
|
||||
|
||||
return returned_agents
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.agent_endpoints.get_agents(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
#### CRUD ENDPOINTS FOR AGENTS ####
|
||||
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry as AGENT_REGISTRY,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/agents",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def create_agent(
|
||||
request: AgentConfig,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/agents" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent": {
|
||||
"agent_name": "my-custom-agent",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "Hello World Agent",
|
||||
"description": "Just a hello world agent",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.0.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {
|
||||
"streaming": true
|
||||
},
|
||||
"skills": [
|
||||
{
|
||||
"id": "hello_world",
|
||||
"name": "Returns hello world",
|
||||
"description": "just returns hello world",
|
||||
"tags": ["hello world"],
|
||||
"examples": ["hi", "hello world"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"litellm_params": {
|
||||
"make_public": true
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
# Get the user ID from the API key auth
|
||||
created_by = user_api_key_dict.user_id or "unknown"
|
||||
|
||||
# check for naming conflicts
|
||||
existing_agent = AGENT_REGISTRY.get_agent_by_name(
|
||||
agent_name=request.get("agent_name") # type: ignore
|
||||
)
|
||||
if existing_agent is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Agent with name {request.get('agent_name')} already exists",
|
||||
)
|
||||
|
||||
result = await AGENT_REGISTRY.add_agent_to_db(
|
||||
agent=request, prisma_client=prisma_client, created_by=created_by
|
||||
)
|
||||
|
||||
agent_name = result.agent_name
|
||||
agent_id = result.agent_id
|
||||
|
||||
# Also register in memory
|
||||
try:
|
||||
AGENT_REGISTRY.register_agent(agent_config=result)
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully registered agent '{agent_name}' (ID: {agent_id}) in memory"
|
||||
)
|
||||
except Exception as reg_error:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to register agent '{agent_name}' (ID: {agent_id}) in memory: {reg_error}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding agent to db: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def get_agent_by_id(
|
||||
agent_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get a specific agent by ID
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
if agent_row is not None:
|
||||
agent_dict = agent_row.model_dump()
|
||||
if agent_row.object_permission is not None:
|
||||
try:
|
||||
agent_dict[
|
||||
"object_permission"
|
||||
] = agent_row.object_permission.model_dump()
|
||||
except Exception:
|
||||
agent_dict[
|
||||
"object_permission"
|
||||
] = agent_row.object_permission.dict()
|
||||
agent = AgentResponse(**agent_dict) # type: ignore
|
||||
else:
|
||||
# Agent found in memory — refresh spend from DB
|
||||
db_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if db_row is not None:
|
||||
agent.spend = db_row.spend
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
return agent
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting agent from db: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def update_agent(
|
||||
agent_id: str,
|
||||
request: AgentConfig,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent": {
|
||||
"agent_name": "updated-agent",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "Updated Agent",
|
||||
"description": "Updated description",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.1.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {
|
||||
"streaming": true
|
||||
},
|
||||
"skills": []
|
||||
},
|
||||
"litellm_params": {
|
||||
"make_public": false
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict(existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
# Get the user ID from the API key auth
|
||||
updated_by = user_api_key_dict.user_id or "unknown"
|
||||
|
||||
result = await AGENT_REGISTRY.update_agent_in_db(
|
||||
agent_id=agent_id,
|
||||
agent=request,
|
||||
prisma_client=prisma_client,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
|
||||
# deregister in memory
|
||||
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||
# register in memory
|
||||
AGENT_REGISTRY.register_agent(agent_config=result)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory"
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating agent: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentResponse,
|
||||
)
|
||||
async def patch_agent(
|
||||
agent_id: str,
|
||||
request: PatchAgentRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent": {
|
||||
"agent_name": "updated-agent",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "Updated Agent",
|
||||
"description": "Updated description",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.1.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {
|
||||
"streaming": true
|
||||
},
|
||||
"skills": []
|
||||
},
|
||||
"litellm_params": {
|
||||
"make_public": false
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict(existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
# Get the user ID from the API key auth
|
||||
updated_by = user_api_key_dict.user_id or "unknown"
|
||||
|
||||
result = await AGENT_REGISTRY.patch_agent_in_db(
|
||||
agent_id=agent_id,
|
||||
agent=request,
|
||||
prisma_client=prisma_client,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
|
||||
# deregister in memory
|
||||
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||
# register in memory
|
||||
AGENT_REGISTRY.register_agent(agent_config=result)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory"
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating agent: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/v1/agents/{agent_id}",
|
||||
tags=["Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_agent(
|
||||
agent_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete an agent
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"message": "Agent 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
|
||||
}
|
||||
```
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
_check_agent_management_permission(user_api_key_dict)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
existing_agent = dict[Any, Any](existing_agent)
|
||||
|
||||
if existing_agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found in DB."
|
||||
)
|
||||
|
||||
await AGENT_REGISTRY.delete_agent_from_db(
|
||||
agent_id=agent_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||
|
||||
return {"message": f"Agent {agent_id} deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting agent: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/agents/{agent_id}/make_public",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentMakePublicResponse,
|
||||
)
|
||||
async def make_agent_public(
|
||||
agent_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Make an agent publicly discoverable
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/agents/123e4567-e89b-12d3-a456-426614174000/make_public" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"agent_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"agent_name": "my-custom-agent",
|
||||
"litellm_params": {
|
||||
"make_public": true
|
||||
},
|
||||
"agent_card_params": {...},
|
||||
"created_at": "2025-11-15T10:30:00Z",
|
||||
"updated_at": "2025-11-15T10:35:00Z",
|
||||
"created_by": "user123",
|
||||
"updated_by": "user123"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Update the public model groups
|
||||
import litellm
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry as AGENT_REGISTRY,
|
||||
)
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
# Check if user has admin permissions
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only proxy admins can update public model groups. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
# check if agent exists in DB
|
||||
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if agent is not None:
|
||||
agent = AgentResponse(**agent.model_dump()) # type: ignore
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
if litellm.public_agent_groups is None:
|
||||
litellm.public_agent_groups = []
|
||||
# handle duplicates
|
||||
if agent.agent_id in litellm.public_agent_groups:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Agent with name {agent.agent_name} already in public agent groups",
|
||||
)
|
||||
litellm.public_agent_groups.append(agent.agent_id)
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# Update config with new settings
|
||||
if "litellm_settings" not in config or config["litellm_settings"] is None:
|
||||
config["litellm_settings"] = {}
|
||||
|
||||
config["litellm_settings"]["public_agent_groups"] = litellm.public_agent_groups
|
||||
|
||||
# Save the updated config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updated public agent groups to: {litellm.public_agent_groups} by user: {user_api_key_dict.user_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Successfully updated public agent groups",
|
||||
"public_agent_groups": litellm.public_agent_groups,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error making agent public: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/agents/make_public",
|
||||
tags=["[beta] A2A Agents"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AgentMakePublicResponse,
|
||||
)
|
||||
async def make_agents_public(
|
||||
request: MakeAgentsPublicRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Make multiple agents publicly discoverable
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/agents/make_public" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"agent_ids": ["123e4567-e89b-12d3-a456-426614174000", "123e4567-e89b-12d3-a456-426614174001"]
|
||||
}'
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"agent_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"agent_name": "my-custom-agent",
|
||||
"litellm_params": {
|
||||
"make_public": true
|
||||
},
|
||||
"agent_card_params": {...},
|
||||
"created_at": "2025-11-15T10:30:00Z",
|
||||
"updated_at": "2025-11-15T10:35:00Z",
|
||||
"created_by": "user123",
|
||||
"updated_by": "user123"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
# Update the public model groups
|
||||
import litellm
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry as AGENT_REGISTRY,
|
||||
)
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
# Check if user has admin permissions
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Only proxy admins can update public model groups. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if litellm.public_agent_groups is None:
|
||||
litellm.public_agent_groups = []
|
||||
|
||||
for agent_id in request.agent_ids:
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
# check if agent exists in DB
|
||||
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if agent is not None:
|
||||
agent = AgentResponse(**agent.model_dump()) # type: ignore
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
|
||||
litellm.public_agent_groups = request.agent_ids
|
||||
|
||||
# Update config with new settings
|
||||
if "litellm_settings" not in config or config["litellm_settings"] is None:
|
||||
config["litellm_settings"] = {}
|
||||
|
||||
config["litellm_settings"]["public_agent_groups"] = litellm.public_agent_groups
|
||||
|
||||
# Save the updated config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updated public agent groups to: {litellm.public_agent_groups} by user: {user_api_key_dict.user_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Successfully updated public agent groups",
|
||||
"public_agent_groups": litellm.public_agent_groups,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error making agent public: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agent/daily/activity",
|
||||
tags=["Agent Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
async def get_agent_daily_activity(
|
||||
agent_ids: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
exclude_agent_ids: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get daily activity for specific agents or all accessible agents.
|
||||
"""
|
||||
await check_feature_access_for_user(user_api_key_dict, "agents")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
agent_ids_list = agent_ids.split(",") if agent_ids else None
|
||||
exclude_agent_ids_list: Optional[List[str]] = None
|
||||
if exclude_agent_ids:
|
||||
exclude_agent_ids_list = (
|
||||
exclude_agent_ids.split(",") if exclude_agent_ids else None
|
||||
)
|
||||
|
||||
where_condition = {}
|
||||
if agent_ids_list:
|
||||
where_condition["agent_id"] = {"in": list(agent_ids_list)}
|
||||
|
||||
agent_records = await prisma_client.db.litellm_agentstable.find_many(
|
||||
where=where_condition
|
||||
)
|
||||
agent_metadata = {
|
||||
agent.agent_id: {"agent_name": agent.agent_name} for agent in agent_records
|
||||
}
|
||||
|
||||
return await get_daily_activity(
|
||||
prisma_client=prisma_client,
|
||||
table_name="litellm_dailyagentspend",
|
||||
entity_id_field="agent_id",
|
||||
entity_id=agent_ids_list,
|
||||
entity_metadata_field=agent_metadata,
|
||||
exclude_entity_ids=exclude_agent_ids_list,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Helper functions for appending A2A agents to model lists.
|
||||
|
||||
Used by proxy model endpoints to make agents appear in UI alongside models.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
||||
ModelGroupInfoProxy,
|
||||
)
|
||||
|
||||
|
||||
async def append_agents_to_model_group(
|
||||
model_groups: List[ModelGroupInfoProxy],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> List[ModelGroupInfoProxy]:
|
||||
"""
|
||||
Append A2A agents to model groups list for UI display.
|
||||
|
||||
Converts agents to model format with "a2a/<agent-name>" naming
|
||||
so they appear in playground and work with LiteLLM routing.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
|
||||
user_api_key_auth=user_api_key_dict
|
||||
)
|
||||
|
||||
for agent_id in allowed_agent_ids:
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id)
|
||||
if agent is not None:
|
||||
model_groups.append(
|
||||
ModelGroupInfoProxy(
|
||||
model_group=f"a2a/{agent.agent_name}",
|
||||
mode="chat",
|
||||
providers=["a2a"],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error appending agents to model_group/info: {e}")
|
||||
|
||||
return model_groups
|
||||
|
||||
|
||||
async def append_agents_to_model_info(
|
||||
models: List[dict],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Append A2A agents to model info list for UI display.
|
||||
|
||||
Converts agents to model format with "a2a/<agent-name>" naming
|
||||
so they appear in models page and work with LiteLLM routing.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||
from litellm.proxy.agent_endpoints.auth.agent_permission_handler import (
|
||||
AgentRequestHandler,
|
||||
)
|
||||
|
||||
allowed_agent_ids = await AgentRequestHandler.get_allowed_agents(
|
||||
user_api_key_auth=user_api_key_dict
|
||||
)
|
||||
|
||||
for agent_id in allowed_agent_ids:
|
||||
agent = global_agent_registry.get_agent_by_id(agent_id)
|
||||
if agent is not None:
|
||||
models.append(
|
||||
{
|
||||
"model_name": f"a2a/{agent.agent_name}",
|
||||
"litellm_params": {
|
||||
"model": f"a2a/{agent.agent_name}",
|
||||
"custom_llm_provider": "a2a",
|
||||
},
|
||||
"model_info": {
|
||||
"id": agent.agent_id,
|
||||
"mode": "chat",
|
||||
"db_model": True,
|
||||
"created_by": agent.created_by,
|
||||
"created_at": agent.created_at,
|
||||
"updated_at": agent.updated_at,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error appending agents to v2/model/info: {e}")
|
||||
|
||||
return models
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Utility helpers for A2A agent endpoints."""
|
||||
|
||||
from typing import Dict, Mapping, Optional
|
||||
|
||||
|
||||
def merge_agent_headers(
|
||||
*,
|
||||
dynamic_headers: Optional[Mapping[str, str]] = None,
|
||||
static_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Merge outbound HTTP headers for A2A agent calls.
|
||||
|
||||
Merge rules:
|
||||
- Start with ``dynamic_headers`` (values extracted from the incoming client request).
|
||||
- Overlay ``static_headers`` (admin-configured per agent).
|
||||
|
||||
If both contain the same key, ``static_headers`` wins.
|
||||
"""
|
||||
merged: Dict[str, str] = {}
|
||||
|
||||
if dynamic_headers:
|
||||
merged.update({str(k): str(v) for k, v in dynamic_headers.items()})
|
||||
|
||||
if static_headers:
|
||||
merged.update({str(k): str(v) for k, v in static_headers.items()})
|
||||
|
||||
return merged or None
|
||||
@@ -0,0 +1,106 @@
|
||||
#### Analytics Endpoints #####
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global/activity/cache_hits",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
responses={
|
||||
200: {"model": List[LiteLLM_SpendLogs]},
|
||||
},
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_global_activity(
|
||||
start_date: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Time from which to start viewing spend",
|
||||
),
|
||||
end_date: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Time till which to view spend",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Get number of cache hits, vs misses
|
||||
|
||||
{
|
||||
"daily_data": [
|
||||
const chartdata = [
|
||||
{
|
||||
date: 'Jan 22',
|
||||
cache_hits: 10,
|
||||
llm_api_calls: 2000
|
||||
},
|
||||
{
|
||||
date: 'Jan 23',
|
||||
cache_hits: 10,
|
||||
llm_api_calls: 12
|
||||
},
|
||||
],
|
||||
"sum_cache_hits": 20,
|
||||
"sum_llm_api_calls": 2012
|
||||
}
|
||||
"""
|
||||
|
||||
if start_date is None or end_date is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Please provide start_date and end_date"},
|
||||
)
|
||||
|
||||
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise ValueError(
|
||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
CASE
|
||||
WHEN vt."key_alias" IS NOT NULL THEN vt."key_alias"
|
||||
ELSE 'Unnamed Key'
|
||||
END AS api_key,
|
||||
sl."call_type",
|
||||
sl."model",
|
||||
COUNT(*) AS total_rows,
|
||||
SUM(CASE WHEN sl."cache_hit" = 'True' THEN 1 ELSE 0 END) AS cache_hit_true_rows,
|
||||
SUM(CASE WHEN sl."cache_hit" = 'True' THEN sl."completion_tokens" ELSE 0 END) AS cached_completion_tokens,
|
||||
SUM(CASE WHEN sl."cache_hit" != 'True' THEN sl."completion_tokens" ELSE 0 END) AS generated_completion_tokens
|
||||
FROM "LiteLLM_SpendLogs" sl
|
||||
LEFT JOIN "LiteLLM_VerificationToken" vt ON sl."api_key" = vt."token"
|
||||
WHERE
|
||||
sl."startTime" >= $1::timestamptz AND "startTime" < ($2::timestamptz + INTERVAL \'1 day\')
|
||||
GROUP BY
|
||||
vt."key_alias",
|
||||
sl."call_type",
|
||||
sl."model"
|
||||
"""
|
||||
db_response = await prisma_client.db.query_raw(
|
||||
sql_query, start_date_obj, end_date_obj
|
||||
)
|
||||
|
||||
if db_response is None:
|
||||
return []
|
||||
|
||||
return db_response
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Claude Code Endpoints
|
||||
|
||||
Provides endpoints for Claude Code plugin marketplace integration.
|
||||
"""
|
||||
|
||||
from litellm.proxy.anthropic_endpoints.claude_code_endpoints.claude_code_marketplace import (
|
||||
router as claude_code_marketplace_router,
|
||||
)
|
||||
|
||||
__all__ = ["claude_code_marketplace_router"]
|
||||
@@ -0,0 +1,546 @@
|
||||
"""
|
||||
CLAUDE CODE MARKETPLACE
|
||||
|
||||
Provides a registry/discovery layer for Claude Code plugins.
|
||||
Plugins are stored as metadata + git source references in LiteLLM database.
|
||||
Actual plugin files are hosted on GitHub/GitLab/Bitbucket.
|
||||
|
||||
Endpoints:
|
||||
/claude-code/marketplace.json - GET - List plugins for Claude Code discovery
|
||||
/claude-code/plugins - POST - Register a plugin
|
||||
/claude-code/plugins - GET - List plugins (admin)
|
||||
/claude-code/plugins/{name} - GET - Get plugin details
|
||||
/claude-code/plugins/{name}/enable - POST - Enable a plugin
|
||||
/claude-code/plugins/{name}/disable - POST - Disable a plugin
|
||||
/claude-code/plugins/{name} - DELETE - Delete a plugin
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.proxy.claude_code_endpoints import (
|
||||
ListPluginsResponse,
|
||||
PluginListItem,
|
||||
RegisterPluginRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_prisma_client():
|
||||
"""Get the prisma client from proxy_server."""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
return prisma_client
|
||||
|
||||
|
||||
@router.get(
|
||||
"/claude-code/marketplace.json",
|
||||
tags=["Claude Code Marketplace"],
|
||||
)
|
||||
async def get_marketplace():
|
||||
"""
|
||||
Serve marketplace.json for Claude Code plugin discovery.
|
||||
|
||||
This endpoint is accessed by Claude Code CLI when users run:
|
||||
- claude plugin marketplace add <url>
|
||||
- claude plugin install <name>@<marketplace>
|
||||
|
||||
Returns:
|
||||
Marketplace catalog with list of available plugins and their git sources.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
claude plugin marketplace add http://localhost:4000/claude-code/marketplace.json
|
||||
claude plugin install my-plugin@litellm
|
||||
```
|
||||
"""
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
|
||||
where={"enabled": True}
|
||||
)
|
||||
|
||||
plugin_list = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
manifest = json.loads(plugin.manifest_json)
|
||||
except json.JSONDecodeError:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Plugin {plugin.name} has invalid manifest JSON, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Source must be specified for URL-based marketplaces
|
||||
if "source" not in manifest:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Plugin {plugin.name} has no source field, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
entry: Dict[str, Any] = {
|
||||
"name": plugin.name,
|
||||
"source": manifest["source"],
|
||||
}
|
||||
|
||||
if plugin.version:
|
||||
entry["version"] = plugin.version
|
||||
if plugin.description:
|
||||
entry["description"] = plugin.description
|
||||
if "author" in manifest:
|
||||
entry["author"] = manifest["author"]
|
||||
if "homepage" in manifest:
|
||||
entry["homepage"] = manifest["homepage"]
|
||||
if "keywords" in manifest:
|
||||
entry["keywords"] = manifest["keywords"]
|
||||
if "category" in manifest:
|
||||
entry["category"] = manifest["category"]
|
||||
|
||||
plugin_list.append(entry)
|
||||
|
||||
marketplace = {
|
||||
"name": "litellm",
|
||||
"owner": {"name": "LiteLLM", "email": "support@litellm.ai"},
|
||||
"plugins": plugin_list,
|
||||
}
|
||||
|
||||
return JSONResponse(content=marketplace)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error generating marketplace: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to generate marketplace: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/claude-code/plugins",
|
||||
tags=["Claude Code Marketplace"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def register_plugin(
|
||||
request: RegisterPluginRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Register a plugin in the LiteLLM marketplace.
|
||||
|
||||
LiteLLM acts as a registry/discovery layer. Plugins are hosted on
|
||||
GitHub/GitLab/Bitbucket. Claude Code will clone from the git source
|
||||
when users install.
|
||||
|
||||
Parameters:
|
||||
- name: Plugin name (kebab-case)
|
||||
- source: Git source reference (github or url format)
|
||||
- version: Semantic version (optional)
|
||||
- description: Plugin description (optional)
|
||||
- author: Author information (optional)
|
||||
- homepage: Plugin homepage URL (optional)
|
||||
- keywords: Search keywords (optional)
|
||||
- category: Plugin category (optional)
|
||||
|
||||
Returns:
|
||||
Registration status and plugin information.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/claude-code/plugins \\
|
||||
-H "Authorization: Bearer sk-..." \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"name": "my-plugin",
|
||||
"source": {"source": "github", "repo": "org/my-plugin"},
|
||||
"version": "1.0.0",
|
||||
"description": "My awesome plugin"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
# Validate name format
|
||||
if not re.match(r"^[a-z0-9-]+$", request.name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Plugin name must be kebab-case (lowercase letters, numbers, hyphens)"
|
||||
},
|
||||
)
|
||||
|
||||
# Validate source format
|
||||
source = request.source
|
||||
source_type = source.get("source")
|
||||
|
||||
if source_type == "github":
|
||||
if "repo" not in source:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "GitHub source must include 'repo' field (e.g., 'org/repo')"
|
||||
},
|
||||
)
|
||||
elif source_type == "url":
|
||||
if "url" not in source:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "URL source must include 'url' field (e.g., 'https://github.com/org/repo.git')"
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "source.source must be 'github' or 'url'"},
|
||||
)
|
||||
|
||||
# Build manifest for storage
|
||||
manifest: Dict[str, Any] = {
|
||||
"name": request.name,
|
||||
"source": request.source,
|
||||
}
|
||||
if request.version:
|
||||
manifest["version"] = request.version
|
||||
if request.description:
|
||||
manifest["description"] = request.description
|
||||
if request.author:
|
||||
manifest["author"] = request.author.model_dump(exclude_none=True)
|
||||
if request.homepage:
|
||||
manifest["homepage"] = request.homepage
|
||||
if request.keywords:
|
||||
manifest["keywords"] = request.keywords
|
||||
if request.category:
|
||||
manifest["category"] = request.category
|
||||
|
||||
# Check if plugin exists
|
||||
existing = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
where={"name": request.name}
|
||||
)
|
||||
|
||||
if existing:
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.update(
|
||||
where={"name": request.name},
|
||||
data={
|
||||
"version": request.version,
|
||||
"description": request.description,
|
||||
"manifest_json": json.dumps(manifest),
|
||||
"files_json": "{}",
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
action = "updated"
|
||||
else:
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.create(
|
||||
data={
|
||||
"name": request.name,
|
||||
"version": request.version,
|
||||
"description": request.description,
|
||||
"manifest_json": json.dumps(manifest),
|
||||
"files_json": "{}",
|
||||
"enabled": True,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
}
|
||||
)
|
||||
action = "created"
|
||||
|
||||
verbose_proxy_logger.info(f"Plugin {request.name} {action} successfully")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"action": action,
|
||||
"plugin": {
|
||||
"id": plugin.id,
|
||||
"name": plugin.name,
|
||||
"version": plugin.version,
|
||||
"description": plugin.description,
|
||||
"source": request.source,
|
||||
"enabled": plugin.enabled,
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error registering plugin: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Registration failed: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/claude-code/plugins",
|
||||
tags=["Claude Code Marketplace"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ListPluginsResponse,
|
||||
)
|
||||
async def list_plugins(
|
||||
enabled_only: bool = False,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all plugins in the marketplace.
|
||||
|
||||
Parameters:
|
||||
- enabled_only: If true, only return enabled plugins
|
||||
|
||||
Returns:
|
||||
List of plugins with their metadata.
|
||||
"""
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
where = {"enabled": True} if enabled_only else {}
|
||||
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
|
||||
where=where
|
||||
)
|
||||
|
||||
plugin_list = []
|
||||
for p in plugins:
|
||||
# Parse manifest to get additional fields
|
||||
manifest = json.loads(p.manifest_json) if p.manifest_json else {}
|
||||
|
||||
plugin_list.append(
|
||||
PluginListItem(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
version=p.version,
|
||||
description=p.description,
|
||||
source=manifest.get("source", {}),
|
||||
author=manifest.get("author"),
|
||||
homepage=manifest.get("homepage"),
|
||||
keywords=manifest.get("keywords"),
|
||||
category=manifest.get("category"),
|
||||
enabled=p.enabled,
|
||||
created_at=p.created_at.isoformat() if p.created_at else None,
|
||||
updated_at=p.updated_at.isoformat() if p.updated_at else None,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by created_at descending (newest first)
|
||||
plugin_list.sort(key=lambda x: x.created_at or "", reverse=True)
|
||||
|
||||
return ListPluginsResponse(
|
||||
plugins=plugin_list,
|
||||
count=len(plugin_list),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error listing plugins: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/claude-code/plugins/{plugin_name}",
|
||||
tags=["Claude Code Marketplace"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_plugin(
|
||||
plugin_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get details of a specific plugin.
|
||||
|
||||
Parameters:
|
||||
- plugin_name: The name of the plugin
|
||||
|
||||
Returns:
|
||||
Plugin details including source and metadata.
|
||||
"""
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Plugin '{plugin_name}' not found"},
|
||||
)
|
||||
|
||||
manifest = json.loads(plugin.manifest_json) if plugin.manifest_json else {}
|
||||
|
||||
return {
|
||||
"id": plugin.id,
|
||||
"name": plugin.name,
|
||||
"version": plugin.version,
|
||||
"description": plugin.description,
|
||||
"source": manifest.get("source"),
|
||||
"author": manifest.get("author"),
|
||||
"homepage": manifest.get("homepage"),
|
||||
"keywords": manifest.get("keywords"),
|
||||
"category": manifest.get("category"),
|
||||
"enabled": plugin.enabled,
|
||||
"created_at": plugin.created_at.isoformat() if plugin.created_at else None,
|
||||
"updated_at": plugin.updated_at.isoformat() if plugin.updated_at else None,
|
||||
"created_by": plugin.created_by,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting plugin: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/claude-code/plugins/{plugin_name}/enable",
|
||||
tags=["Claude Code Marketplace"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def enable_plugin(
|
||||
plugin_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Enable a disabled plugin.
|
||||
|
||||
Parameters:
|
||||
- plugin_name: The name of the plugin to enable
|
||||
"""
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
if not plugin:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Plugin '{plugin_name}' not found"},
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_claudecodeplugintable.update(
|
||||
where={"name": plugin_name},
|
||||
data={"enabled": True, "updated_at": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(f"Plugin {plugin_name} enabled")
|
||||
return {"status": "success", "message": f"Plugin '{plugin_name}' enabled"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error enabling plugin: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/claude-code/plugins/{plugin_name}/disable",
|
||||
tags=["Claude Code Marketplace"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def disable_plugin(
|
||||
plugin_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Disable a plugin without deleting it.
|
||||
|
||||
Parameters:
|
||||
- plugin_name: The name of the plugin to disable
|
||||
"""
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
if not plugin:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Plugin '{plugin_name}' not found"},
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_claudecodeplugintable.update(
|
||||
where={"name": plugin_name},
|
||||
data={"enabled": False, "updated_at": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(f"Plugin {plugin_name} disabled")
|
||||
return {"status": "success", "message": f"Plugin '{plugin_name}' disabled"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error disabling plugin: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/claude-code/plugins/{plugin_name}",
|
||||
tags=["Claude Code Marketplace"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_plugin(
|
||||
plugin_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete a plugin from the marketplace.
|
||||
|
||||
Parameters:
|
||||
- plugin_name: The name of the plugin to delete
|
||||
"""
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
if not plugin:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Plugin '{plugin_name}' not found"},
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_claudecodeplugintable.delete(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(f"Plugin {plugin_name} deleted")
|
||||
return {"status": "success", "message": f"Plugin '{plugin_name}' deleted"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting plugin: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Unified /v1/messages endpoint - (Anthropic Spec)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.anthropic_interface.exceptions import AnthropicExceptionMapping
|
||||
from litellm.integrations.custom_guardrail import ModifyResponseException
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
create_response,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.types.utils import TokenCountResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
tags=["[beta] Anthropic `/v1/messages`"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def anthropic_response( # noqa: PLR0915
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/pass_through/anthropic_completion).
|
||||
|
||||
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
data = await _read_request_body(request=request)
|
||||
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
result = await base_llm_response_processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="anthropic_messages",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=None,
|
||||
model=None,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
return result
|
||||
except ModifyResponseException as e:
|
||||
# Guardrail flagged content in passthrough mode - return 200 with violation message
|
||||
_data = e.request_data
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=_data,
|
||||
)
|
||||
|
||||
# Create Anthropic-formatted response with violation message
|
||||
import uuid
|
||||
|
||||
from litellm.types.utils import AnthropicMessagesResponse
|
||||
|
||||
_anthropic_response = AnthropicMessagesResponse(
|
||||
id=f"msg_{str(uuid.uuid4())}",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[{"type": "text", "text": e.message}],
|
||||
model=e.model,
|
||||
stop_reason="end_turn",
|
||||
usage={"input_tokens": 0, "output_tokens": 0},
|
||||
)
|
||||
|
||||
if data.get("stream", None) is not None and data["stream"] is True:
|
||||
# For streaming, use the standard SSE data generator
|
||||
async def _passthrough_stream_generator():
|
||||
yield _anthropic_response
|
||||
|
||||
selected_data_generator = (
|
||||
ProxyBaseLLMRequestProcessing.async_sse_data_generator(
|
||||
response=_passthrough_stream_generator(),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=_data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
return await create_response(
|
||||
generator=selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers={},
|
||||
)
|
||||
|
||||
return _anthropic_response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
# Extract model_id from request metadata (same as success path)
|
||||
litellm_metadata = data.get("litellm_metadata", {}) or {}
|
||||
model_info = litellm_metadata.get("model_info", {}) or {}
|
||||
model_id = model_info.get("id", "") or ""
|
||||
|
||||
# Get headers
|
||||
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=data.get("litellm_call_id", ""),
|
||||
model_id=model_id,
|
||||
version=version,
|
||||
response_cost=0,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
timeout=getattr(e, "timeout", None),
|
||||
litellm_logging_obj=None,
|
||||
)
|
||||
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages/count_tokens",
|
||||
tags=["[beta] Anthropic Messages Token Counting"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def count_tokens(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), # Used for auth
|
||||
):
|
||||
"""
|
||||
Count tokens for Anthropic Messages API format.
|
||||
|
||||
This endpoint follows the Anthropic Messages API token counting specification.
|
||||
It accepts the same parameters as the /v1/messages endpoint but returns
|
||||
token counts instead of generating a response.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
curl -X POST "http://localhost:4000/v1/messages/count_tokens?beta=true" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-d '{
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"messages": [{"role": "user", "content": "Hello Claude!"}]
|
||||
}'
|
||||
```
|
||||
|
||||
Returns: {"input_tokens": <number>}
|
||||
"""
|
||||
from litellm.proxy.proxy_server import token_counter as internal_token_counter
|
||||
|
||||
try:
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data}
|
||||
|
||||
# Extract required fields
|
||||
model_name = data.get("model")
|
||||
messages = data.get("messages", [])
|
||||
|
||||
if not model_name:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "model parameter is required"}
|
||||
)
|
||||
|
||||
if not messages:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "messages parameter is required"}
|
||||
)
|
||||
|
||||
# Create TokenCountRequest for the internal endpoint
|
||||
from litellm.proxy._types import TokenCountRequest
|
||||
|
||||
token_request = TokenCountRequest(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=data.get("tools"),
|
||||
system=data.get("system"),
|
||||
)
|
||||
|
||||
# Call the internal token counter function with direct request flag set to False
|
||||
token_response = await internal_token_counter(
|
||||
request=token_request,
|
||||
call_endpoint=True,
|
||||
)
|
||||
_token_response_dict: dict = {}
|
||||
if isinstance(token_response, TokenCountResponse):
|
||||
_token_response_dict = token_response.model_dump()
|
||||
elif isinstance(token_response, dict):
|
||||
_token_response_dict = token_response
|
||||
|
||||
# Convert the internal response to Anthropic API format
|
||||
return {"input_tokens": _token_response_dict.get("total_tokens", 0)}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ProxyException as e:
|
||||
status_code = int(e.code) if e.code and e.code.isdigit() else 500
|
||||
detail = AnthropicExceptionMapping.transform_to_anthropic_error(
|
||||
status_code=status_code,
|
||||
raw_message=e.message,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail=detail,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/api/event_logging/batch",
|
||||
tags=["[beta] Anthropic Event Logging"],
|
||||
)
|
||||
async def event_logging_batch(
|
||||
request: Request,
|
||||
):
|
||||
"""
|
||||
Stubbed endpoint for Anthropic event logging batch requests.
|
||||
|
||||
This endpoint accepts event logging requests but does nothing with them.
|
||||
It exists to prevent 404 errors from Claude Code clients that send telemetry.
|
||||
"""
|
||||
return {"status": "ok"}
|
||||
@@ -0,0 +1,437 @@
|
||||
"""
|
||||
Anthropic Skills API endpoints - /v1/skills
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
convert_upload_files_to_file_data,
|
||||
get_form_data,
|
||||
)
|
||||
from litellm.types.llms.anthropic_skills import (
|
||||
DeleteSkillResponse,
|
||||
ListSkillsResponse,
|
||||
Skill,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/skills",
|
||||
tags=["[beta] Anthropic Skills API"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=Skill,
|
||||
)
|
||||
async def create_skill(
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
custom_llm_provider: Optional[str] = "anthropic",
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new skill on Anthropic.
|
||||
|
||||
Requires `?beta=true` query parameter.
|
||||
|
||||
Model-based routing (for multi-account support):
|
||||
- Pass model via header: `x-litellm-model: claude-account-1`
|
||||
- Pass model via query: `?model=claude-account-1`
|
||||
- Pass model via form field: `model=claude-account-1`
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Basic usage
|
||||
curl -X POST "http://localhost:4000/v1/skills?beta=true" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-F "display_title=My Skill" \
|
||||
-F "files[]=@skill.zip"
|
||||
|
||||
# With model-based routing
|
||||
curl -X POST "http://localhost:4000/v1/skills?beta=true" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-H "x-litellm-model: claude-account-1" \
|
||||
-F "display_title=My Skill" \
|
||||
-F "files[]=@skill.zip"
|
||||
```
|
||||
|
||||
Returns: Skill object with id, display_title, etc.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
# Read form data and convert UploadFile objects to file data tuples
|
||||
form_data = await get_form_data(request)
|
||||
data = await convert_upload_files_to_file_data(form_data)
|
||||
|
||||
# Extract model for routing (header > query > body)
|
||||
model = (
|
||||
data.get("model")
|
||||
or request.query_params.get("model")
|
||||
or request.headers.get("x-litellm-model")
|
||||
)
|
||||
if model:
|
||||
data["model"] = model
|
||||
|
||||
if "custom_llm_provider" not in data:
|
||||
data["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# Process request using ProxyBaseLLMRequestProcessing
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="acreate_skill",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=data.get("model"),
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/skills",
|
||||
tags=["[beta] Anthropic Skills API"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ListSkillsResponse,
|
||||
)
|
||||
async def list_skills(
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
limit: Optional[int] = 10,
|
||||
after_id: Optional[str] = None,
|
||||
before_id: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = "anthropic",
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List skills on Anthropic.
|
||||
|
||||
Requires `?beta=true` query parameter.
|
||||
|
||||
Model-based routing (for multi-account support):
|
||||
- Pass model via header: `x-litellm-model: claude-account-1`
|
||||
- Pass model via query: `?model=claude-account-1`
|
||||
- Pass model via body: `{"model": "claude-account-1"}`
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Basic usage
|
||||
curl "http://localhost:4000/v1/skills?beta=true&limit=10" \
|
||||
-H "Authorization: Bearer your-key"
|
||||
|
||||
# With model-based routing
|
||||
curl "http://localhost:4000/v1/skills?beta=true&limit=10" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-H "x-litellm-model: claude-account-1"
|
||||
```
|
||||
|
||||
Returns: ListSkillsResponse with list of skills
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
# Read request body
|
||||
body = await request.body()
|
||||
data = orjson.loads(body) if body else {}
|
||||
|
||||
# Use query params if not in body
|
||||
if "limit" not in data and limit is not None:
|
||||
data["limit"] = limit
|
||||
if "after_id" not in data and after_id is not None:
|
||||
data["after_id"] = after_id
|
||||
if "before_id" not in data and before_id is not None:
|
||||
data["before_id"] = before_id
|
||||
|
||||
# Extract model for routing (header > query > body)
|
||||
model = (
|
||||
data.get("model")
|
||||
or request.query_params.get("model")
|
||||
or request.headers.get("x-litellm-model")
|
||||
)
|
||||
if model:
|
||||
data["model"] = model
|
||||
|
||||
# Set custom_llm_provider: body > query param > default
|
||||
if "custom_llm_provider" not in data:
|
||||
data["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# Process request using ProxyBaseLLMRequestProcessing
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="alist_skills",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=data.get("model"),
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/skills/{skill_id}",
|
||||
tags=["[beta] Anthropic Skills API"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=Skill,
|
||||
)
|
||||
async def get_skill(
|
||||
skill_id: str,
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
custom_llm_provider: Optional[str] = "anthropic",
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get a specific skill by ID from Anthropic.
|
||||
|
||||
Requires `?beta=true` query parameter.
|
||||
|
||||
Model-based routing (for multi-account support):
|
||||
- Pass model via header: `x-litellm-model: claude-account-1`
|
||||
- Pass model via query: `?model=claude-account-1`
|
||||
- Pass model via body: `{"model": "claude-account-1"}`
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Basic usage
|
||||
curl "http://localhost:4000/v1/skills/skill_123?beta=true" \
|
||||
-H "Authorization: Bearer your-key"
|
||||
|
||||
# With model-based routing
|
||||
curl "http://localhost:4000/v1/skills/skill_123?beta=true" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-H "x-litellm-model: claude-account-1"
|
||||
```
|
||||
|
||||
Returns: Skill object
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
# Read request body
|
||||
body = await request.body()
|
||||
data = orjson.loads(body) if body else {}
|
||||
|
||||
# Set skill_id from path parameter
|
||||
data["skill_id"] = skill_id
|
||||
|
||||
# Extract model for routing (header > query > body)
|
||||
model = (
|
||||
data.get("model")
|
||||
or request.query_params.get("model")
|
||||
or request.headers.get("x-litellm-model")
|
||||
)
|
||||
if model:
|
||||
data["model"] = model
|
||||
|
||||
# Set custom_llm_provider: body > query param > default
|
||||
if "custom_llm_provider" not in data:
|
||||
data["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# Process request using ProxyBaseLLMRequestProcessing
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="aget_skill",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=data.get("model"),
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/v1/skills/{skill_id}",
|
||||
tags=["[beta] Anthropic Skills API"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=DeleteSkillResponse,
|
||||
)
|
||||
async def delete_skill(
|
||||
skill_id: str,
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
custom_llm_provider: Optional[str] = "anthropic",
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete a skill by ID from Anthropic.
|
||||
|
||||
Requires `?beta=true` query parameter.
|
||||
|
||||
Note: Anthropic does not allow deleting skills with existing versions.
|
||||
|
||||
Model-based routing (for multi-account support):
|
||||
- Pass model via header: `x-litellm-model: claude-account-1`
|
||||
- Pass model via query: `?model=claude-account-1`
|
||||
- Pass model via body: `{"model": "claude-account-1"}`
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Basic usage
|
||||
curl -X DELETE "http://localhost:4000/v1/skills/skill_123?beta=true" \
|
||||
-H "Authorization: Bearer your-key"
|
||||
|
||||
# With model-based routing
|
||||
curl -X DELETE "http://localhost:4000/v1/skills/skill_123?beta=true" \
|
||||
-H "Authorization: Bearer your-key" \
|
||||
-H "x-litellm-model: claude-account-1"
|
||||
```
|
||||
|
||||
Returns: DeleteSkillResponse with type="skill_deleted"
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
# Read request body
|
||||
body = await request.body()
|
||||
data = orjson.loads(body) if body else {}
|
||||
|
||||
# Set skill_id from path parameter
|
||||
data["skill_id"] = skill_id
|
||||
|
||||
# Extract model for routing (header > query > body)
|
||||
model = (
|
||||
data.get("model")
|
||||
or request.query_params.get("model")
|
||||
or request.headers.get("x-litellm-model")
|
||||
)
|
||||
if model:
|
||||
data["model"] = model
|
||||
|
||||
# Set custom_llm_provider: body > query param > default
|
||||
if "custom_llm_provider" not in data:
|
||||
data["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# Process request using ProxyBaseLLMRequestProcessing
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="adelete_skill",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=data.get("model"),
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Auth Checks for Organizations
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import status
|
||||
|
||||
from litellm.proxy._types import *
|
||||
|
||||
|
||||
def organization_role_based_access_check(
|
||||
request_body: dict,
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
route: str,
|
||||
):
|
||||
"""
|
||||
Role based access control checks only run if a user is part of an Organization
|
||||
|
||||
Organization Checks:
|
||||
ONLY RUN IF user_object.organization_memberships is not None
|
||||
|
||||
1. Only Proxy Admins can access /organization/new
|
||||
2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization
|
||||
|
||||
"""
|
||||
|
||||
if user_object is None:
|
||||
return
|
||||
|
||||
passed_organization_id: Optional[str] = request_body.get("organization_id", None)
|
||||
|
||||
if route == "/organization/new":
|
||||
if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"Only proxy admins can create new organizations. You are {user_object.user_role}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return
|
||||
|
||||
# Checks if route is an Org Admin Only Route
|
||||
if route in LiteLLMRoutes.org_admin_only_routes.value:
|
||||
(
|
||||
_user_organizations,
|
||||
_user_organization_role_mapping,
|
||||
) = get_user_organization_info(user_object)
|
||||
|
||||
if user_object.organization_memberships is None:
|
||||
raise ProxyException(
|
||||
message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if passed_organization_id is None:
|
||||
raise ProxyException(
|
||||
message="Passed organization_id is None, please pass an organization_id in your request",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get(
|
||||
passed_organization_id
|
||||
)
|
||||
if user_role is None:
|
||||
raise ProxyException(
|
||||
message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if user_role != LitellmUserRoles.ORG_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
elif route == "/team/new":
|
||||
# if user is part of multiple teams, then they need to specify the organization_id
|
||||
(
|
||||
_user_organizations,
|
||||
_user_organization_role_mapping,
|
||||
) = get_user_organization_info(user_object)
|
||||
if (
|
||||
user_object.organization_memberships is not None
|
||||
and len(user_object.organization_memberships) > 0
|
||||
):
|
||||
if passed_organization_id is None:
|
||||
raise ProxyException(
|
||||
message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
_user_role_in_passed_org = _user_organization_role_mapping.get(
|
||||
passed_organization_id
|
||||
)
|
||||
if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
def get_user_organization_info(
|
||||
user_object: LiteLLM_UserTable,
|
||||
) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]:
|
||||
"""
|
||||
Helper function to extract user organization information.
|
||||
|
||||
Args:
|
||||
user_object (LiteLLM_UserTable): The user object containing organization memberships.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing:
|
||||
- List of organization IDs the user is a member of
|
||||
- Dictionary mapping organization IDs to user roles
|
||||
"""
|
||||
_user_organizations: List[str] = []
|
||||
_user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {}
|
||||
|
||||
if user_object.organization_memberships is not None:
|
||||
for _membership in user_object.organization_memberships:
|
||||
if _membership.organization_id is not None:
|
||||
_user_organizations.append(_membership.organization_id)
|
||||
_user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore
|
||||
|
||||
return _user_organizations, _user_organization_role_mapping
|
||||
|
||||
|
||||
def _user_is_org_admin(
|
||||
request_data: dict,
|
||||
user_object: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to check if user is an org admin for any of the passed organizations.
|
||||
|
||||
Checks both:
|
||||
- `organization_id` (singular string) — legacy callers
|
||||
- `organizations` (list of strings) — used by /user/new
|
||||
"""
|
||||
if user_object is None:
|
||||
return False
|
||||
|
||||
if user_object.organization_memberships is None:
|
||||
return False
|
||||
|
||||
# Collect candidate org IDs from both fields
|
||||
candidate_org_ids: List[str] = []
|
||||
singular = request_data.get("organization_id", None)
|
||||
if singular is not None:
|
||||
candidate_org_ids.append(singular)
|
||||
orgs_list = request_data.get("organizations", None)
|
||||
if isinstance(orgs_list, list):
|
||||
candidate_org_ids.extend(orgs_list)
|
||||
|
||||
if not candidate_org_ids:
|
||||
return False
|
||||
|
||||
for _membership in user_object.organization_memberships:
|
||||
if _membership.organization_id in candidate_org_ids:
|
||||
if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Handles Authentication Errors
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class UserAPIKeyAuthExceptionHandler:
|
||||
@staticmethod
|
||||
async def _handle_authentication_error(
|
||||
e: Exception,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
parent_otel_span: Optional[Span],
|
||||
api_key: str,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handles Connection Errors when reading a Virtual Key from LiteLLM DB
|
||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||
|
||||
Reliability scenarios this covers:
|
||||
- DB is down and having an outage
|
||||
- Unable to read / recover a key from the DB
|
||||
|
||||
Returns:
|
||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||
|
||||
Raises:
|
||||
- Original Exception in all other cases
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
if (
|
||||
PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
|
||||
and PrismaDBExceptionHandler.is_database_connection_error(e)
|
||||
):
|
||||
# log this as a DB failure on prometheus
|
||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type="get_key_object",
|
||||
error=e,
|
||||
duration=0.0,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
key_name="failed-to-connect-to-db",
|
||||
token="failed-to-connect-to-db",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
request_route=route,
|
||||
)
|
||||
else:
|
||||
# raise the exception to the caller
|
||||
requester_ip = _get_request_ip_address(
|
||||
request=request,
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
|
||||
str(e),
|
||||
requester_ip,
|
||||
),
|
||||
extra={"requester_ip": requester_ip},
|
||||
)
|
||||
|
||||
# Log this exception to OTEL, Datadog etc
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
parent_otel_span=parent_otel_span,
|
||||
api_key=api_key,
|
||||
request_route=route,
|
||||
)
|
||||
# Allow callbacks to transform the error response
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route=route,
|
||||
)
|
||||
# Use transformed exception if callback returned one, otherwise use original
|
||||
if transformed_exception is not None:
|
||||
e = transformed_exception
|
||||
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise ProxyException(
|
||||
message=e.message,
|
||||
type=ProxyErrorTypes.budget_exceeded,
|
||||
param=None,
|
||||
code=400,
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
@@ -0,0 +1,835 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm import Router, provider_list
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import STANDARD_CUSTOMER_ID_HEADERS
|
||||
from litellm.proxy._types import *
|
||||
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
|
||||
|
||||
|
||||
def _get_request_ip_address(
|
||||
request: Request, use_x_forwarded_for: Optional[bool] = False
|
||||
) -> Optional[str]:
|
||||
client_ip = None
|
||||
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
||||
client_ip = request.headers["x-forwarded-for"]
|
||||
elif request.client is not None:
|
||||
client_ip = request.client.host
|
||||
else:
|
||||
client_ip = ""
|
||||
|
||||
return client_ip
|
||||
|
||||
|
||||
def _check_valid_ip(
|
||||
allowed_ips: Optional[List[str]],
|
||||
request: Request,
|
||||
use_x_forwarded_for: Optional[bool] = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Returns if ip is allowed or not
|
||||
"""
|
||||
if allowed_ips is None: # if not set, assume true
|
||||
return True, None
|
||||
|
||||
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
||||
client_ip = _get_request_ip_address(
|
||||
request=request, use_x_forwarded_for=use_x_forwarded_for
|
||||
)
|
||||
|
||||
# Check if IP address is allowed
|
||||
if client_ip not in allowed_ips:
|
||||
return False, client_ip
|
||||
|
||||
return True, client_ip
|
||||
|
||||
|
||||
def check_complete_credentials(request_body: dict) -> bool:
|
||||
"""
|
||||
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
|
||||
"""
|
||||
given_model: Optional[str] = None
|
||||
|
||||
given_model = request_body.get("model")
|
||||
if given_model is None:
|
||||
return False
|
||||
|
||||
if (
|
||||
"sagemaker" in given_model
|
||||
or "bedrock" in given_model
|
||||
or "vertex_ai" in given_model
|
||||
or "vertex_ai_beta" in given_model
|
||||
):
|
||||
# complex credentials - easier to make a malicious request
|
||||
return False
|
||||
|
||||
if "api_key" in request_body:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
|
||||
"""
|
||||
Check if request_body_value matches the regex_str or is equal to param
|
||||
"""
|
||||
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_param_allowed(
|
||||
param: str,
|
||||
request_body_value: Any,
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if param is a str or dict and if request_body_value is in the list of allowed values
|
||||
"""
|
||||
if configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
for item in configurable_clientside_auth_params:
|
||||
if isinstance(item, str) and param == item:
|
||||
return True
|
||||
elif isinstance(item, Dict):
|
||||
if param == "api_base" and check_regex_or_str_match(
|
||||
request_body_value=request_body_value,
|
||||
regex_str=item["api_base"],
|
||||
): # assume param is a regex
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _allow_model_level_clientside_configurable_parameters(
|
||||
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if model is allowed to use configurable client-side params
|
||||
- get matching model
|
||||
- check if 'clientside_configurable_parameters' is set for model
|
||||
-
|
||||
"""
|
||||
if llm_router is None:
|
||||
return False
|
||||
# check if model is set
|
||||
model_info = llm_router.get_model_group_info(model_group=model)
|
||||
if model_info is None:
|
||||
# check if wildcard model is set
|
||||
if model.split("/", 1)[0] in provider_list:
|
||||
model_info = llm_router.get_model_group_info(
|
||||
model_group=model.split("/", 1)[0]
|
||||
)
|
||||
|
||||
if model_info is None:
|
||||
return False
|
||||
|
||||
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
return _is_param_allowed(
|
||||
param=param,
|
||||
request_body_value=request_body_value,
|
||||
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
|
||||
)
|
||||
|
||||
|
||||
def is_request_body_safe(
|
||||
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the request body is safe.
|
||||
|
||||
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
|
||||
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
|
||||
"""
|
||||
banned_params = ["api_base", "base_url"]
|
||||
|
||||
for param in banned_params:
|
||||
if (
|
||||
param in request_body
|
||||
and not check_complete_credentials( # allow client-credentials to be passed to proxy
|
||||
request_body=request_body
|
||||
)
|
||||
):
|
||||
if general_settings.get("allow_client_side_credentials") is True:
|
||||
return True
|
||||
elif (
|
||||
_allow_model_level_clientside_configurable_parameters(
|
||||
model=model,
|
||||
param=param,
|
||||
request_body_value=request_body[param],
|
||||
llm_router=llm_router,
|
||||
)
|
||||
is True
|
||||
):
|
||||
return True
|
||||
raise ValueError(
|
||||
f"Rejected Request: {param} is not allowed in request body. "
|
||||
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
|
||||
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def pre_db_read_auth_checks(
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
):
|
||||
"""
|
||||
1. Checks if request size is under max_request_size_mb (if set)
|
||||
2. Check if request body is safe (example user has not set api_base in request body)
|
||||
3. Check if IP address is allowed (if set)
|
||||
4. Check if request route is an allowed route on the proxy (if set)
|
||||
|
||||
Returns:
|
||||
- True
|
||||
|
||||
Raises:
|
||||
- HTTPException if request fails initial auth checks
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
|
||||
|
||||
# Check 1. request size
|
||||
await check_if_request_size_is_safe(request=request)
|
||||
|
||||
# Check 2. Request body is safe
|
||||
is_request_body_safe(
|
||||
request_body=request_data,
|
||||
general_settings=general_settings,
|
||||
llm_router=llm_router,
|
||||
model=request_data.get(
|
||||
"model", ""
|
||||
), # [TODO] use model passed in url as well (azure openai routes)
|
||||
)
|
||||
|
||||
# Check 3. Check if IP address is allowed
|
||||
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||
allowed_ips=general_settings.get("allowed_ips", None),
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not is_valid_ip:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
|
||||
)
|
||||
|
||||
# Check 4. Check if request route is an allowed route on the proxy
|
||||
if "allowed_routes" in general_settings:
|
||||
_allowed_routes = general_settings["allowed_routes"]
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.error(
|
||||
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
if route not in _allowed_routes:
|
||||
verbose_proxy_logger.error(
|
||||
f"Route {route} not in allowed_routes={_allowed_routes}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access forbidden: Route {route} not allowed",
|
||||
)
|
||||
|
||||
|
||||
def route_in_additonal_public_routes(current_route: str):
|
||||
"""
|
||||
Helper to check if the user defined public_routes on config.yaml
|
||||
|
||||
Parameters:
|
||||
- current_route: str - the route the user is trying to call
|
||||
|
||||
Returns:
|
||||
- bool - True if the route is defined in public_routes
|
||||
- bool - False if the route is not defined in public_routes
|
||||
|
||||
Supports wildcard patterns (e.g., "/api/*" matches "/api/users", "/api/users/123")
|
||||
|
||||
In order to use this the litellm config.yaml should have the following in general_settings:
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate", "/api/*"]
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
try:
|
||||
if premium_user is not True:
|
||||
return False
|
||||
if general_settings is None:
|
||||
return False
|
||||
|
||||
routes_defined = general_settings.get("public_routes", [])
|
||||
|
||||
# Check exact match first
|
||||
if current_route in routes_defined:
|
||||
return True
|
||||
|
||||
# Check wildcard patterns
|
||||
for route_pattern in routes_defined:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=current_route, pattern=route_pattern
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def get_request_route(request: Request) -> str:
|
||||
"""
|
||||
Helper to get the route from the request
|
||||
|
||||
remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
|
||||
"""
|
||||
try:
|
||||
if hasattr(request, "base_url") and request.url.path.startswith(
|
||||
request.base_url.path
|
||||
):
|
||||
# remove base_url from path
|
||||
return request.url.path[len(request.base_url.path) - 1 :]
|
||||
else:
|
||||
return request.url.path
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
|
||||
)
|
||||
return request.url.path
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def normalize_request_route(route: str) -> str:
|
||||
"""
|
||||
Normalize request routes by replacing dynamic path parameters with placeholders.
|
||||
|
||||
This prevents high cardinality in Prometheus metrics by collapsing routes like:
|
||||
- /v1/responses/1234567890 -> /v1/responses/{response_id}
|
||||
- /v1/threads/thread_123 -> /v1/threads/{thread_id}
|
||||
|
||||
Args:
|
||||
route: The request route path
|
||||
|
||||
Returns:
|
||||
Normalized route with dynamic parameters replaced by placeholders
|
||||
|
||||
Examples:
|
||||
>>> normalize_request_route("/v1/responses/abc123")
|
||||
'/v1/responses/{response_id}'
|
||||
>>> normalize_request_route("/v1/responses/abc123/cancel")
|
||||
'/v1/responses/{response_id}/cancel'
|
||||
>>> normalize_request_route("/chat/completions")
|
||||
'/chat/completions'
|
||||
"""
|
||||
# Define patterns for routes with dynamic IDs
|
||||
# Format: (regex_pattern, replacement_template)
|
||||
patterns = [
|
||||
# Responses API - must come before generic patterns
|
||||
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/responses)/([^/]+)$", r"\1/{response_id}"),
|
||||
(r"^(/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
|
||||
(r"^(/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
|
||||
(r"^(/responses)/([^/]+)$", r"\1/{response_id}"),
|
||||
# Threads API
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)/([^/]+)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5/{step_id}",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/cancel)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/submit_tool_outputs)$",
|
||||
r"\1/{thread_id}\3/{run_id}\5",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)$",
|
||||
r"\1/{thread_id}\3/{run_id}",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)$", r"\1/{thread_id}\3"),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)/([^/]+)$",
|
||||
r"\1/{thread_id}\3/{message_id}",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)$", r"\1/{thread_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/threads)/([^/]+)$", r"\1/{thread_id}"),
|
||||
# Vector Stores API
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)/([^/]+)$",
|
||||
r"\1/{vector_store_id}\3/{file_id}",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)$",
|
||||
r"\1/{vector_store_id}\3",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)/([^/]+)$",
|
||||
r"\1/{vector_store_id}\3/{batch_id}",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)$",
|
||||
r"\1/{vector_store_id}\3",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/vector_stores)/([^/]+)$", r"\1/{vector_store_id}"),
|
||||
# Assistants API
|
||||
(r"^(/(?:openai/)?v1/assistants)/([^/]+)$", r"\1/{assistant_id}"),
|
||||
# Files API
|
||||
(r"^(/(?:openai/)?v1/files)/([^/]+)(/content)$", r"\1/{file_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/files)/([^/]+)$", r"\1/{file_id}"),
|
||||
# Batches API
|
||||
(r"^(/(?:openai/)?v1/batches)/([^/]+)(/cancel)$", r"\1/{batch_id}\3"),
|
||||
(r"^(/(?:openai/)?v1/batches)/([^/]+)$", r"\1/{batch_id}"),
|
||||
# Fine-tuning API
|
||||
(
|
||||
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/events)$",
|
||||
r"\1/{fine_tuning_job_id}\3",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/cancel)$",
|
||||
r"\1/{fine_tuning_job_id}\3",
|
||||
),
|
||||
(
|
||||
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/checkpoints)$",
|
||||
r"\1/{fine_tuning_job_id}\3",
|
||||
),
|
||||
(r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)$", r"\1/{fine_tuning_job_id}"),
|
||||
# Models API
|
||||
(r"^(/(?:openai/)?v1/models)/([^/]+)$", r"\1/{model}"),
|
||||
]
|
||||
|
||||
# Apply patterns in order
|
||||
for pattern, replacement in patterns:
|
||||
normalized = re.sub(pattern, replacement, route)
|
||||
if normalized != route:
|
||||
return normalized
|
||||
|
||||
# Return original route if no pattern matched
|
||||
return route
|
||||
|
||||
|
||||
async def check_if_request_size_is_safe(request: Request) -> bool:
|
||||
"""
|
||||
Enterprise Only:
|
||||
- Checks if the request size is within the limit
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request.
|
||||
|
||||
Returns:
|
||||
bool: True if the request size is within the limit
|
||||
|
||||
Raises:
|
||||
ProxyException: If the request size is too large
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
||||
|
||||
if max_request_size_mb is not None:
|
||||
# Check if premium user
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.warning(
|
||||
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Get the request body
|
||||
content_length = request.headers.get("content-length")
|
||||
|
||||
if content_length:
|
||||
header_size = int(content_length)
|
||||
header_size_mb = bytes_to_mb(bytes_value=header_size)
|
||||
verbose_proxy_logger.debug(
|
||||
f"content_length request size in MB={header_size_mb}"
|
||||
)
|
||||
|
||||
if header_size_mb > max_request_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
else:
|
||||
# If Content-Length is not available, read the body
|
||||
body = await request.body()
|
||||
body_size = len(body)
|
||||
request_size_mb = bytes_to_mb(bytes_value=body_size)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"request body request size in MB={request_size_mb}"
|
||||
)
|
||||
if request_size_mb > max_request_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def check_response_size_is_safe(response: Any) -> bool:
|
||||
"""
|
||||
Enterprise Only:
|
||||
- Checks if the response size is within the limit
|
||||
|
||||
Args:
|
||||
response (Any): The response to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the response size is within the limit
|
||||
|
||||
Raises:
|
||||
ProxyException: If the response size is too large
|
||||
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
max_response_size_mb = general_settings.get("max_response_size_mb", None)
|
||||
if max_response_size_mb is not None:
|
||||
# Check if premium user
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.warning(
|
||||
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
|
||||
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
|
||||
if response_size_mb > max_response_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def bytes_to_mb(bytes_value: int):
|
||||
"""
|
||||
Helper to convert bytes to MB
|
||||
"""
|
||||
return bytes_value / (1024 * 1024)
|
||||
|
||||
|
||||
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
|
||||
def get_key_model_rpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
Get the model rpm limit for a given api key.
|
||||
|
||||
Priority order (returns first found):
|
||||
1. Key metadata (model_rpm_limit)
|
||||
2. Key model_max_budget (rpm_limit per model)
|
||||
3. Team metadata (model_rpm_limit)
|
||||
"""
|
||||
# 1. Check key metadata first (takes priority)
|
||||
if user_api_key_dict.metadata:
|
||||
result = user_api_key_dict.metadata.get("model_rpm_limit")
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 2. Check model_max_budget
|
||||
if user_api_key_dict.model_max_budget:
|
||||
model_rpm_limit: Dict[str, Any] = {}
|
||||
for model, budget in user_api_key_dict.model_max_budget.items():
|
||||
if isinstance(budget, dict) and budget.get("rpm_limit") is not None:
|
||||
model_rpm_limit[model] = budget["rpm_limit"]
|
||||
if model_rpm_limit:
|
||||
return model_rpm_limit
|
||||
|
||||
# 3. Fallback to team metadata
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_rpm_limit")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_key_model_tpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
Get the model tpm limit for a given api key.
|
||||
|
||||
Priority order (returns first found):
|
||||
1. Key metadata (model_tpm_limit)
|
||||
2. Key model_max_budget (tpm_limit per model)
|
||||
3. Team metadata (model_tpm_limit)
|
||||
"""
|
||||
# 1. Check key metadata first (takes priority)
|
||||
if user_api_key_dict.metadata:
|
||||
result = user_api_key_dict.metadata.get("model_tpm_limit")
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 2. Check model_max_budget (iterate per-model like RPM does)
|
||||
if user_api_key_dict.model_max_budget:
|
||||
model_tpm_limit: Dict[str, Any] = {}
|
||||
for model, budget in user_api_key_dict.model_max_budget.items():
|
||||
if isinstance(budget, dict) and budget.get("tpm_limit") is not None:
|
||||
model_tpm_limit[model] = budget["tpm_limit"]
|
||||
if model_tpm_limit:
|
||||
return model_tpm_limit
|
||||
|
||||
# 3. Fallback to team metadata
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_tpm_limit")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_rate_limit_from_metadata(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
metadata_accessor_key: Literal["team_metadata", "organization_metadata"],
|
||||
rate_limit_key: Literal["model_rpm_limit", "model_tpm_limit"],
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if getattr(user_api_key_dict, metadata_accessor_key):
|
||||
return getattr(user_api_key_dict, metadata_accessor_key).get(rate_limit_key)
|
||||
return None
|
||||
|
||||
|
||||
def get_team_model_rpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_rpm_limit")
|
||||
return None
|
||||
|
||||
|
||||
def get_team_model_tpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_tpm_limit")
|
||||
return None
|
||||
|
||||
|
||||
def is_pass_through_provider_route(route: str) -> bool:
|
||||
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
|
||||
"vertex-ai",
|
||||
]
|
||||
|
||||
# check if any of the prefixes are in the route
|
||||
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
|
||||
if prefix in route:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _has_user_setup_sso():
|
||||
"""
|
||||
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
|
||||
Returns a boolean indicating whether SSO has been set up.
|
||||
"""
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
sso_setup = (
|
||||
(microsoft_client_id is not None)
|
||||
or (google_client_id is not None)
|
||||
or (generic_client_id is not None)
|
||||
)
|
||||
|
||||
return sso_setup
|
||||
|
||||
|
||||
def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[str]:
|
||||
"""Return the header_name mapped to CUSTOMER role, if any (dict-based)."""
|
||||
if not user_id_mapping:
|
||||
return None
|
||||
items = user_id_mapping if isinstance(user_id_mapping, list) else [user_id_mapping]
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = item.get("litellm_user_role")
|
||||
header_name = item.get("header_name")
|
||||
if role is None or not header_name:
|
||||
continue
|
||||
if str(role).lower() == str(LitellmUserRoles.CUSTOMER).lower():
|
||||
return header_name
|
||||
return None
|
||||
|
||||
|
||||
def _get_customer_id_from_standard_headers(
|
||||
request_headers: Optional[dict],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Check standard customer ID headers for a customer/end-user ID.
|
||||
|
||||
This enables tools like Claude Code to pass customer IDs via ANTHROPIC_CUSTOM_HEADERS.
|
||||
No configuration required - these headers are always checked.
|
||||
|
||||
Args:
|
||||
request_headers: The request headers dict
|
||||
|
||||
Returns:
|
||||
The customer ID if found in standard headers, None otherwise
|
||||
"""
|
||||
if request_headers is None:
|
||||
return None
|
||||
|
||||
for standard_header in STANDARD_CUSTOMER_ID_HEADERS:
|
||||
for header_name, header_value in request_headers.items():
|
||||
if header_name.lower() == standard_header.lower():
|
||||
user_id_str = str(header_value) if header_value is not None else ""
|
||||
if user_id_str.strip():
|
||||
return user_id_str
|
||||
return None
|
||||
|
||||
|
||||
def get_end_user_id_from_request_body(
|
||||
request_body: dict, request_headers: Optional[dict] = None
|
||||
) -> Optional[str]:
|
||||
# Import general_settings here to avoid potential circular import issues at module level
|
||||
# and to ensure it's fetched at runtime.
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
# Check 1: Standard customer ID headers (always checked, no configuration required)
|
||||
customer_id = _get_customer_id_from_standard_headers(
|
||||
request_headers=request_headers
|
||||
)
|
||||
if customer_id is not None:
|
||||
return customer_id
|
||||
|
||||
# Check 2: Follow the user header mappings feature, if not found, then check for deprecated user_header_name (only if request_headers is provided)
|
||||
# User query: "system not respecting user_header_name property"
|
||||
# This implies the key in general_settings is 'user_header_name'.
|
||||
if request_headers is not None:
|
||||
custom_header_name_to_check: Optional[str] = None
|
||||
|
||||
# Prefer user mappings (new behavior)
|
||||
user_id_mapping = general_settings.get("user_header_mappings", None)
|
||||
if user_id_mapping:
|
||||
custom_header_name_to_check = get_customer_user_header_from_mapping(
|
||||
user_id_mapping
|
||||
)
|
||||
|
||||
# Fallback to deprecated user_header_name if mapping did not specify
|
||||
if not custom_header_name_to_check:
|
||||
user_id_header_config_key = "user_header_name"
|
||||
value = general_settings.get(user_id_header_config_key)
|
||||
if isinstance(value, str) and value.strip() != "":
|
||||
custom_header_name_to_check = value
|
||||
|
||||
# If we have a header name to check, try to read it from request headers
|
||||
if isinstance(custom_header_name_to_check, str):
|
||||
for header_name, header_value in request_headers.items():
|
||||
if header_name.lower() == custom_header_name_to_check.lower():
|
||||
user_id_from_header = header_value
|
||||
user_id_str = (
|
||||
str(user_id_from_header)
|
||||
if user_id_from_header is not None
|
||||
else ""
|
||||
)
|
||||
if user_id_str.strip():
|
||||
return user_id_str
|
||||
|
||||
# Check 3: 'user' field in request_body (commonly OpenAI)
|
||||
if "user" in request_body and request_body["user"] is not None:
|
||||
user_from_body_user_field = request_body["user"]
|
||||
return str(user_from_body_user_field)
|
||||
|
||||
# Check 4: 'litellm_metadata.user' in request_body (commonly Anthropic)
|
||||
litellm_metadata = request_body.get("litellm_metadata")
|
||||
if isinstance(litellm_metadata, dict):
|
||||
user_from_litellm_metadata = litellm_metadata.get("user")
|
||||
if user_from_litellm_metadata is not None:
|
||||
return str(user_from_litellm_metadata)
|
||||
|
||||
# Check 5: 'metadata.user_id' in request_body (another common pattern)
|
||||
metadata_dict = request_body.get("metadata")
|
||||
if isinstance(metadata_dict, dict):
|
||||
user_id_from_metadata_field = metadata_dict.get("user_id")
|
||||
if user_id_from_metadata_field is not None:
|
||||
return str(user_id_from_metadata_field)
|
||||
|
||||
# Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter)
|
||||
# SECURITY NOTE: safety_identifier can be set by any caller in the request body.
|
||||
# Only use this for end-user identification in trusted environments where you control
|
||||
# the calling application. For untrusted callers, prefer using headers or server-side
|
||||
# middleware to set the end_user_id to prevent impersonation.
|
||||
if request_body.get("safety_identifier") is not None:
|
||||
user_from_body_user_field = request_body["safety_identifier"]
|
||||
return str(user_from_body_user_field)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_from_request(
|
||||
request_data: dict, route: str
|
||||
) -> Optional[Union[str, List[str]]]:
|
||||
# First try to get model from request_data
|
||||
model = request_data.get("model") or request_data.get("target_model_names")
|
||||
|
||||
if model is not None:
|
||||
model_names = model.split(",")
|
||||
if len(model_names) == 1:
|
||||
model = model_names[0].strip()
|
||||
else:
|
||||
model = [m.strip() for m in model_names]
|
||||
|
||||
# If model not in request_data, try to extract from route
|
||||
if model is None:
|
||||
# Parse model from route that follows the pattern /openai/deployments/{model}/*
|
||||
match = re.match(r"/openai/deployments/([^/]+)", route)
|
||||
if match:
|
||||
model = match.group(1)
|
||||
|
||||
# If still not found, extract model from Google generateContent-style routes.
|
||||
# These routes put the model in the path and allow "/" inside the model id.
|
||||
# Examples:
|
||||
# - /v1beta/models/gemini-2.0-flash:generateContent
|
||||
# - /v1beta/models/bedrock/claude-sonnet-3.7:generateContent
|
||||
# - /models/custom/ns/model:streamGenerateContent
|
||||
if model is None and not route.lower().startswith("/vertex"):
|
||||
google_match = re.search(r"/(?:v1beta|beta)/models/([^:]+):", route)
|
||||
if google_match:
|
||||
model = google_match.group(1)
|
||||
|
||||
if model is None and not route.lower().startswith("/vertex"):
|
||||
google_match = re.search(r"^/models/([^:]+):", route)
|
||||
if google_match:
|
||||
model = google_match.group(1)
|
||||
|
||||
# If still not found, extract from Vertex AI passthrough route
|
||||
# Pattern: /vertex_ai/.../models/{model_id}:*
|
||||
# Example: /vertex_ai/v1/.../models/gemini-1.5-pro:generateContent
|
||||
if model is None and route.lower().startswith("/vertex"):
|
||||
vertex_match = re.search(r"/models/([^:]+)", route)
|
||||
if vertex_match:
|
||||
model = vertex_match.group(1)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def abbreviate_api_key(api_key: str) -> str:
|
||||
return f"sk-...{api_key[-4:]}"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
IP address utilities for MCP public/private access control.
|
||||
|
||||
Internal callers (private IPs) see all MCP servers.
|
||||
External callers (public IPs) only see servers with available_on_public_internet=True.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||
|
||||
|
||||
class IPAddressUtils:
|
||||
"""Static utilities for IP-based MCP access control."""
|
||||
|
||||
_DEFAULT_INTERNAL_NETWORKS = [
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def parse_internal_networks(
|
||||
configured_ranges: Optional[List[str]],
|
||||
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
|
||||
"""Parse configured CIDR ranges into network objects, falling back to defaults."""
|
||||
if not configured_ranges:
|
||||
return IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
|
||||
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
|
||||
for cidr in configured_ranges:
|
||||
try:
|
||||
networks.append(ipaddress.ip_network(cidr, strict=False))
|
||||
except ValueError:
|
||||
verbose_proxy_logger.warning(
|
||||
"Invalid CIDR in mcp_internal_ip_ranges: %s, skipping", cidr
|
||||
)
|
||||
return networks if networks else IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
|
||||
|
||||
@staticmethod
|
||||
def parse_trusted_proxy_networks(
|
||||
configured_ranges: Optional[List[str]],
|
||||
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
|
||||
"""
|
||||
Parse trusted proxy CIDR ranges for XFF validation.
|
||||
Returns empty list if not configured (XFF will not be trusted).
|
||||
"""
|
||||
if not configured_ranges:
|
||||
return []
|
||||
networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = []
|
||||
for cidr in configured_ranges:
|
||||
try:
|
||||
networks.append(ipaddress.ip_network(cidr, strict=False))
|
||||
except ValueError:
|
||||
verbose_proxy_logger.warning(
|
||||
"Invalid CIDR in mcp_trusted_proxy_ranges: %s, skipping", cidr
|
||||
)
|
||||
return networks
|
||||
|
||||
@staticmethod
|
||||
def is_trusted_proxy(
|
||||
proxy_ip: Optional[str],
|
||||
trusted_networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]],
|
||||
) -> bool:
|
||||
"""Check if the direct connection IP is from a trusted proxy."""
|
||||
if not proxy_ip or not trusted_networks:
|
||||
return False
|
||||
try:
|
||||
addr = ipaddress.ip_address(proxy_ip.strip())
|
||||
return any(addr in network for network in trusted_networks)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_internal_ip(
|
||||
client_ip: Optional[str],
|
||||
internal_networks: Optional[
|
||||
List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]
|
||||
] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a client IP is from an internal/private network.
|
||||
|
||||
Handles X-Forwarded-For comma chains (takes leftmost = original client).
|
||||
Fails closed: empty/invalid IPs are treated as external.
|
||||
"""
|
||||
if not client_ip:
|
||||
return False
|
||||
|
||||
# X-Forwarded-For may contain comma-separated chain; leftmost is original client
|
||||
if "," in client_ip:
|
||||
client_ip = client_ip.split(",")[0].strip()
|
||||
|
||||
networks = internal_networks or IPAddressUtils._DEFAULT_INTERNAL_NETWORKS
|
||||
|
||||
try:
|
||||
addr = ipaddress.ip_address(client_ip.strip())
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
return any(addr in network for network in networks)
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_client_ip(
|
||||
request: Request,
|
||||
general_settings: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract client IP from a FastAPI request for MCP access control.
|
||||
|
||||
Security: Only trusts X-Forwarded-For if:
|
||||
1. use_x_forwarded_for is enabled in settings
|
||||
2. The direct connection is from a trusted proxy (if mcp_trusted_proxy_ranges configured)
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
general_settings: Optional settings dict. If not provided, imports from proxy_server.
|
||||
"""
|
||||
if general_settings is None:
|
||||
try:
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings as proxy_general_settings,
|
||||
)
|
||||
|
||||
general_settings = proxy_general_settings
|
||||
except ImportError:
|
||||
general_settings = {}
|
||||
|
||||
# Handle case where general_settings is still None after import
|
||||
if general_settings is None:
|
||||
general_settings = {}
|
||||
|
||||
use_xff = general_settings.get("use_x_forwarded_for", False)
|
||||
|
||||
# If XFF is enabled, validate the request comes from a trusted proxy
|
||||
if use_xff and "x-forwarded-for" in request.headers:
|
||||
trusted_ranges = general_settings.get("mcp_trusted_proxy_ranges")
|
||||
if trusted_ranges:
|
||||
# Validate direct connection is from trusted proxy
|
||||
direct_ip = request.client.host if request.client else None
|
||||
trusted_networks = IPAddressUtils.parse_trusted_proxy_networks(
|
||||
trusted_ranges
|
||||
)
|
||||
if not IPAddressUtils.is_trusted_proxy(direct_ip, trusted_networks):
|
||||
# Untrusted source trying to set XFF - ignore XFF, use direct IP
|
||||
verbose_proxy_logger.warning(
|
||||
"XFF header from untrusted IP %s, ignoring", direct_ip
|
||||
)
|
||||
return direct_ip
|
||||
return _get_request_ip_address(request, use_x_forwarded_for=use_xff)
|
||||
@@ -0,0 +1,214 @@
|
||||
# What is this?
|
||||
## If litellm license in env, checks if it's valid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import NON_LLM_CONNECTION_TIMEOUT
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import EnterpriseLicenseData
|
||||
|
||||
|
||||
class LicenseCheck:
|
||||
"""
|
||||
- Check if license in env
|
||||
- Returns if license is valid
|
||||
"""
|
||||
|
||||
base_url = "https://license.litellm.ai"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
|
||||
self.http_handler = HTTPHandler(timeout=NON_LLM_CONNECTION_TIMEOUT)
|
||||
self._premium_check_logged = False
|
||||
self.public_key = None
|
||||
self.read_public_key()
|
||||
self.airgapped_license_data: Optional["EnterpriseLicenseData"] = None
|
||||
|
||||
def read_public_key(self):
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
# current dir
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
# check if public_key.pem exists
|
||||
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
|
||||
if os.path.exists(_path_to_public_key):
|
||||
with open(_path_to_public_key, "rb") as key_file:
|
||||
self.public_key = serialization.load_pem_public_key(key_file.read())
|
||||
else:
|
||||
self.public_key = None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
||||
|
||||
def _verify(self, license_str: str) -> bool:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
|
||||
self.base_url, license_str
|
||||
)
|
||||
)
|
||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try: # don't impact user, if call fails
|
||||
num_retries = 3
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
response = self.http_handler.get(url=url)
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
if i == num_retries - 1:
|
||||
raise
|
||||
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
premium = response_json["verify"]
|
||||
|
||||
assert isinstance(premium, bool)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
|
||||
license_str, premium
|
||||
)
|
||||
)
|
||||
return premium
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
|
||||
license_str, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
def is_premium(self) -> bool:
|
||||
"""
|
||||
1. verify_license_without_api_request: checks if license was generate using private / public key pair
|
||||
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
|
||||
"""
|
||||
try:
|
||||
if not self._premium_check_logged:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
|
||||
self.license_str
|
||||
)
|
||||
)
|
||||
|
||||
if self.license_str is None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
|
||||
if not self._premium_check_logged:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - Updated 'self.license_str' - {}".format(
|
||||
self.license_str
|
||||
)
|
||||
)
|
||||
self._premium_check_logged = True
|
||||
|
||||
if self.license_str is None:
|
||||
return False
|
||||
elif (
|
||||
self.verify_license_without_api_request(
|
||||
public_key=self.public_key, license_key=self.license_str
|
||||
)
|
||||
is True
|
||||
):
|
||||
return True
|
||||
elif self._verify(license_str=self.license_str) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def is_over_limit(self, total_users: int) -> bool:
|
||||
"""
|
||||
Check if the license is over the limit
|
||||
"""
|
||||
if self.airgapped_license_data is None:
|
||||
return False
|
||||
if "max_users" not in self.airgapped_license_data or not isinstance(
|
||||
self.airgapped_license_data["max_users"], int
|
||||
):
|
||||
return False
|
||||
return total_users > self.airgapped_license_data["max_users"]
|
||||
|
||||
def is_team_count_over_limit(self, team_count: int) -> bool:
|
||||
"""
|
||||
Check if the license is over the limit
|
||||
"""
|
||||
if self.airgapped_license_data is None:
|
||||
return False
|
||||
|
||||
_max_teams_in_license: Optional[int] = self.airgapped_license_data.get(
|
||||
"max_teams"
|
||||
)
|
||||
if "max_teams" not in self.airgapped_license_data or not isinstance(
|
||||
_max_teams_in_license, int
|
||||
):
|
||||
return False
|
||||
return team_count > _max_teams_in_license
|
||||
|
||||
def verify_license_without_api_request(self, public_key, license_key):
|
||||
try:
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
from litellm.proxy._types import EnterpriseLicenseData
|
||||
|
||||
# Decode the license key - add padding if needed for base64
|
||||
# Base64 strings need to be a multiple of 4 characters
|
||||
padding_needed = len(license_key) % 4
|
||||
if padding_needed:
|
||||
license_key += "=" * (4 - padding_needed)
|
||||
|
||||
decoded = base64.b64decode(license_key)
|
||||
message, signature = decoded.split(b".", 1)
|
||||
|
||||
# Verify the signature
|
||||
public_key.verify(
|
||||
signature,
|
||||
message,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
# Decode and parse the data
|
||||
license_data = json.loads(message.decode())
|
||||
|
||||
self.airgapped_license_data = EnterpriseLicenseData(**license_data)
|
||||
|
||||
# debug information provided in license data
|
||||
verbose_proxy_logger.debug("License data: %s", license_data)
|
||||
|
||||
# Check expiration date
|
||||
expiration_date = datetime.strptime(
|
||||
license_data["expiration_date"], "%Y-%m-%d"
|
||||
)
|
||||
if expiration_date < datetime.now():
|
||||
return False, "License has expired"
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::verify_license_without_api_request - Unable to verify License locally. - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
Login utilities for handling user authentication in the proxy server.
|
||||
|
||||
This module contains the core login logic that can be reused across different
|
||||
login endpoints (e.g., /login and /v2/login).
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME, LITELLM_UI_SESSION_DURATION
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
UpdateUserRequest,
|
||||
UserAPIKeyAuth,
|
||||
hash_token,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.ui_sso import (
|
||||
get_disabled_non_admin_personal_key_creation,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, get_server_root_path
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
|
||||
|
||||
|
||||
def get_ui_credentials(master_key: Optional[str]) -> tuple[str, str]:
|
||||
"""
|
||||
Get UI username and password from environment variables or master key.
|
||||
|
||||
Args:
|
||||
master_key: Master key for the proxy (used as fallback for password)
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing (ui_username, ui_password)
|
||||
|
||||
Raises:
|
||||
ProxyException: If neither UI_PASSWORD nor master_key is available
|
||||
"""
|
||||
ui_username = os.getenv("UI_USERNAME", "admin")
|
||||
ui_password = os.getenv("UI_PASSWORD", None)
|
||||
if ui_password is None:
|
||||
ui_password = str(master_key) if master_key is not None else None
|
||||
if ui_password is None:
|
||||
raise ProxyException(
|
||||
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="UI_PASSWORD",
|
||||
code=500,
|
||||
)
|
||||
return ui_username, ui_password
|
||||
|
||||
|
||||
class LoginResult:
|
||||
"""Result object containing authentication data from login."""
|
||||
|
||||
user_id: str
|
||||
key: str
|
||||
user_email: Optional[str]
|
||||
user_role: str
|
||||
login_method: Literal["sso", "username_password"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
key: str,
|
||||
user_email: Optional[str],
|
||||
user_role: str,
|
||||
login_method: Literal["sso", "username_password"] = "username_password",
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.key = key
|
||||
self.user_email = user_email
|
||||
self.user_role = user_role
|
||||
self.login_method = login_method
|
||||
|
||||
|
||||
async def authenticate_user( # noqa: PLR0915
|
||||
username: str,
|
||||
password: str,
|
||||
master_key: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
) -> LoginResult:
|
||||
"""
|
||||
Authenticate a user and generate an API key for UI access.
|
||||
|
||||
This function handles two login scenarios:
|
||||
1. Admin login using UI_USERNAME and UI_PASSWORD
|
||||
2. User login using email and password from database
|
||||
|
||||
Args:
|
||||
username: Username or email from the login form
|
||||
password: Password from the login form
|
||||
master_key: Master key for the proxy (required)
|
||||
prisma_client: Prisma database client (optional)
|
||||
|
||||
Returns:
|
||||
LoginResult: Object containing authentication data
|
||||
|
||||
Raises:
|
||||
ProxyException: If authentication fails or required configuration is missing
|
||||
"""
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="master_key",
|
||||
code=500,
|
||||
)
|
||||
|
||||
ui_username, ui_password = get_ui_credentials(master_key)
|
||||
|
||||
# Check if we can find the `username` in the db. On the UI, users can enter username=their email
|
||||
_user_row: Optional[LiteLLM_UserTable] = None
|
||||
user_role: Optional[
|
||||
Literal[
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
]
|
||||
] = None
|
||||
|
||||
if prisma_client is not None:
|
||||
_user_row = cast(
|
||||
Optional[LiteLLM_UserTable],
|
||||
await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_email": {"equals": username, "mode": "insensitive"}}
|
||||
),
|
||||
)
|
||||
|
||||
"""
|
||||
To login to Admin UI, we support the following
|
||||
- Login with UI_USERNAME and UI_PASSWORD
|
||||
- Login with Invite Link `user_email` and `password` combination
|
||||
"""
|
||||
if secrets.compare_digest(
|
||||
username.encode("utf-8"), ui_username.encode("utf-8")
|
||||
) and secrets.compare_digest(password.encode("utf-8"), ui_password.encode("utf-8")):
|
||||
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
|
||||
user_role = LitellmUserRoles.PROXY_ADMIN
|
||||
user_id = LITELLM_PROXY_ADMIN_NAME
|
||||
|
||||
# we want the key created to have PROXY_ADMIN_PERMISSIONS
|
||||
key_user_id = LITELLM_PROXY_ADMIN_NAME
|
||||
if (
|
||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||
) or user_id == LITELLM_PROXY_ADMIN_NAME:
|
||||
# checks if user is admin
|
||||
key_user_id = os.getenv("PROXY_ADMIN_ID", LITELLM_PROXY_ADMIN_NAME)
|
||||
|
||||
# Admin is Authe'd in - generate key for the UI to access Proxy
|
||||
|
||||
# ensure this user is set as the proxy admin, in this route there is no sso, we can assume this user is only the admin
|
||||
await user_update(
|
||||
data=UpdateUserRequest(
|
||||
user_id=key_user_id,
|
||||
user_role=user_role,
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
|
||||
if os.getenv("DATABASE_URL") is not None:
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{
|
||||
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
||||
"duration": LITELLM_UI_SESSION_DURATION,
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"user_id": key_user_id,
|
||||
"team_id": "litellm-dashboard",
|
||||
}, # type: ignore
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="DATABASE_URL",
|
||||
code=500,
|
||||
)
|
||||
|
||||
key = response["token"] # type: ignore
|
||||
|
||||
if get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
|
||||
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken
|
||||
|
||||
user_info: Optional[LiteLLM_UserTable] = None
|
||||
if _user_row is not None:
|
||||
user_info = _user_row
|
||||
elif (
|
||||
user_id is not None
|
||||
): # if user_id is not None, we are using the UI_USERNAME and UI_PASSWORD
|
||||
user_info = LiteLLM_UserTable(
|
||||
user_id=user_id,
|
||||
user_role=user_role,
|
||||
models=[],
|
||||
max_budget=litellm.max_ui_session_budget,
|
||||
)
|
||||
if user_info is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "User Information is required for experimental UI login"
|
||||
},
|
||||
)
|
||||
|
||||
key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
||||
user_info
|
||||
)
|
||||
|
||||
return LoginResult(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
user_email=None,
|
||||
user_role=user_role,
|
||||
login_method="username_password",
|
||||
)
|
||||
|
||||
elif _user_row is not None:
|
||||
"""
|
||||
When sharing invite links
|
||||
|
||||
-> if the user has no role in the DB assume they are only a viewer
|
||||
"""
|
||||
user_id = getattr(_user_row, "user_id", "unknown")
|
||||
user_role = getattr(
|
||||
_user_row, "user_role", LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
)
|
||||
user_email = getattr(_user_row, "user_email", "unknown")
|
||||
_password = getattr(_user_row, "password", "unknown")
|
||||
|
||||
if _password is None:
|
||||
raise ProxyException(
|
||||
message="User has no password set. Please set a password for the user via `/user/update`.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="password",
|
||||
code=401,
|
||||
)
|
||||
|
||||
# check if password == _user_row.password
|
||||
hash_password = hash_token(token=password)
|
||||
if secrets.compare_digest(
|
||||
password.encode("utf-8"), _password.encode("utf-8")
|
||||
) or secrets.compare_digest(
|
||||
hash_password.encode("utf-8"), _password.encode("utf-8")
|
||||
):
|
||||
if os.getenv("DATABASE_URL") is not None:
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{ # type: ignore
|
||||
"user_role": user_role,
|
||||
"duration": LITELLM_UI_SESSION_DURATION,
|
||||
"key_max_budget": litellm.max_ui_session_budget,
|
||||
"models": [],
|
||||
"aliases": {},
|
||||
"config": {},
|
||||
"spend": 0,
|
||||
"user_id": user_id,
|
||||
"team_id": "litellm-dashboard",
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="DATABASE_URL",
|
||||
code=500,
|
||||
)
|
||||
|
||||
key = response["token"] # type: ignore
|
||||
|
||||
return LoginResult(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
user_email=user_email,
|
||||
user_role=cast(str, user_role),
|
||||
login_method="username_password",
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="invalid_credentials",
|
||||
code=401,
|
||||
)
|
||||
else:
|
||||
raise ProxyException(
|
||||
message="Invalid credentials used to access UI.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="invalid_credentials",
|
||||
code=401,
|
||||
)
|
||||
|
||||
|
||||
def create_ui_token_object(
|
||||
login_result: LoginResult,
|
||||
general_settings: dict,
|
||||
premium_user: bool,
|
||||
) -> ReturnedUITokenObject:
|
||||
"""
|
||||
Create a ReturnedUITokenObject from a LoginResult.
|
||||
|
||||
Args:
|
||||
login_result: The result from authenticate_user
|
||||
general_settings: General proxy settings dictionary
|
||||
premium_user: Whether premium features are enabled
|
||||
|
||||
Returns:
|
||||
ReturnedUITokenObject: Token object ready for JWT encoding
|
||||
"""
|
||||
disabled_non_admin_personal_key_creation = (
|
||||
get_disabled_non_admin_personal_key_creation()
|
||||
)
|
||||
|
||||
return ReturnedUITokenObject(
|
||||
user_id=login_result.user_id,
|
||||
key=login_result.key,
|
||||
user_email=login_result.user_email,
|
||||
user_role=login_result.user_role,
|
||||
login_method=login_result.login_method,
|
||||
premium_user=premium_user,
|
||||
auth_header_name=general_settings.get(
|
||||
"litellm_key_header_name", "Authorization"
|
||||
),
|
||||
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
|
||||
server_root_path=get_server_root_path(),
|
||||
)
|
||||
@@ -0,0 +1,381 @@
|
||||
# What is this?
|
||||
## Common checks for /v1/models and `/model/info`
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
|
||||
from litellm.router import Router
|
||||
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
|
||||
from litellm.types.router import LiteLLM_Params
|
||||
from litellm.utils import get_valid_models
|
||||
|
||||
|
||||
def _check_wildcard_routing(model: str) -> bool:
|
||||
"""
|
||||
Returns True if a model is a provider wildcard.
|
||||
|
||||
eg:
|
||||
- anthropic/*
|
||||
- openai/*
|
||||
- *
|
||||
"""
|
||||
if "*" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_provider_models(
|
||||
provider: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns the list of known models by provider
|
||||
"""
|
||||
if provider == "*":
|
||||
return get_valid_models(litellm_params=litellm_params)
|
||||
|
||||
if provider in litellm.models_by_provider:
|
||||
provider_models = get_valid_models(
|
||||
custom_llm_provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
return provider_models
|
||||
return None
|
||||
|
||||
|
||||
def _get_models_from_access_groups(
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
all_models: List[str],
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
idx_to_remove = []
|
||||
new_models = []
|
||||
for idx, model in enumerate(all_models):
|
||||
if model in model_access_groups:
|
||||
if (
|
||||
not include_model_access_groups
|
||||
): # remove access group, unless requested - e.g. when creating a key
|
||||
idx_to_remove.append(idx)
|
||||
new_models.extend(model_access_groups[model])
|
||||
|
||||
for idx in sorted(idx_to_remove, reverse=True):
|
||||
all_models.pop(idx)
|
||||
|
||||
all_models.extend(new_models)
|
||||
return all_models
|
||||
|
||||
|
||||
async def get_mcp_server_ids(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns the list of MCP server ids for a given key by querying the object_permission table
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
return []
|
||||
|
||||
if user_api_key_dict.object_permission_id is None:
|
||||
return []
|
||||
|
||||
# Make a direct SQL query to get just the mcp_servers
|
||||
try:
|
||||
result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": user_api_key_dict.object_permission_id},
|
||||
)
|
||||
if result and result.mcp_servers:
|
||||
return result.mcp_servers
|
||||
return []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def get_key_models(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
proxy_model_list: List[str],
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
only_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
- If model_access_groups is provided, only return models that are in the access groups
|
||||
- If include_model_access_groups is True, it includes the 'keys' of the model_access_groups
|
||||
in the response - {"beta-models": ["gpt-4", "claude-v1"]} -> returns 'beta-models'
|
||||
"""
|
||||
all_models: List[str] = []
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
all_models = list(
|
||||
user_api_key_dict.models
|
||||
) # copy to avoid mutating cached objects
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = list(
|
||||
user_api_key_dict.team_models
|
||||
) # copy to avoid mutating cached objects
|
||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||
all_models = list(proxy_model_list) # copy to avoid mutating caller's list
|
||||
if include_model_access_groups:
|
||||
all_models.extend(model_access_groups.keys())
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups,
|
||||
all_models=all_models,
|
||||
include_model_access_groups=include_model_access_groups,
|
||||
)
|
||||
|
||||
# deduplicate while preserving order
|
||||
all_models = list(dict.fromkeys(all_models))
|
||||
|
||||
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_team_models(
|
||||
team_models: List[str],
|
||||
proxy_model_list: List[str],
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
- If model_access_groups is provided, only return models that are in the access groups
|
||||
"""
|
||||
all_models_set: Set[str] = set()
|
||||
if len(team_models) > 0:
|
||||
all_models_set.update(team_models)
|
||||
if SpecialModelNames.all_team_models.value in all_models_set:
|
||||
all_models_set.update(team_models)
|
||||
if SpecialModelNames.all_proxy_models.value in all_models_set:
|
||||
all_models_set.update(proxy_model_list)
|
||||
if include_model_access_groups:
|
||||
all_models_set.update(model_access_groups.keys())
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups,
|
||||
all_models=list(all_models_set),
|
||||
include_model_access_groups=include_model_access_groups,
|
||||
)
|
||||
|
||||
# deduplicate while preserving order
|
||||
all_models = list(dict.fromkeys(all_models))
|
||||
|
||||
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_complete_model_list(
|
||||
key_models: List[str],
|
||||
team_models: List[str],
|
||||
proxy_model_list: List[str],
|
||||
user_model: Optional[str],
|
||||
infer_model_from_keys: Optional[bool],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
model_access_groups: Dict[str, List[str]] = {},
|
||||
include_model_access_groups: Optional[bool] = False,
|
||||
only_model_access_groups: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
"""Logic for returning complete model list for a given key + team pair"""
|
||||
|
||||
"""
|
||||
- If key list is empty -> defer to team list
|
||||
- If team list is empty -> defer to proxy model list
|
||||
|
||||
If list contains wildcard -> return known provider models
|
||||
"""
|
||||
|
||||
unique_models = []
|
||||
|
||||
def append_unique(models):
|
||||
for model in models:
|
||||
if model not in unique_models:
|
||||
unique_models.append(model)
|
||||
|
||||
if key_models:
|
||||
append_unique(key_models)
|
||||
elif team_models:
|
||||
append_unique(team_models)
|
||||
else:
|
||||
append_unique(proxy_model_list)
|
||||
if include_model_access_groups:
|
||||
append_unique(list(model_access_groups.keys())) # TODO: keys order
|
||||
|
||||
if user_model:
|
||||
append_unique([user_model])
|
||||
|
||||
if infer_model_from_keys:
|
||||
valid_models = get_valid_models()
|
||||
append_unique(valid_models)
|
||||
|
||||
if only_model_access_groups:
|
||||
model_access_groups_to_return: List[str] = []
|
||||
for model in unique_models:
|
||||
if model in model_access_groups:
|
||||
model_access_groups_to_return.append(model)
|
||||
return model_access_groups_to_return
|
||||
|
||||
all_wildcard_models = _get_wildcard_models(
|
||||
unique_models=unique_models,
|
||||
return_wildcard_routes=return_wildcard_routes,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
complete_model_list = unique_models + all_wildcard_models
|
||||
|
||||
return complete_model_list
|
||||
|
||||
|
||||
def get_known_models_from_wildcard(
|
||||
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> List[str]:
|
||||
try:
|
||||
wildcard_provider_prefix, wildcard_suffix = wildcard_model.split("/", 1)
|
||||
except ValueError: # safely fail
|
||||
return []
|
||||
|
||||
# Use provider from litellm_params when available, otherwise from wildcard prefix
|
||||
# (e.g., "openai" from "openai/*" - needed for BYOK where wildcard isn't in router)
|
||||
if litellm_params is not None:
|
||||
try:
|
||||
provider = litellm_params.model.split("/", 1)[0]
|
||||
except ValueError:
|
||||
provider = wildcard_provider_prefix
|
||||
else:
|
||||
provider = wildcard_provider_prefix
|
||||
|
||||
# get all known provider models
|
||||
|
||||
wildcard_models = get_provider_models(
|
||||
provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if wildcard_models is None:
|
||||
return []
|
||||
if wildcard_suffix != "*":
|
||||
## CHECK IF PARTIAL FILTER e.g. `gemini-*`
|
||||
model_prefix = wildcard_suffix.replace("*", "")
|
||||
|
||||
is_partial_filter = any(
|
||||
wc_model.startswith(model_prefix) for wc_model in wildcard_models
|
||||
)
|
||||
if is_partial_filter:
|
||||
filtered_wildcard_models = [
|
||||
wc_model
|
||||
for wc_model in wildcard_models
|
||||
if wc_model.startswith(model_prefix)
|
||||
]
|
||||
wildcard_models = filtered_wildcard_models
|
||||
else:
|
||||
# add model prefix to wildcard models
|
||||
wildcard_models = [f"{model_prefix}{model}" for model in wildcard_models]
|
||||
|
||||
suffix_appended_wildcard_models = []
|
||||
for model in wildcard_models:
|
||||
if not model.startswith(wildcard_provider_prefix):
|
||||
model = f"{wildcard_provider_prefix}/{model}"
|
||||
suffix_appended_wildcard_models.append(model)
|
||||
return suffix_appended_wildcard_models or []
|
||||
|
||||
|
||||
def _get_wildcard_models(
|
||||
unique_models: List[str],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
models_to_remove = set()
|
||||
all_wildcard_models = []
|
||||
for model in unique_models:
|
||||
if _check_wildcard_routing(model=model):
|
||||
if (
|
||||
return_wildcard_routes
|
||||
): # will add the wildcard route to the list eg: anthropic/*.
|
||||
all_wildcard_models.append(model)
|
||||
|
||||
## get litellm params from model
|
||||
if llm_router is not None:
|
||||
model_list = llm_router.get_model_list(model_name=model)
|
||||
if model_list:
|
||||
for router_model in model_list:
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model,
|
||||
litellm_params=LiteLLM_Params(
|
||||
**router_model["litellm_params"] # type: ignore
|
||||
),
|
||||
)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
else:
|
||||
# Router has no deployment for this wildcard (e.g., BYOK team models)
|
||||
# Fall back to expanding from known provider models
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model, litellm_params=None
|
||||
)
|
||||
if wildcard_models:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
else:
|
||||
# get all known provider models
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model, litellm_params=None
|
||||
)
|
||||
|
||||
if wildcard_models:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
|
||||
for model in models_to_remove:
|
||||
unique_models.remove(model)
|
||||
|
||||
return all_wildcard_models
|
||||
|
||||
|
||||
def get_all_fallbacks(
|
||||
model: str,
|
||||
llm_router: Optional[Router] = None,
|
||||
fallback_type: str = "general",
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all fallbacks for a given model from the router's fallback configuration.
|
||||
|
||||
Args:
|
||||
model: The model name to get fallbacks for
|
||||
llm_router: The LiteLLM router instance
|
||||
fallback_type: Type of fallback ("general", "context_window", "content_policy")
|
||||
|
||||
Returns:
|
||||
List of fallback model names. Empty list if no fallbacks found.
|
||||
"""
|
||||
if llm_router is None:
|
||||
return []
|
||||
|
||||
# Get the appropriate fallback list based on type
|
||||
fallbacks_config: list = []
|
||||
if fallback_type == "general":
|
||||
fallbacks_config = getattr(llm_router, "fallbacks", [])
|
||||
elif fallback_type == "context_window":
|
||||
fallbacks_config = getattr(llm_router, "context_window_fallbacks", [])
|
||||
elif fallback_type == "content_policy":
|
||||
fallbacks_config = getattr(llm_router, "content_policy_fallbacks", [])
|
||||
else:
|
||||
verbose_proxy_logger.warning(f"Unknown fallback_type: {fallback_type}")
|
||||
return []
|
||||
|
||||
if not fallbacks_config:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Use existing function to get fallback model group
|
||||
fallback_model_group, _ = get_fallback_model_group(
|
||||
fallbacks=fallbacks_config, model_group=model
|
||||
)
|
||||
|
||||
if fallback_model_group is None:
|
||||
return []
|
||||
|
||||
return fallback_model_group
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error getting fallbacks for model {model}: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,222 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
|
||||
|
||||
class Oauth2Handler:
|
||||
"""
|
||||
Handles OAuth2 token validation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _is_introspection_endpoint(
|
||||
token_info_endpoint: str,
|
||||
oauth_client_id: Optional[str],
|
||||
oauth_client_secret: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET).
|
||||
|
||||
Args:
|
||||
token_info_endpoint: The OAuth2 endpoint URL
|
||||
oauth_client_id: OAuth2 client ID
|
||||
oauth_client_secret: OAuth2 client secret
|
||||
|
||||
Returns:
|
||||
bool: True if this is an introspection endpoint
|
||||
"""
|
||||
return (
|
||||
"introspect" in token_info_endpoint.lower()
|
||||
and oauth_client_id is not None
|
||||
and oauth_client_secret is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_introspection_request(
|
||||
token: str,
|
||||
oauth_client_id: Optional[str],
|
||||
oauth_client_secret: Optional[str],
|
||||
) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||||
"""
|
||||
Prepare headers and data for OAuth2 introspection endpoint (RFC 7662).
|
||||
|
||||
Args:
|
||||
token: The OAuth2 token to validate
|
||||
oauth_client_id: OAuth2 client ID
|
||||
oauth_client_secret: OAuth2 client secret
|
||||
|
||||
Returns:
|
||||
Tuple of (headers, data) for the introspection request
|
||||
"""
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {"token": token}
|
||||
|
||||
# Add client authentication if credentials are provided
|
||||
if oauth_client_id and oauth_client_secret:
|
||||
# Use HTTP Basic authentication for client credentials
|
||||
credentials = base64.b64encode(
|
||||
f"{oauth_client_id}:{oauth_client_secret}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
elif oauth_client_id:
|
||||
# For public clients, include client_id in the request body
|
||||
data["client_id"] = oauth_client_id
|
||||
|
||||
return headers, data
|
||||
|
||||
@staticmethod
|
||||
def _prepare_token_info_request(token: str) -> Dict[str, str]:
|
||||
"""
|
||||
Prepare headers for generic token info endpoint.
|
||||
|
||||
Args:
|
||||
token: The OAuth2 token to validate
|
||||
|
||||
Returns:
|
||||
Dict of headers for the token info request
|
||||
"""
|
||||
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
@staticmethod
|
||||
def _extract_user_info(
|
||||
response_data: Dict,
|
||||
user_id_field_name: str,
|
||||
user_role_field_name: str,
|
||||
user_team_id_field_name: str,
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract user information from OAuth2 response.
|
||||
|
||||
Args:
|
||||
response_data: The response data from OAuth2 endpoint
|
||||
user_id_field_name: Field name for user ID
|
||||
user_role_field_name: Field name for user role
|
||||
user_team_id_field_name: Field name for team ID
|
||||
|
||||
Returns:
|
||||
Tuple of (user_id, user_role, user_team_id)
|
||||
"""
|
||||
user_id = response_data.get(user_id_field_name)
|
||||
user_team_id = response_data.get(user_team_id_field_name)
|
||||
user_role = response_data.get(user_role_field_name)
|
||||
|
||||
return user_id, user_role, user_team_id
|
||||
|
||||
@staticmethod
|
||||
async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Makes a request to the token introspection endpoint to validate the OAuth2 token.
|
||||
|
||||
This function implements OAuth2 token introspection according to RFC 7662.
|
||||
It supports both generic token info endpoints (GET) and OAuth2 introspection endpoints (POST).
|
||||
|
||||
Args:
|
||||
token (str): The OAuth2 token to validate.
|
||||
|
||||
Returns:
|
||||
UserAPIKeyAuth: If the token is valid, containing user information.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
"Oauth2 token validation is only available for premium users"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
|
||||
|
||||
# Get the token info endpoint from environment variable
|
||||
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
|
||||
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
|
||||
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
|
||||
user_team_id_field_name = os.environ.get(
|
||||
"OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id"
|
||||
)
|
||||
|
||||
# OAuth2 client credentials for introspection endpoint authentication
|
||||
oauth_client_id = os.environ.get("OAUTH_CLIENT_ID")
|
||||
oauth_client_secret = os.environ.get("OAUTH_CLIENT_SECRET")
|
||||
|
||||
if not token_info_endpoint:
|
||||
raise ValueError(
|
||||
"OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set"
|
||||
)
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
|
||||
# Determine if this is an introspection endpoint (requires POST) or token info endpoint (uses GET)
|
||||
is_introspection_endpoint = Oauth2Handler._is_introspection_endpoint(
|
||||
token_info_endpoint=token_info_endpoint,
|
||||
oauth_client_id=oauth_client_id,
|
||||
oauth_client_secret=oauth_client_secret,
|
||||
)
|
||||
|
||||
try:
|
||||
if is_introspection_endpoint:
|
||||
# OAuth2 Token Introspection (RFC 7662) - requires POST with form data
|
||||
verbose_proxy_logger.debug("Using OAuth2 introspection endpoint (POST)")
|
||||
|
||||
headers, data = Oauth2Handler._prepare_introspection_request(
|
||||
token=token,
|
||||
oauth_client_id=oauth_client_id,
|
||||
oauth_client_secret=oauth_client_secret,
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
token_info_endpoint, headers=headers, data=data
|
||||
)
|
||||
else:
|
||||
# Generic token info endpoint - uses GET with Bearer token
|
||||
verbose_proxy_logger.debug("Using generic token info endpoint (GET)")
|
||||
headers = Oauth2Handler._prepare_token_info_request(token=token)
|
||||
response = await client.get(token_info_endpoint, headers=headers)
|
||||
|
||||
# if it's a bad token we expect it to raise an HTTPStatusError
|
||||
response.raise_for_status()
|
||||
|
||||
# If we get here, the request was successful
|
||||
data = response.json()
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Oauth2 token validation for token=%s, response from endpoint=%s",
|
||||
token,
|
||||
data,
|
||||
)
|
||||
|
||||
# For introspection endpoints, check if token is active
|
||||
if is_introspection_endpoint and not data.get("active", True):
|
||||
raise ValueError("Token is not active")
|
||||
|
||||
# Extract user information from response
|
||||
user_id, user_role, user_team_id = Oauth2Handler._extract_user_info(
|
||||
response_data=data,
|
||||
user_id_field_name=user_id_field_name,
|
||||
user_role_field_name=user_role_field_name,
|
||||
user_team_id_field_name=user_team_id_field_name,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
api_key=token,
|
||||
team_id=user_team_id,
|
||||
user_id=user_id,
|
||||
user_role=cast(LitellmUserRoles, user_role),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# This will catch any 4xx or 5xx errors
|
||||
raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
|
||||
except Exception as e:
|
||||
# This will catch any other errors (like network issues)
|
||||
raise ValueError(f"An error occurred during token validation: {e}")
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handle request from oauth2 proxy.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
verbose_proxy_logger.debug("Handling oauth2 proxy request")
|
||||
# Define the OAuth2 config mappings
|
||||
oauth2_config_mappings: Dict[str, str] = (
|
||||
general_settings.get("oauth2_config_mappings") or {}
|
||||
)
|
||||
verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}")
|
||||
|
||||
if not oauth2_config_mappings:
|
||||
raise ValueError("Oauth2 config mappings not found in general_settings")
|
||||
# Initialize a dictionary to store the mapped values
|
||||
auth_data: Dict[str, Any] = {}
|
||||
|
||||
# Extract values from headers based on the mappings
|
||||
for key, header in oauth2_config_mappings.items():
|
||||
value = request.headers.get(header)
|
||||
if value:
|
||||
# Convert max_budget to float if present
|
||||
if key == "max_budget":
|
||||
auth_data[key] = float(value)
|
||||
# Convert models to list if present
|
||||
elif key == "models":
|
||||
auth_data[key] = [model.strip() for model in value.split(",")]
|
||||
else:
|
||||
auth_data[key] = value
|
||||
verbose_proxy_logger.debug(
|
||||
f"Auth data before creating UserAPIKeyAuth object: {auth_data}"
|
||||
)
|
||||
user_api_key_auth = UserAPIKeyAuth(**auth_data)
|
||||
verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}")
|
||||
# Create and return UserAPIKeyAuth object
|
||||
return user_api_key_auth
|
||||
@@ -0,0 +1,9 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwcNBabWBZzrDhFAuA4Fh
|
||||
FhIcA3rF7vrLb8+1yhF2U62AghQp9nStyuJRjxMUuldWgJ1yRJ2s7UffVw5r8DeA
|
||||
dqXPD+w+3LCNwqJGaIKN08QGJXNArM3QtMaN0RTzAyQ4iibN1r6609W5muK9wGp0
|
||||
b1j5+iDUmf0ynItnhvaX6B8Xoaflc3WD/UBdrygLmsU5uR3XC86+/8ILoSZH3HtN
|
||||
6FJmWhlhjS2TR1cKZv8K5D0WuADTFf5MF8jYFR+uORPj5Pe/EJlLGN26Lfn2QnGu
|
||||
XgbPF6nCGwZ0hwH1Xkn3xzGaJ4xBEC761wqp5cHxWSDktHyFKnLbP3jVeegjVIHh
|
||||
pQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -0,0 +1,187 @@
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def init_rds_client(
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param) # type: ignore
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### SET REGION NAME
|
||||
region_name = aws_region_name
|
||||
if aws_region_name:
|
||||
region_name = aws_region_name
|
||||
elif litellm_aws_region_name:
|
||||
region_name = litellm_aws_region_name
|
||||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise Exception(
|
||||
"AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
)
|
||||
|
||||
import boto3
|
||||
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config( # type: ignore
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
else:
|
||||
config = boto3.session.Config() # type: ignore
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
try:
|
||||
oidc_token = open(aws_web_identity_token).read() # check if filepath
|
||||
except Exception:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise Exception(
|
||||
"OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
elif aws_profile_name is not None:
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
|
||||
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||
service_name="rds",
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def generate_iam_auth_token(
|
||||
db_host, db_port, db_user, client: Optional[Any] = None
|
||||
) -> str:
|
||||
from urllib.parse import quote
|
||||
|
||||
if client is None:
|
||||
boto_client = init_rds_client(
|
||||
aws_region_name=os.getenv("AWS_REGION_NAME"),
|
||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_name=os.getenv("AWS_SESSION_NAME"),
|
||||
aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
|
||||
aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
|
||||
aws_web_identity_token=os.getenv(
|
||||
"AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
),
|
||||
)
|
||||
else:
|
||||
boto_client = client
|
||||
|
||||
token = boto_client.generate_db_auth_token(
|
||||
DBHostname=db_host, Port=db_port, DBUsername=db_user
|
||||
)
|
||||
cleaned_token = quote(token, safe="")
|
||||
|
||||
return cleaned_token
|
||||
@@ -0,0 +1,669 @@
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLMRoutes,
|
||||
LitellmUserRoles,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
||||
from .auth_checks_organization import _user_is_org_admin
|
||||
|
||||
|
||||
class RouteChecks:
|
||||
@staticmethod
|
||||
def should_call_route(route: str, valid_token: UserAPIKeyAuth):
|
||||
"""
|
||||
Check if management route is disabled and raise exception
|
||||
"""
|
||||
try:
|
||||
from litellm_enterprise.proxy.auth.route_checks import EnterpriseRouteChecks
|
||||
|
||||
EnterpriseRouteChecks.should_call_route(route=route)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check if Virtual Key is allowed to call the route - Applies to all Roles
|
||||
RouteChecks.is_virtual_key_allowed_to_call_route(
|
||||
route=route, valid_token=valid_token
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_virtual_key_allowed_to_call_route(
|
||||
route: str, valid_token: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
"""
|
||||
Raises Exception if Virtual Key is not allowed to call the route
|
||||
"""
|
||||
|
||||
# Only check if valid_token.allowed_routes is set and is a list with at least one item
|
||||
if valid_token.allowed_routes is None:
|
||||
return True
|
||||
if not isinstance(valid_token.allowed_routes, list):
|
||||
return True
|
||||
if len(valid_token.allowed_routes) == 0:
|
||||
return True
|
||||
|
||||
# explicit check for allowed routes (exact match or prefix match)
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_allowed_route(
|
||||
route=route, allowed_route=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
## check if 'allowed_route' is a field name in LiteLLMRoutes
|
||||
if any(
|
||||
allowed_route in LiteLLMRoutes._member_names_
|
||||
for allowed_route in valid_token.allowed_routes
|
||||
):
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if allowed_route in LiteLLMRoutes._member_names_:
|
||||
if RouteChecks.check_route_access(
|
||||
route=route,
|
||||
allowed_routes=LiteLLMRoutes._member_map_[allowed_route].value,
|
||||
):
|
||||
return True
|
||||
|
||||
################################################
|
||||
# For llm_api_routes, also check registered pass-through endpoints
|
||||
################################################
|
||||
if allowed_route == "llm_api_routes":
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
InitPassThroughEndpointHelpers,
|
||||
)
|
||||
|
||||
if InitPassThroughEndpointHelpers.is_registered_pass_through_route(
|
||||
route=route
|
||||
):
|
||||
return True
|
||||
|
||||
# check if wildcard pattern is allowed
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Virtual key is not allowed to call this route. Only allowed to call routes: {valid_token.allowed_routes}. Tried to call route: {route}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mask_user_id(user_id: str) -> str:
|
||||
"""
|
||||
Mask user_id to prevent leaking sensitive information in error messages
|
||||
|
||||
Args:
|
||||
user_id (str): The user_id to mask
|
||||
|
||||
Returns:
|
||||
str: Masked user_id showing only first 2 and last 2 characters
|
||||
"""
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
if not user_id or len(user_id) <= 4:
|
||||
return "***"
|
||||
|
||||
# Use SensitiveDataMasker with custom configuration for user_id
|
||||
masker = SensitiveDataMasker(visible_prefix=6, visible_suffix=2, mask_char="*")
|
||||
|
||||
return masker._mask_value(user_id)
|
||||
|
||||
@staticmethod
|
||||
def _raise_admin_only_route_exception(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
route: str,
|
||||
) -> None:
|
||||
"""
|
||||
Raise exception for routes that require proxy admin access
|
||||
|
||||
Args:
|
||||
user_obj (Optional[LiteLLM_UserTable]): The user object
|
||||
route (str): The route being accessed
|
||||
|
||||
Raises:
|
||||
Exception: With user role and masked user_id information
|
||||
"""
|
||||
user_role = "unknown"
|
||||
user_id = "unknown"
|
||||
if user_obj is not None:
|
||||
user_role = user_obj.user_role or "unknown"
|
||||
user_id = user_obj.user_id or "unknown"
|
||||
|
||||
masked_user_id = RouteChecks._mask_user_id(user_id)
|
||||
raise Exception(
|
||||
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={masked_user_id}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def non_proxy_admin_allowed_routes_check(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
_user_role: Optional[LitellmUserRoles],
|
||||
route: str,
|
||||
request: Request,
|
||||
valid_token: UserAPIKeyAuth,
|
||||
request_data: dict,
|
||||
):
|
||||
"""
|
||||
Checks if Non Proxy Admin User is allowed to access the route
|
||||
"""
|
||||
|
||||
# Check user has defined custom admin routes
|
||||
RouteChecks.custom_admin_only_route_check(
|
||||
route=route,
|
||||
)
|
||||
|
||||
if RouteChecks.is_llm_api_route(route=route):
|
||||
pass
|
||||
elif RouteChecks.is_info_route(route=route):
|
||||
# check if user allowed to call an info route
|
||||
if route == "/key/info":
|
||||
# handled by function itself
|
||||
pass
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
user_id = query_params.get("user_id")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
|
||||
)
|
||||
if user_id and user_id != valid_token.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="key not allowed to access this user's info. user_id={}, key's user_id={}".format(
|
||||
user_id, valid_token.user_id
|
||||
),
|
||||
)
|
||||
elif route == "/model/info":
|
||||
# /model/info just shows models user has access to
|
||||
pass
|
||||
elif route == "/team/info":
|
||||
pass # handled by function itself
|
||||
elif (
|
||||
route in LiteLLMRoutes.global_spend_tracking_routes.value
|
||||
and getattr(valid_token, "permissions", None) is not None
|
||||
and "get_spend_routes" in getattr(valid_token, "permissions", [])
|
||||
):
|
||||
pass
|
||||
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
|
||||
RouteChecks._check_proxy_admin_viewer_access(
|
||||
route=route,
|
||||
_user_role=_user_role,
|
||||
request_data=request_data,
|
||||
)
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||
and RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
|
||||
)
|
||||
):
|
||||
pass
|
||||
elif _user_is_org_admin(
|
||||
request_data=request_data, user_object=user_obj
|
||||
) and RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value
|
||||
):
|
||||
pass
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
and RouteChecks.check_route_access(
|
||||
route=route,
|
||||
allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
|
||||
)
|
||||
):
|
||||
pass
|
||||
elif RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value
|
||||
): # routes that manage their own allowed/disallowed logic
|
||||
pass
|
||||
elif route.startswith("/v1/mcp/") or route.startswith("/mcp-rest/"):
|
||||
pass # authN/authZ handled by api itself
|
||||
elif RouteChecks.check_passthrough_route_access(
|
||||
route=route, user_api_key_dict=valid_token
|
||||
):
|
||||
pass
|
||||
elif valid_token.allowed_routes is not None:
|
||||
# check if route is in allowed_routes (exact match or prefix match)
|
||||
route_allowed = False
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_allowed_route(
|
||||
route=route, allowed_route=allowed_route
|
||||
):
|
||||
route_allowed = True
|
||||
break
|
||||
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
route_allowed = True
|
||||
break
|
||||
|
||||
if not route_allowed:
|
||||
RouteChecks._raise_admin_only_route_exception(
|
||||
user_obj=user_obj, route=route
|
||||
)
|
||||
else:
|
||||
RouteChecks._raise_admin_only_route_exception(
|
||||
user_obj=user_obj, route=route
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def custom_admin_only_route_check(route: str):
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
if "admin_only_routes" in general_settings:
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.error(
|
||||
f"Trying to use 'admin_only_routes' this is an Enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return
|
||||
if route in general_settings["admin_only_routes"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route. Route={route} is an admin only route",
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_llm_api_route(route: str) -> bool:
|
||||
"""
|
||||
Helper to checks if provided route is an OpenAI route
|
||||
|
||||
|
||||
Returns:
|
||||
- True: if route is an OpenAI route
|
||||
- False: if route is not an OpenAI route
|
||||
"""
|
||||
# Ensure route is a string before performing checks
|
||||
if not isinstance(route, str):
|
||||
return False
|
||||
|
||||
if route in LiteLLMRoutes.openai_routes.value:
|
||||
return True
|
||||
|
||||
if route in LiteLLMRoutes.anthropic_routes.value:
|
||||
return True
|
||||
|
||||
if route in LiteLLMRoutes.google_routes.value:
|
||||
return True
|
||||
|
||||
if RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
|
||||
):
|
||||
return True
|
||||
|
||||
if RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.agent_routes.value
|
||||
):
|
||||
return True
|
||||
|
||||
if route in LiteLLMRoutes.litellm_native_routes.value:
|
||||
return True
|
||||
|
||||
# fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
# Check for routes with placeholders or wildcard patterns
|
||||
for openai_route in LiteLLMRoutes.openai_routes.value:
|
||||
# Replace placeholders with regex pattern
|
||||
# placeholders are written as "/threads/{thread_id}"
|
||||
if "{" in openai_route:
|
||||
if RouteChecks._route_matches_pattern(
|
||||
route=route, pattern=openai_route
|
||||
):
|
||||
return True
|
||||
# Check for wildcard patterns like "/containers/*"
|
||||
if RouteChecks._is_wildcard_pattern(pattern=openai_route):
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=openai_route
|
||||
):
|
||||
return True
|
||||
|
||||
# Check for Google routes with placeholders like "/v1beta/models/{model_name}:generateContent"
|
||||
for google_route in LiteLLMRoutes.google_routes.value:
|
||||
if "{" in google_route:
|
||||
if RouteChecks._route_matches_pattern(
|
||||
route=route, pattern=google_route
|
||||
):
|
||||
return True
|
||||
|
||||
# Check for Anthropic routes with placeholders
|
||||
for anthropic_route in LiteLLMRoutes.anthropic_routes.value:
|
||||
if "{" in anthropic_route:
|
||||
if RouteChecks._route_matches_pattern(
|
||||
route=route, pattern=anthropic_route
|
||||
):
|
||||
return True
|
||||
|
||||
if RouteChecks._is_azure_openai_route(route=route):
|
||||
return True
|
||||
|
||||
for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value:
|
||||
if _llm_passthrough_route in route:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_management_route(route: str) -> bool:
|
||||
"""
|
||||
Check if route is a management route
|
||||
"""
|
||||
return route in LiteLLMRoutes.management_routes.value
|
||||
|
||||
@staticmethod
|
||||
def is_info_route(route: str) -> bool:
|
||||
"""
|
||||
Check if route is an info route
|
||||
"""
|
||||
return route in LiteLLMRoutes.info_routes.value
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_openai_route(route: str) -> bool:
|
||||
"""
|
||||
Check if route is a route from AzureOpenAI SDK client
|
||||
|
||||
eg.
|
||||
route='/openai/deployments/vertex_ai/gemini-1.5-flash/chat/completions'
|
||||
"""
|
||||
# Ensure route is a string before attempting regex matching
|
||||
if not isinstance(route, str):
|
||||
return False
|
||||
# Add support for deployment and engine model paths
|
||||
deployment_pattern = r"^/openai/deployments/[^/]+/[^/]+/chat/completions$"
|
||||
engine_pattern = r"^/engines/[^/]+/chat/completions$"
|
||||
|
||||
if re.match(deployment_pattern, route) or re.match(engine_pattern, route):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_pattern(route: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if route matches the pattern placed in proxy/_types.py
|
||||
|
||||
Example:
|
||||
- pattern: "/threads/{thread_id}"
|
||||
- route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
- returns: True
|
||||
|
||||
|
||||
- pattern: "/key/{token_id}/regenerate"
|
||||
- route: "/key/regenerate/82akk800000000jjsk"
|
||||
- returns: False, pattern is "/key/{token_id}/regenerate"
|
||||
"""
|
||||
# Ensure route is a string before attempting regex matching
|
||||
if not isinstance(route, str):
|
||||
return False
|
||||
|
||||
def _placeholder_to_regex(match: re.Match) -> str:
|
||||
placeholder = match.group(0).strip("{}")
|
||||
if placeholder.endswith(":path"):
|
||||
# allow "/" in the placeholder value, but don't eat the route suffix after ":"
|
||||
return r"[^:]+"
|
||||
return r"[^/]+"
|
||||
|
||||
pattern = re.sub(r"\{[^}]+\}", _placeholder_to_regex, pattern)
|
||||
# Anchor the pattern to match the entire string
|
||||
pattern = f"^{pattern}$"
|
||||
if re.match(pattern, route):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_wildcard_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
Check if pattern is a wildcard pattern
|
||||
"""
|
||||
return pattern.endswith("*")
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if route matches the wildcard pattern
|
||||
|
||||
eg.
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users"
|
||||
- returns: True
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/chat/completions"
|
||||
- returns: False
|
||||
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users/123"
|
||||
- returns: True
|
||||
|
||||
"""
|
||||
if pattern.endswith("*"):
|
||||
# Get the prefix (everything before the wildcard)
|
||||
prefix = pattern[:-1]
|
||||
return route.startswith(prefix)
|
||||
else:
|
||||
# If there's no wildcard, the pattern and route should match exactly
|
||||
return route == pattern
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_allowed_route(route: str, allowed_route: str) -> bool:
|
||||
"""
|
||||
Check if route matches the allowed_route pattern.
|
||||
Supports both exact match and prefix match.
|
||||
|
||||
Examples:
|
||||
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6" -> True (exact match)
|
||||
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-6/v1/chat/completions" -> True (prefix match)
|
||||
- allowed_route="/fake-openai-proxy-6", route="/fake-openai-proxy-600" -> False (not a valid prefix)
|
||||
|
||||
Args:
|
||||
route: The actual route being accessed
|
||||
allowed_route: The allowed route pattern
|
||||
|
||||
Returns:
|
||||
bool: True if route matches (exact or prefix), False otherwise
|
||||
"""
|
||||
# Exact match
|
||||
if route == allowed_route:
|
||||
return True
|
||||
# Prefix match - ensure we add "/" to prevent false matches like /fake-openai-proxy-600
|
||||
if route.startswith(allowed_route + "/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
||||
"""
|
||||
Check if a route has access by checking both exact matches and patterns
|
||||
|
||||
Args:
|
||||
route (str): The route to check
|
||||
allowed_routes (list): List of allowed routes/patterns
|
||||
|
||||
Returns:
|
||||
bool: True if route is allowed, False otherwise
|
||||
"""
|
||||
#########################################################
|
||||
# exact match route is in allowed_routes
|
||||
#########################################################
|
||||
if route in allowed_routes:
|
||||
return True
|
||||
|
||||
#########################################################
|
||||
# wildcard match route is in allowed_routes
|
||||
# e.g calling /anthropic/v1/messages is allowed if allowed_routes has /anthropic/*
|
||||
#########################################################
|
||||
wildcard_allowed_routes = [
|
||||
route
|
||||
for route in allowed_routes
|
||||
if RouteChecks._is_wildcard_pattern(pattern=route)
|
||||
]
|
||||
for allowed_route in wildcard_allowed_routes:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
#########################################################
|
||||
# pattern match route is in allowed_routes
|
||||
# pattern: "/threads/{thread_id}"
|
||||
# route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
# returns: True
|
||||
#########################################################
|
||||
if any( # Check pattern match
|
||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||
for allowed_route in allowed_routes
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_passthrough_route_access(
|
||||
route: str, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
"""
|
||||
Check if route is a passthrough route.
|
||||
Supports both exact match and prefix match.
|
||||
"""
|
||||
metadata = user_api_key_dict.metadata
|
||||
team_metadata = user_api_key_dict.team_metadata or {}
|
||||
if metadata is None and team_metadata is None:
|
||||
return False
|
||||
if (
|
||||
"allowed_passthrough_routes" not in metadata
|
||||
and "allowed_passthrough_routes" not in team_metadata
|
||||
):
|
||||
return False
|
||||
if (
|
||||
metadata.get("allowed_passthrough_routes") is None
|
||||
and team_metadata.get("allowed_passthrough_routes") is None
|
||||
):
|
||||
return False
|
||||
|
||||
allowed_passthrough_routes = (
|
||||
metadata.get("allowed_passthrough_routes")
|
||||
or team_metadata.get("allowed_passthrough_routes")
|
||||
or []
|
||||
)
|
||||
|
||||
# Check if route matches any allowed passthrough route (exact or prefix match)
|
||||
for allowed_route in allowed_passthrough_routes:
|
||||
if RouteChecks._route_matches_allowed_route(
|
||||
route=route, allowed_route=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_assistants_api_request(request: Request) -> bool:
|
||||
"""
|
||||
Returns True if `thread` or `assistant` is in the request path
|
||||
|
||||
Args:
|
||||
request (Request): The request object
|
||||
|
||||
Returns:
|
||||
bool: True if `thread` or `assistant` is in the request path, False otherwise
|
||||
"""
|
||||
if "thread" in request.url.path or "assistant" in request.url.path:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_generate_content_route(route: str) -> bool:
|
||||
"""
|
||||
Returns True if this is a google generateContent or streamGenerateContent route
|
||||
|
||||
These routes from google allow passing key=api_key in the query params
|
||||
"""
|
||||
if "generateContent" in route:
|
||||
return True
|
||||
if "streamGenerateContent" in route:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_proxy_admin_viewer_access(
|
||||
route: str,
|
||||
_user_role: str,
|
||||
request_data: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Check access for PROXY_ADMIN_VIEW_ONLY role
|
||||
"""
|
||||
if RouteChecks.is_llm_api_route(route=route):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this OpenAI routes, role= {_user_role}",
|
||||
)
|
||||
|
||||
# Check if this is a write operation on management routes
|
||||
if RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.management_routes.value
|
||||
):
|
||||
# For management routes, only allow read operations or specific allowed updates
|
||||
if route == "/user/update":
|
||||
# Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY
|
||||
if request_data is not None and isinstance(request_data, dict):
|
||||
_params_updated = request_data.keys()
|
||||
for param in _params_updated:
|
||||
if param not in ["user_email", "password"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route} and updating invalid param: {param}. only user_email and password can be updated",
|
||||
)
|
||||
elif (
|
||||
route
|
||||
in [
|
||||
"/user/new",
|
||||
"/user/delete",
|
||||
"/team/new",
|
||||
"/team/update",
|
||||
"/team/delete",
|
||||
"/model/new",
|
||||
"/model/update",
|
||||
"/model/delete",
|
||||
"/key/generate",
|
||||
"/key/delete",
|
||||
"/key/update",
|
||||
"/key/regenerate",
|
||||
"/key/service-account/generate",
|
||||
"/key/block",
|
||||
"/key/unblock",
|
||||
]
|
||||
or route.startswith("/key/")
|
||||
and route.endswith("/regenerate")
|
||||
):
|
||||
# Block write operations for PROXY_ADMIN_VIEW_ONLY
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||
)
|
||||
# Allow read operations on management routes (like /user/info, /team/info, /model/info)
|
||||
return
|
||||
elif RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.admin_viewer_routes.value
|
||||
):
|
||||
# Allow access to admin viewer routes (read-only admin endpoints)
|
||||
return
|
||||
elif RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.global_spend_tracking_routes.value
|
||||
):
|
||||
# Allow access to global spend tracking routes (read-only spend endpoints)
|
||||
# proxy_admin_viewer role description: "view all keys, view all spend"
|
||||
return
|
||||
else:
|
||||
# For other routes, block access
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,967 @@
|
||||
######################################################################
|
||||
|
||||
# /v1/batches Endpoints
|
||||
|
||||
|
||||
######################################################################
|
||||
import asyncio
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_headers,
|
||||
get_custom_llm_provider_from_request_query,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
decode_model_from_file_id,
|
||||
encode_batch_response_ids,
|
||||
encode_file_id_with_model,
|
||||
get_batch_from_database,
|
||||
get_credentials_for_model,
|
||||
get_model_id_from_unified_batch_id,
|
||||
get_models_from_unified_file_id,
|
||||
get_original_file_id,
|
||||
prepare_data_with_credentials,
|
||||
resolve_input_file_id_to_unified,
|
||||
update_batch_in_database,
|
||||
)
|
||||
from litellm.proxy.utils import handle_exception_on_proxy, is_known_model
|
||||
from litellm.types.llms.openai import LiteLLMBatchCreateRequest
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def create_batch( # noqa: PLR0915
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create large batches of API requests for asynchronous processing.
|
||||
This is the equivalent of POST https://api.openai.com/v1/batch
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input_file_id": "file-abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
try:
|
||||
data = await _read_request_body(request=request)
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
||||
)
|
||||
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
(
|
||||
data,
|
||||
litellm_logging_obj,
|
||||
) = await base_llm_response_processor.common_processing_pre_call_logic(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
proxy_config=proxy_config,
|
||||
route_type="acreate_batch",
|
||||
)
|
||||
|
||||
## check if model is a loadbalanced model
|
||||
router_model: Optional[str] = None
|
||||
is_router_model = False
|
||||
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||
router_model = data.get("model", None)
|
||||
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
|
||||
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or data.pop("custom_llm_provider", None)
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or "openai"
|
||||
)
|
||||
_create_batch_data = LiteLLMBatchCreateRequest(**data)
|
||||
|
||||
# Apply team-level batch output expiry enforcement
|
||||
team_metadata = user_api_key_dict.team_metadata or {}
|
||||
enforced_batch_expiry = team_metadata.get("enforced_batch_output_expires_after")
|
||||
if enforced_batch_expiry is not None:
|
||||
if (
|
||||
"anchor" not in enforced_batch_expiry
|
||||
or "seconds" not in enforced_batch_expiry
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Server configuration error: team metadata field 'enforced_batch_output_expires_after' is malformed - must contain 'anchor' and 'seconds' keys. Contact your team or proxy admin to fix this setting.",
|
||||
},
|
||||
)
|
||||
if enforced_batch_expiry["anchor"] != "created_at":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": f"Server configuration error: team metadata field 'enforced_batch_output_expires_after' has invalid anchor '{enforced_batch_expiry['anchor']}' - must be 'created_at'. Contact your team or proxy admin to fix this setting.",
|
||||
},
|
||||
)
|
||||
_create_batch_data["output_expires_after"] = {
|
||||
"anchor": "created_at",
|
||||
"seconds": int(enforced_batch_expiry["seconds"]),
|
||||
}
|
||||
|
||||
input_file_id = _create_batch_data.get("input_file_id", None)
|
||||
unified_file_id: Union[str, Literal[False]] = False
|
||||
|
||||
model_from_file_id = None
|
||||
if input_file_id:
|
||||
model_from_file_id = decode_model_from_file_id(input_file_id)
|
||||
unified_file_id = _is_base64_encoded_unified_file_id(input_file_id)
|
||||
|
||||
# SCENARIO 1: File ID is encoded with model info
|
||||
if model_from_file_id is not None and input_file_id:
|
||||
credentials = get_credentials_for_model(
|
||||
llm_router=llm_router,
|
||||
model_id=model_from_file_id,
|
||||
operation_context="batch creation (file created with model)",
|
||||
)
|
||||
|
||||
original_file_id = get_original_file_id(input_file_id)
|
||||
_create_batch_data["input_file_id"] = original_file_id
|
||||
prepare_data_with_credentials(
|
||||
data=_create_batch_data, # type: ignore
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Create batch using model credentials
|
||||
response = await litellm.acreate_batch(
|
||||
custom_llm_provider=credentials["custom_llm_provider"],
|
||||
**_create_batch_data, # type: ignore
|
||||
)
|
||||
|
||||
# Encode the batch ID and related file IDs with model information
|
||||
if response and hasattr(response, "id") and response.id:
|
||||
original_batch_id = response.id
|
||||
encoded_batch_id = encode_file_id_with_model(
|
||||
file_id=original_batch_id,
|
||||
model=model_from_file_id,
|
||||
id_type="batch",
|
||||
)
|
||||
response.id = encoded_batch_id
|
||||
|
||||
if hasattr(response, "output_file_id") and response.output_file_id:
|
||||
response.output_file_id = encode_file_id_with_model(
|
||||
file_id=response.output_file_id, model=model_from_file_id
|
||||
)
|
||||
|
||||
if hasattr(response, "error_file_id") and response.error_file_id:
|
||||
response.error_file_id = encode_file_id_with_model(
|
||||
file_id=response.error_file_id, model=model_from_file_id
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Created batch using model: {model_from_file_id}, "
|
||||
f"original_batch_id: {original_batch_id}, encoded: {encoded_batch_id}"
|
||||
)
|
||||
|
||||
response.input_file_id = input_file_id
|
||||
|
||||
elif (
|
||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||
and is_router_model
|
||||
and router_model is not None
|
||||
):
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
|
||||
elif (
|
||||
unified_file_id and input_file_id
|
||||
): # litellm_proxy:application/octet-stream;unified_id,c4843482-b176-4901-8292-7523fd0f2c6e;target_model_names,gpt-4o-mini
|
||||
target_model_names = get_models_from_unified_file_id(unified_file_id)
|
||||
## EXPECTS 1 MODEL
|
||||
if len(target_model_names) != 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Expected 1 model, got {}".format(
|
||||
len(target_model_names)
|
||||
)
|
||||
},
|
||||
)
|
||||
model = target_model_names[0]
|
||||
_create_batch_data["model"] = model
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.acreate_batch(**_create_batch_data)
|
||||
response.input_file_id = input_file_id
|
||||
response._hidden_params["unified_file_id"] = unified_file_id
|
||||
else:
|
||||
# Check if model specified via header/query/body param
|
||||
model_param = (
|
||||
data.get("model")
|
||||
or request.query_params.get("model")
|
||||
or request.headers.get("x-litellm-model")
|
||||
)
|
||||
|
||||
# SCENARIO 2 & 3: Model from header/query OR custom_llm_provider fallback
|
||||
if model_param:
|
||||
# SCENARIO 2: Use model-based routing from header/query/body
|
||||
credentials = get_credentials_for_model(
|
||||
llm_router=llm_router,
|
||||
model_id=model_param,
|
||||
operation_context="batch creation",
|
||||
)
|
||||
|
||||
prepare_data_with_credentials(
|
||||
data=_create_batch_data, # type: ignore
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Create batch using model credentials
|
||||
response = await litellm.acreate_batch(
|
||||
custom_llm_provider=credentials["custom_llm_provider"],
|
||||
**_create_batch_data, # type: ignore
|
||||
)
|
||||
|
||||
encode_batch_response_ids(response, model=model_param)
|
||||
|
||||
verbose_proxy_logger.debug(f"Created batch using model: {model_param}")
|
||||
else:
|
||||
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
|
||||
response = await litellm.acreate_batch(
|
||||
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def retrieve_batch( # noqa: PLR0915
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
provider: Optional[str] = None,
|
||||
batch_id: str = Path(
|
||||
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieves a batch.
|
||||
This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches/batch_abc123 \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
try:
|
||||
model_from_id = decode_model_from_file_id(batch_id)
|
||||
_retrieve_batch_request = RetrieveBatchRequest(
|
||||
batch_id=batch_id,
|
||||
)
|
||||
|
||||
data = cast(dict, _retrieve_batch_request)
|
||||
unified_batch_id = _is_base64_encoded_unified_file_id(batch_id)
|
||||
|
||||
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
(
|
||||
data,
|
||||
litellm_logging_obj,
|
||||
) = await base_llm_response_processor.common_processing_pre_call_logic(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
proxy_config=proxy_config,
|
||||
route_type="aretrieve_batch",
|
||||
)
|
||||
|
||||
# FIX: First, try to read from ManagedObjectTable for consistent state
|
||||
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
db_batch_object, response = await get_batch_from_database(
|
||||
batch_id=batch_id,
|
||||
unified_batch_id=unified_batch_id,
|
||||
managed_files_obj=managed_files_obj,
|
||||
prisma_client=prisma_client,
|
||||
verbose_proxy_logger=verbose_proxy_logger,
|
||||
)
|
||||
|
||||
# If batch is in a terminal state, return immediately
|
||||
if response is not None and response.status in [
|
||||
"completed",
|
||||
"failed",
|
||||
"cancelled",
|
||||
"expired",
|
||||
]:
|
||||
# Call hooks and return
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
# async_post_call_success_hook replaces batch.id and output_file_id with unified IDs
|
||||
# but not input_file_id. Resolve raw provider ID to unified ID.
|
||||
if unified_batch_id:
|
||||
await resolve_input_file_id_to_unified(response, prisma_client)
|
||||
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# If batch is still processing, sync with provider to get latest state
|
||||
if response is not None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Batch {batch_id} is in non-terminal state {response.status}, syncing with provider"
|
||||
)
|
||||
|
||||
# Retrieve from provider (for non-terminal states or if DB lookup failed)
|
||||
# SCENARIO 1: Batch ID is encoded with model info
|
||||
if model_from_id is not None:
|
||||
credentials = get_credentials_for_model(
|
||||
llm_router=llm_router,
|
||||
model_id=model_from_id,
|
||||
operation_context="batch retrieval (batch created with model)",
|
||||
)
|
||||
|
||||
original_batch_id = get_original_file_id(batch_id)
|
||||
prepare_data_with_credentials(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
file_id=original_batch_id, # Sets data["batch_id"] = original_batch_id
|
||||
)
|
||||
# Fix: The helper sets "file_id" but we need "batch_id"
|
||||
data["batch_id"] = data.pop("file_id", original_batch_id)
|
||||
|
||||
# Retrieve batch using model credentials
|
||||
response = await litellm.aretrieve_batch(
|
||||
custom_llm_provider=credentials["custom_llm_provider"],
|
||||
**data, # type: ignore
|
||||
)
|
||||
|
||||
encode_batch_response_ids(response, model=model_from_id)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Retrieved batch using model: {model_from_id}, original_id: {original_batch_id}"
|
||||
)
|
||||
|
||||
elif (
|
||||
litellm.enable_loadbalancing_on_batch_endpoints is True or unified_batch_id
|
||||
):
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.aretrieve_batch(**data) # type: ignore
|
||||
response._hidden_params["unified_batch_id"] = unified_batch_id
|
||||
if unified_batch_id:
|
||||
model_id_from_batch = get_model_id_from_unified_batch_id(
|
||||
unified_batch_id
|
||||
)
|
||||
if model_id_from_batch:
|
||||
response._hidden_params["model_id"] = model_id_from_batch
|
||||
|
||||
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or "openai"
|
||||
)
|
||||
response = await litellm.aretrieve_batch(
|
||||
custom_llm_provider=custom_llm_provider, **data # type: ignore
|
||||
)
|
||||
|
||||
# FIX: Update the database with the latest state from provider
|
||||
await update_batch_in_database(
|
||||
batch_id=batch_id,
|
||||
unified_batch_id=unified_batch_id,
|
||||
response=response,
|
||||
managed_files_obj=managed_files_obj,
|
||||
prisma_client=prisma_client,
|
||||
verbose_proxy_logger=verbose_proxy_logger,
|
||||
db_batch_object=db_batch_object,
|
||||
operation="retrieve",
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
# Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id
|
||||
# Resolve raw provider input_file_id to unified ID.
|
||||
if unified_batch_id:
|
||||
await resolve_input_file_id_to_unified(response, prisma_client)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def list_batches(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
after: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
target_model_names: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Lists
|
||||
This is the equivalent of GET https://api.openai.com/v1/batches/
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/list
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
|
||||
try:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.no_llm_router.value},
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data = await _read_request_body(request=request)
|
||||
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
(
|
||||
data,
|
||||
litellm_logging_obj,
|
||||
) = await base_llm_response_processor.common_processing_pre_call_logic(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
proxy_config=proxy_config,
|
||||
route_type="alist_batches",
|
||||
)
|
||||
|
||||
# Try to use managed objects table for listing batches (returns encoded IDs)
|
||||
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
if managed_files_obj is not None and hasattr(
|
||||
managed_files_obj, "list_user_batches"
|
||||
):
|
||||
verbose_proxy_logger.debug("Using managed objects table for batch listing")
|
||||
response = await managed_files_obj.list_user_batches(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
limit=limit,
|
||||
after=after,
|
||||
provider=provider,
|
||||
target_model_names=target_model_names,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
elif model_param := (
|
||||
data.get("model")
|
||||
or request.query_params.get("model")
|
||||
or request.headers.get("x-litellm-model")
|
||||
):
|
||||
# SCENARIO 2: Use model-based routing from header/query/body
|
||||
credentials = get_credentials_for_model(
|
||||
llm_router=llm_router,
|
||||
model_id=model_param,
|
||||
operation_context="batch listing",
|
||||
)
|
||||
|
||||
data.update(credentials)
|
||||
|
||||
response = await litellm.alist_batches(
|
||||
custom_llm_provider=credentials["custom_llm_provider"],
|
||||
after=after,
|
||||
limit=limit,
|
||||
**data, # type: ignore
|
||||
)
|
||||
|
||||
# Encode batch IDs in the list response so clients can use
|
||||
# them for retrieve/cancel/file downloads through the proxy.
|
||||
if response and hasattr(response, "data") and response.data:
|
||||
for batch in response.data:
|
||||
encode_batch_response_ids(batch, model=model_param)
|
||||
|
||||
verbose_proxy_logger.debug(f"Listed batches using model: {model_param}")
|
||||
|
||||
# SCENARIO 2 (alternative): target_model_names based routing
|
||||
elif target_model_names or data.get("target_model_names", None):
|
||||
target_model_names = target_model_names or data.get(
|
||||
"target_model_names", None
|
||||
)
|
||||
if target_model_names is None:
|
||||
raise ValueError(
|
||||
"target_model_names is required for this routing scenario"
|
||||
)
|
||||
model = target_model_names.split(",")[0]
|
||||
data.pop("model", None)
|
||||
response = await llm_router.alist_batches(
|
||||
model=model,
|
||||
after=after,
|
||||
limit=limit,
|
||||
**data,
|
||||
)
|
||||
|
||||
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or "openai"
|
||||
)
|
||||
response = await litellm.alist_batches(
|
||||
custom_llm_provider=custom_llm_provider, # type: ignore
|
||||
after=after,
|
||||
limit=limit,
|
||||
**data,
|
||||
)
|
||||
|
||||
## POST CALL HOOKS ###
|
||||
_response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
|
||||
)
|
||||
if _response is not None and type(response) is type(_response):
|
||||
response = _response
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data={"after": after, "limit": limit},
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/v1/batches/{batch_id:path}/cancel",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/batches/{batch_id:path}/cancel",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/batches/{batch_id:path}/cancel",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def cancel_batch(
|
||||
request: Request,
|
||||
batch_id: str,
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Cancel a batch.
|
||||
This is the equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
|
||||
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/cancel
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches/batch_abc123/cancel \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-X POST
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
try:
|
||||
# Check for encoded batch ID with model info
|
||||
model_from_id = decode_model_from_file_id(batch_id)
|
||||
|
||||
# Create CancelBatchRequest with batch_id to enable ownership checking
|
||||
_cancel_batch_request = CancelBatchRequest(
|
||||
batch_id=batch_id,
|
||||
)
|
||||
data = cast(dict, _cancel_batch_request)
|
||||
|
||||
unified_batch_id = _is_base64_encoded_unified_file_id(batch_id)
|
||||
|
||||
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
(
|
||||
data,
|
||||
litellm_logging_obj,
|
||||
) = await base_llm_response_processor.common_processing_pre_call_logic(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
proxy_config=proxy_config,
|
||||
route_type="acancel_batch",
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# SCENARIO 1: Batch ID is encoded with model info
|
||||
if model_from_id is not None:
|
||||
credentials = get_credentials_for_model(
|
||||
llm_router=llm_router,
|
||||
model_id=model_from_id,
|
||||
operation_context="batch cancellation (batch created with model)",
|
||||
)
|
||||
|
||||
original_batch_id = get_original_file_id(batch_id)
|
||||
prepare_data_with_credentials(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
file_id=original_batch_id,
|
||||
)
|
||||
# Fix: The helper sets "file_id" but we need "batch_id"
|
||||
data["batch_id"] = data.pop("file_id", original_batch_id)
|
||||
|
||||
# Cancel batch using model credentials
|
||||
response = await litellm.acancel_batch(
|
||||
custom_llm_provider=credentials["custom_llm_provider"],
|
||||
**data, # type: ignore
|
||||
)
|
||||
|
||||
encode_batch_response_ids(response, model=model_from_id)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Cancelled batch using model: {model_from_id}, original_id: {original_batch_id}"
|
||||
)
|
||||
|
||||
# SCENARIO 2: target_model_names based routing
|
||||
elif unified_batch_id:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
# Hook has already extracted model and unwrapped batch_id into data dict
|
||||
response = await llm_router.acancel_batch(**data) # type: ignore
|
||||
response._hidden_params["unified_batch_id"] = unified_batch_id
|
||||
|
||||
# Ensure model_id is set for the post_call_success_hook to re-encode IDs
|
||||
if not response._hidden_params.get("model_id") and data.get("model"):
|
||||
response._hidden_params["model_id"] = data["model"]
|
||||
|
||||
# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider or data.pop("custom_llm_provider", None) or "openai"
|
||||
)
|
||||
# Extract batch_id from data to avoid "multiple values for keyword argument" error
|
||||
# data was cast from CancelBatchRequest which already contains batch_id
|
||||
data.pop("batch_id", None)
|
||||
_cancel_batch_data = CancelBatchRequest(batch_id=batch_id, **data)
|
||||
response = await litellm.acancel_batch(
|
||||
custom_llm_provider=custom_llm_provider, # type: ignore
|
||||
**_cancel_batch_data,
|
||||
)
|
||||
|
||||
# FIX: Update the database with the new cancelled state
|
||||
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
await update_batch_in_database(
|
||||
batch_id=batch_id,
|
||||
unified_batch_id=unified_batch_id,
|
||||
response=response,
|
||||
managed_files_obj=managed_files_obj,
|
||||
prisma_client=prisma_client,
|
||||
verbose_proxy_logger=verbose_proxy_logger,
|
||||
operation="cancel",
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
######################################################################
|
||||
|
||||
# END OF /v1/batches Endpoints Implementation
|
||||
|
||||
######################################################################
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
@@ -0,0 +1,257 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import RedisCache
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.caching import CachePingResponse, HealthCheckCacheParams
|
||||
|
||||
masker = SensitiveDataMasker()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/cache",
|
||||
tags=["caching"],
|
||||
)
|
||||
|
||||
|
||||
def _extract_cache_params() -> Dict[str, Any]:
|
||||
"""
|
||||
Safely extracts and cleans cache parameters.
|
||||
|
||||
The health check UI needs to display specific cache parameters, to show users how they set up their cache.
|
||||
|
||||
eg.
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"redis_kwargs": {"db": 0},
|
||||
"namespace": "test",
|
||||
}
|
||||
|
||||
Returns:
|
||||
Dict containing cleaned and masked cache parameters
|
||||
"""
|
||||
if litellm.cache is None:
|
||||
return {}
|
||||
try:
|
||||
cache_params = vars(litellm.cache.cache)
|
||||
cleaned_params = (
|
||||
HealthCheckCacheParams(**cache_params).model_dump() if cache_params else {}
|
||||
)
|
||||
return masker.mask_dict(cleaned_params)
|
||||
except (AttributeError, TypeError) as e:
|
||||
verbose_proxy_logger.debug(f"Error extracting cache params: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ping",
|
||||
response_model=CachePingResponse,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_ping():
|
||||
"""
|
||||
Endpoint for checking if cache can be pinged
|
||||
"""
|
||||
litellm_cache_params: Dict[str, Any] = {}
|
||||
cleaned_cache_params: Dict[str, Any] = {}
|
||||
try:
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
litellm_cache_params = masker.mask_dict(vars(litellm.cache))
|
||||
# remove field that might reference itself
|
||||
litellm_cache_params.pop("cache", None)
|
||||
cleaned_cache_params = _extract_cache_params()
|
||||
|
||||
if litellm.cache.type == "redis":
|
||||
ping_response = await litellm.cache.ping()
|
||||
verbose_proxy_logger.debug(
|
||||
"/cache/ping: ping_response: " + str(ping_response)
|
||||
)
|
||||
# add cache does not return anything
|
||||
await litellm.cache.async_add_cache(
|
||||
result="test_key",
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "test from litellm"}],
|
||||
)
|
||||
verbose_proxy_logger.debug("/cache/ping: done with set_cache()")
|
||||
|
||||
return CachePingResponse(
|
||||
status="healthy",
|
||||
cache_type=str(litellm.cache.type),
|
||||
ping_response=True,
|
||||
set_cache_response="success",
|
||||
litellm_cache_params=safe_dumps(litellm_cache_params),
|
||||
health_check_cache_params=cleaned_cache_params,
|
||||
)
|
||||
else:
|
||||
return CachePingResponse(
|
||||
status="healthy",
|
||||
cache_type=str(litellm.cache.type),
|
||||
litellm_cache_params=safe_dumps(litellm_cache_params),
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_message = {
|
||||
"message": f"Service Unhealthy ({str(e)})",
|
||||
"litellm_cache_params": safe_dumps(litellm_cache_params),
|
||||
"health_check_cache_params": safe_dumps(cleaned_cache_params),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
raise ProxyException(
|
||||
message=safe_dumps(error_message),
|
||||
type=ProxyErrorTypes.cache_ping_error,
|
||||
param="cache_ping",
|
||||
code=503,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/delete",
|
||||
tags=["caching"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_delete(request: Request):
|
||||
"""
|
||||
Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers
|
||||
|
||||
Parameters:
|
||||
- **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]}
|
||||
|
||||
```shell
|
||||
curl -X POST "http://0.0.0.0:4000/cache/delete" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"keys": ["key1", "key2"]}'
|
||||
```
|
||||
|
||||
"""
|
||||
try:
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
|
||||
request_data = await request.json()
|
||||
keys = request_data.get("keys", None)
|
||||
|
||||
if litellm.cache.type == "redis":
|
||||
await litellm.cache.delete_cache_keys(keys=keys)
|
||||
return {
|
||||
"status": "success",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache Delete Failed({str(e)})",
|
||||
)
|
||||
|
||||
|
||||
def _get_redis_client_info(cache_instance) -> Tuple[List, int]:
|
||||
"""
|
||||
Helper function to safely get Redis client list information.
|
||||
|
||||
Returns:
|
||||
tuple: (client_list, num_clients) where num_clients is -1 if CLIENT LIST is unavailable
|
||||
"""
|
||||
try:
|
||||
client_list = cache_instance.client_list()
|
||||
return client_list, len(client_list)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"CLIENT LIST command failed (likely restricted on managed Redis): {str(e)}"
|
||||
)
|
||||
return ["CLIENT LIST command not available on this Redis instance"], -1
|
||||
|
||||
|
||||
@router.get(
|
||||
"/redis/info",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_redis_info():
|
||||
"""
|
||||
Endpoint for getting /redis/info
|
||||
"""
|
||||
try:
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
|
||||
if not (
|
||||
litellm.cache.type == "redis"
|
||||
and isinstance(litellm.cache.cache, RedisCache)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache type {litellm.cache.type} does not support redis info",
|
||||
)
|
||||
|
||||
# Get client information (handles CLIENT LIST restrictions gracefully)
|
||||
client_list, num_clients = _get_redis_client_info(litellm.cache.cache)
|
||||
|
||||
# Get Redis server information
|
||||
redis_info = litellm.cache.cache.info()
|
||||
|
||||
return {
|
||||
"num_clients": num_clients,
|
||||
"clients": client_list,
|
||||
"info": redis_info,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Service Unhealthy ({str(e)})",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/flushall",
|
||||
tags=["caching"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_flushall():
|
||||
"""
|
||||
A function to flush all items from the cache. (All items will be deleted from the cache with this)
|
||||
Raises HTTPException if the cache is not initialized or if the cache type does not support flushing.
|
||||
Returns a dictionary with the status of the operation.
|
||||
|
||||
Usage:
|
||||
```
|
||||
curl -X POST http://0.0.0.0:4000/cache/flushall -H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
"""
|
||||
try:
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
if litellm.cache.type == "redis" and isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
litellm.cache.cache.flushall()
|
||||
return {
|
||||
"status": "success",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache type {litellm.cache.type} does not support flushing",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Service Unhealthy ({str(e)})",
|
||||
)
|
||||
@@ -0,0 +1,394 @@
|
||||
# LiteLLM Proxy Client
|
||||
|
||||
A Python client library for interacting with the LiteLLM proxy server. This client provides a clean, typed interface for managing models, keys, credentials, and making chat completions.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install litellm
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from litellm.proxy.client import Client
|
||||
|
||||
# Initialize the client
|
||||
client = Client(
|
||||
base_url="http://localhost:4000", # Your LiteLLM proxy server URL
|
||||
api_key="sk-api-key" # Optional: API key for authentication
|
||||
)
|
||||
|
||||
# Make a chat completion request
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
The client is organized into several resource clients for different functionality:
|
||||
|
||||
- `chat`: Chat completions
|
||||
- `models`: Model management
|
||||
- `model_groups`: Model group management
|
||||
- `keys`: API key management
|
||||
- `credentials`: Credential management
|
||||
- `users`: User management
|
||||
|
||||
## Chat Completions
|
||||
|
||||
Make chat completion requests to your LiteLLM proxy:
|
||||
|
||||
```python
|
||||
# Basic chat completion
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the capital of France?"}
|
||||
]
|
||||
)
|
||||
|
||||
# Stream responses
|
||||
for chunk in client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Tell me a story"}],
|
||||
stream=True
|
||||
):
|
||||
print(chunk.choices[0].delta.content or "", end="")
|
||||
```
|
||||
|
||||
## Model Management
|
||||
|
||||
Manage available models on your proxy:
|
||||
|
||||
```python
|
||||
# List available models
|
||||
models = client.models.list()
|
||||
|
||||
# Add a new model
|
||||
client.models.add(
|
||||
model_name="gpt-4",
|
||||
litellm_params={
|
||||
"api_key": "your-openai-key",
|
||||
"api_base": "https://api.openai.com/v1"
|
||||
}
|
||||
)
|
||||
|
||||
# Delete a model
|
||||
client.models.delete(model_name="gpt-4")
|
||||
```
|
||||
|
||||
## API Key Management
|
||||
|
||||
Manage virtual API keys:
|
||||
|
||||
```python
|
||||
# Generate a new API key
|
||||
key = client.keys.generate(
|
||||
models=["gpt-4", "gpt-3.5-turbo"],
|
||||
aliases={"gpt4": "gpt-4"},
|
||||
duration="24h",
|
||||
key_alias="my-key",
|
||||
team_id="team123"
|
||||
)
|
||||
|
||||
# List all keys
|
||||
keys = client.keys.list(
|
||||
page=1,
|
||||
size=10,
|
||||
return_full_object=True
|
||||
)
|
||||
|
||||
# Delete keys
|
||||
client.keys.delete(
|
||||
keys=["sk-key1", "sk-key2"],
|
||||
key_aliases=["alias1", "alias2"]
|
||||
)
|
||||
```
|
||||
|
||||
## Credential Management
|
||||
|
||||
Manage model credentials:
|
||||
|
||||
```python
|
||||
# Create new credentials
|
||||
client.credentials.create(
|
||||
credential_name="azure1",
|
||||
credential_info={"api_type": "azure"},
|
||||
credential_values={
|
||||
"api_key": "your-azure-key",
|
||||
"api_base": "https://example.azure.openai.com"
|
||||
}
|
||||
)
|
||||
|
||||
# List all credentials
|
||||
credentials = client.credentials.list()
|
||||
|
||||
# Get a specific credential
|
||||
credential = client.credentials.get(credential_name="azure1")
|
||||
|
||||
# Delete credentials
|
||||
client.credentials.delete(credential_name="azure1")
|
||||
```
|
||||
|
||||
## Model Groups
|
||||
|
||||
Manage model groups for load balancing and fallbacks:
|
||||
|
||||
```python
|
||||
# Create a model group
|
||||
client.model_groups.create(
|
||||
name="gpt4-group",
|
||||
models=[
|
||||
{"model_name": "gpt-4", "litellm_params": {"api_key": "key1"}},
|
||||
{"model_name": "gpt-4-backup", "litellm_params": {"api_key": "key2"}}
|
||||
]
|
||||
)
|
||||
|
||||
# List model groups
|
||||
groups = client.model_groups.list()
|
||||
|
||||
# Delete a model group
|
||||
client.model_groups.delete(name="gpt4-group")
|
||||
```
|
||||
|
||||
## Users Management
|
||||
|
||||
Manage users on your proxy:
|
||||
|
||||
```python
|
||||
from litellm.proxy.client import UsersManagementClient
|
||||
|
||||
users = UsersManagementClient(base_url="http://localhost:4000", api_key="sk-test")
|
||||
|
||||
# List users
|
||||
user_list = users.list_users()
|
||||
|
||||
# Get user info
|
||||
user_info = users.get_user(user_id="u1")
|
||||
|
||||
# Create a new user
|
||||
created = users.create_user({
|
||||
"user_email": "a@b.com",
|
||||
"user_role": "internal_user",
|
||||
"user_alias": "Alice",
|
||||
"teams": ["team1"],
|
||||
"max_budget": 100.0
|
||||
})
|
||||
|
||||
# Delete users
|
||||
users.delete_user(["u1", "u2"])
|
||||
```
|
||||
|
||||
## Low-Level HTTP Client
|
||||
|
||||
The client provides access to a low-level HTTP client for making direct requests
|
||||
to the LiteLLM proxy server. This is useful when you need more control or when
|
||||
working with endpoints that don't yet have a high-level interface.
|
||||
|
||||
```python
|
||||
# Access the HTTP client
|
||||
client = Client(
|
||||
base_url="http://localhost:4000",
|
||||
api_key="sk-api-key"
|
||||
)
|
||||
|
||||
# Make a custom request
|
||||
response = client.http.request(
|
||||
method="POST",
|
||||
uri="/health/test_connection",
|
||||
json={
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"api_key": "your-api-key",
|
||||
"api_base": "https://api.openai.com/v1"
|
||||
},
|
||||
"mode": "chat"
|
||||
}
|
||||
)
|
||||
|
||||
# The response is automatically parsed from JSON
|
||||
print(response)
|
||||
```
|
||||
|
||||
### HTTP Client Features
|
||||
|
||||
- Automatic URL handling (handles trailing/leading slashes)
|
||||
- Built-in authentication (adds Bearer token if `api_key` is provided)
|
||||
- JSON request/response handling
|
||||
- Configurable timeout (default: 30 seconds)
|
||||
- Comprehensive error handling
|
||||
- Support for custom headers and request parameters
|
||||
|
||||
### HTTP Client `request` method parameters
|
||||
|
||||
- `method`: HTTP method (GET, POST, PUT, DELETE, etc.)
|
||||
- `uri`: URI path (will be appended to base_url)
|
||||
- `data`: (optional) Data to send in the request body
|
||||
- `json`: (optional) JSON data to send in the request body
|
||||
- `headers`: (optional) Custom HTTP headers
|
||||
- Additional keyword arguments are passed to the underlying requests library
|
||||
|
||||
## Error Handling
|
||||
|
||||
The client provides clear error handling with custom exceptions:
|
||||
|
||||
```python
|
||||
from litellm.proxy.client.exceptions import UnauthorizedError
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
except UnauthorizedError as e:
|
||||
print("Authentication failed:", e)
|
||||
except Exception as e:
|
||||
print("Request failed:", e)
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Request Customization
|
||||
|
||||
All methods support returning the raw request object for inspection or modification:
|
||||
|
||||
```python
|
||||
# Get the prepared request without sending it
|
||||
request = client.models.list(return_request=True)
|
||||
print(request.method) # GET
|
||||
print(request.url) # http://localhost:8000/models
|
||||
print(request.headers) # {'Content-Type': 'application/json', ...}
|
||||
```
|
||||
|
||||
### Pagination
|
||||
|
||||
Methods that return lists support pagination:
|
||||
|
||||
```python
|
||||
# Get the first page of keys
|
||||
page1 = client.keys.list(page=1, size=10)
|
||||
|
||||
# Get the second page
|
||||
page2 = client.keys.list(page=2, size=10)
|
||||
```
|
||||
|
||||
### Filtering
|
||||
|
||||
Many list methods support filtering:
|
||||
|
||||
```python
|
||||
# Filter keys by user and team
|
||||
keys = client.keys.list(
|
||||
user_id="user123",
|
||||
team_id="team456",
|
||||
include_team_keys=True
|
||||
)
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please check out our [contributing guidelines](../../CONTRIBUTING.md) for details.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](../../LICENSE) file for details.
|
||||
|
||||
## CLI Authentication Flow
|
||||
|
||||
The LiteLLM CLI supports SSO authentication through a polling-based approach that works with any OAuth-compatible SSO provider.
|
||||
|
||||
### How CLI Authentication Works
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant CLI as CLI
|
||||
participant Browser as Browser
|
||||
participant Proxy as LiteLLM Proxy
|
||||
participant SSO as SSO Provider
|
||||
|
||||
CLI->>CLI: Generate key ID (sk-uuid)
|
||||
CLI->>Browser: Open /sso/key/generate?source=litellm-cli&key=sk-uuid
|
||||
|
||||
Browser->>Proxy: GET /sso/key/generate?source=litellm-cli&key=sk-uuid
|
||||
Proxy->>Proxy: Set cli_state = litellm-session-token:sk-uuid
|
||||
Proxy->>SSO: Redirect with state=litellm-session-token:sk-uuid
|
||||
|
||||
SSO->>Browser: Show login page
|
||||
Browser->>SSO: User authenticates
|
||||
SSO->>Proxy: Redirect to /sso/callback?state=litellm-session-token:sk-uuid
|
||||
|
||||
Proxy->>Proxy: Check if state starts with "litellm-session-token:"
|
||||
Proxy->>Proxy: Generate API key with ID=sk-uuid
|
||||
Proxy->>Browser: Show success page
|
||||
|
||||
CLI->>Proxy: Poll /sso/cli/poll/sk-uuid
|
||||
Proxy->>CLI: Return {"status": "ready", "key": "sk-uuid"}
|
||||
CLI->>CLI: Save key to ~/.litellm/token.json
|
||||
```
|
||||
|
||||
### Authentication Commands
|
||||
|
||||
The CLI provides three authentication commands:
|
||||
|
||||
- **`litellm-proxy login`** - Start SSO authentication flow
|
||||
- **`litellm-proxy logout`** - Clear stored authentication token
|
||||
- **`litellm-proxy whoami`** - Show current authentication status
|
||||
|
||||
### Authentication Flow Steps
|
||||
|
||||
1. **Generate Session ID**: CLI generates a unique key ID (`sk-{uuid}`)
|
||||
2. **Open Browser**: CLI opens browser to `/sso/key/generate` with CLI source and key parameters
|
||||
3. **SSO Redirect**: Proxy sets the formatted state (`litellm-session-token:sk-uuid`) as OAuth state parameter and redirects to SSO provider
|
||||
4. **User Authentication**: User completes SSO authentication in browser
|
||||
5. **Callback Processing**: SSO provider redirects back to proxy with state parameter
|
||||
6. **Key Generation**: Proxy detects CLI login (state starts with "litellm-session-token:") and generates API key with pre-specified ID
|
||||
7. **Polling**: CLI polls `/sso/cli/poll/{key_id}` endpoint until key is ready
|
||||
8. **Token Storage**: CLI saves the authentication token to `~/.litellm/token.json`
|
||||
|
||||
### Benefits of This Approach
|
||||
|
||||
- **No Local Server**: No need to run a local callback server
|
||||
- **Standard OAuth**: Uses OAuth 2.0 state parameter correctly
|
||||
- **Remote Compatible**: Works with remote proxy servers
|
||||
- **Secure**: Uses UUID session identifiers
|
||||
- **Simple Setup**: No additional OAuth redirect URL configuration needed
|
||||
|
||||
### Token Storage
|
||||
|
||||
Authentication tokens are stored in `~/.litellm/token.json` with restricted file permissions (600). The stored token includes:
|
||||
|
||||
```json
|
||||
{
|
||||
"key": "sk-...",
|
||||
"user_id": "cli-user",
|
||||
"user_email": "user@example.com",
|
||||
"user_role": "cli",
|
||||
"auth_header_name": "Authorization",
|
||||
"timestamp": 1234567890
|
||||
}
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
Once authenticated, the CLI will automatically use the stored token for all requests. You no longer need to specify `--api-key` for subsequent commands.
|
||||
|
||||
```bash
|
||||
# Login
|
||||
litellm-proxy login
|
||||
|
||||
# Use CLI without specifying API key
|
||||
litellm-proxy models list
|
||||
|
||||
# Check authentication status
|
||||
litellm-proxy whoami
|
||||
|
||||
# Logout
|
||||
litellm-proxy logout
|
||||
```
|
||||
@@ -0,0 +1,17 @@
|
||||
from .client import Client
|
||||
from .chat import ChatClient
|
||||
from .models import ModelsManagementClient
|
||||
from .model_groups import ModelGroupsManagementClient
|
||||
from .exceptions import UnauthorizedError
|
||||
from .users import UsersManagementClient
|
||||
from .health import HealthManagementClient
|
||||
|
||||
__all__ = [
|
||||
"Client",
|
||||
"ChatClient",
|
||||
"ModelsManagementClient",
|
||||
"ModelGroupsManagementClient",
|
||||
"UsersManagementClient",
|
||||
"UnauthorizedError",
|
||||
"HealthManagementClient",
|
||||
]
|
||||
@@ -0,0 +1,185 @@
|
||||
import json
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from .exceptions import UnauthorizedError
|
||||
|
||||
|
||||
class ChatClient:
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the ChatClient.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
|
||||
self._api_key = api_key
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get the headers for API requests, including authorization if api_key is set.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Headers to use for API requests
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
return headers
|
||||
|
||||
def completions(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Create a chat completion.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for completion
|
||||
messages (List[Dict[str, str]]): The messages to generate a completion for
|
||||
temperature (Optional[float]): Sampling temperature between 0 and 2
|
||||
top_p (Optional[float]): Nucleus sampling parameter between 0 and 1
|
||||
n (Optional[int]): Number of completions to generate
|
||||
max_tokens (Optional[int]): Maximum number of tokens to generate
|
||||
presence_penalty (Optional[float]): Presence penalty between -2.0 and 2.0
|
||||
frequency_penalty (Optional[float]): Frequency penalty between -2.0 and 2.0
|
||||
user (Optional[str]): Unique identifier for the end user
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the completion response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/chat/completions"
|
||||
|
||||
# Build request data with required fields
|
||||
data: Dict[str, Any] = {"model": model, "messages": messages}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if temperature is not None:
|
||||
data["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
data["top_p"] = top_p
|
||||
if n is not None:
|
||||
data["n"] = n
|
||||
if max_tokens is not None:
|
||||
data["max_tokens"] = max_tokens
|
||||
if presence_penalty is not None:
|
||||
data["presence_penalty"] = presence_penalty
|
||||
if frequency_penalty is not None:
|
||||
data["frequency_penalty"] = frequency_penalty
|
||||
if user is not None:
|
||||
data["user"] = user
|
||||
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
# Prepare and send the request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def completions_stream(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Create a streaming chat completion.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for completion
|
||||
messages (List[Dict[str, str]]): The messages to generate a completion for
|
||||
temperature (Optional[float]): Sampling temperature between 0 and 2
|
||||
top_p (Optional[float]): Nucleus sampling parameter between 0 and 1
|
||||
n (Optional[int]): Number of completions to generate
|
||||
max_tokens (Optional[int]): Maximum number of tokens to generate
|
||||
presence_penalty (Optional[float]): Presence penalty between -2.0 and 2.0
|
||||
frequency_penalty (Optional[float]): Frequency penalty between -2.0 and 2.0
|
||||
user (Optional[str]): Unique identifier for the end user
|
||||
|
||||
Yields:
|
||||
Dict[str, Any]: Streaming response chunks from the server
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/chat/completions"
|
||||
|
||||
# Build request data with required fields
|
||||
data: Dict[str, Any] = {"model": model, "messages": messages, "stream": True}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if temperature is not None:
|
||||
data["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
data["top_p"] = top_p
|
||||
if n is not None:
|
||||
data["n"] = n
|
||||
if max_tokens is not None:
|
||||
data["max_tokens"] = max_tokens
|
||||
if presence_penalty is not None:
|
||||
data["presence_penalty"] = presence_penalty
|
||||
if frequency_penalty is not None:
|
||||
data["frequency_penalty"] = frequency_penalty
|
||||
if user is not None:
|
||||
data["user"] = user
|
||||
|
||||
# Make streaming request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.post(
|
||||
url, headers=self._get_headers(), json=data, stream=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse SSE stream
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove 'data: ' prefix
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
yield chunk
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
@@ -0,0 +1,536 @@
|
||||
# LiteLLM Proxy CLI
|
||||
|
||||
The LiteLLM Proxy CLI is a command-line tool for managing your LiteLLM proxy server. It provides commands for managing models, viewing server status, and interacting with the proxy server.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install 'litellm[proxy]'
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The CLI can be configured using environment variables or command-line options:
|
||||
|
||||
- `LITELLM_PROXY_URL`: Base URL of the LiteLLM proxy server (default: http://localhost:4000)
|
||||
- `LITELLM_PROXY_API_KEY`: API key for authentication
|
||||
|
||||
## Global Options
|
||||
|
||||
- `--version`, `-v`: Print the LiteLLM Proxy client and server version and exit.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
litellm-proxy version
|
||||
# or
|
||||
litellm-proxy --version
|
||||
# or
|
||||
litellm-proxy -v
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Models Management
|
||||
|
||||
The CLI provides several commands for managing models on your LiteLLM proxy server:
|
||||
|
||||
#### List Models
|
||||
|
||||
View all available models:
|
||||
|
||||
```bash
|
||||
litellm-proxy models list [--format table|json]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--format`: Output format (table or json, default: table)
|
||||
|
||||
#### Model Information
|
||||
|
||||
Get detailed information about all models:
|
||||
|
||||
```bash
|
||||
litellm-proxy models info [options]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--format`: Output format (table or json, default: table)
|
||||
- `--columns`: Comma-separated list of columns to display. Valid columns:
|
||||
- `public_model`
|
||||
- `upstream_model`
|
||||
- `credential_name`
|
||||
- `created_at`
|
||||
- `updated_at`
|
||||
- `id`
|
||||
- `input_cost`
|
||||
- `output_cost`
|
||||
|
||||
Default columns: `public_model`, `upstream_model`, `updated_at`
|
||||
|
||||
#### Add Model
|
||||
|
||||
Add a new model to the proxy:
|
||||
|
||||
```bash
|
||||
litellm-proxy models add <model-name> [options]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--param`, `-p`: Model parameters in key=value format (can be specified multiple times)
|
||||
- `--info`, `-i`: Model info in key=value format (can be specified multiple times)
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
litellm-proxy models add gpt-4 -p api_key=sk-123 -p api_base=https://api.openai.com -i description="GPT-4 model"
|
||||
```
|
||||
|
||||
#### Get Model Info
|
||||
|
||||
Get information about a specific model:
|
||||
|
||||
```bash
|
||||
litellm-proxy models get [--id MODEL_ID] [--name MODEL_NAME]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--id`: ID of the model to retrieve
|
||||
- `--name`: Name of the model to retrieve
|
||||
|
||||
#### Delete Model
|
||||
|
||||
Delete a model from the proxy:
|
||||
|
||||
```bash
|
||||
litellm-proxy models delete <model-id>
|
||||
```
|
||||
|
||||
#### Update Model
|
||||
|
||||
Update an existing model's configuration:
|
||||
|
||||
```bash
|
||||
litellm-proxy models update <model-id> [options]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--param`, `-p`: Model parameters in key=value format (can be specified multiple times)
|
||||
- `--info`, `-i`: Model info in key=value format (can be specified multiple times)
|
||||
|
||||
#### Import Models
|
||||
|
||||
Import models from a YAML file:
|
||||
|
||||
```bash
|
||||
litellm-proxy models import models.yaml
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--dry-run`: Show what would be imported without making any changes.
|
||||
- `--only-models-matching-regex <regex>`: Only import models where `litellm_params.model` matches the given regex.
|
||||
- `--only-access-groups-matching-regex <regex>`: Only import models where at least one item in `model_info.access_groups` matches the given regex.
|
||||
|
||||
Examples:
|
||||
|
||||
1. Import all models from a YAML file:
|
||||
|
||||
```bash
|
||||
litellm-proxy models import models.yaml
|
||||
```
|
||||
|
||||
2. Dry run (show what would be imported):
|
||||
|
||||
```bash
|
||||
litellm-proxy models import models.yaml --dry-run
|
||||
```
|
||||
|
||||
3. Only import models where the model name contains 'gpt':
|
||||
|
||||
```bash
|
||||
litellm-proxy models import models.yaml --only-models-matching-regex gpt
|
||||
```
|
||||
|
||||
4. Only import models with access group containing 'beta':
|
||||
|
||||
```bash
|
||||
litellm-proxy models import models.yaml --only-access-groups-matching-regex beta
|
||||
```
|
||||
|
||||
5. Combine both filters:
|
||||
|
||||
```bash
|
||||
litellm-proxy models import models.yaml --only-models-matching-regex gpt --only-access-groups-matching-regex beta
|
||||
```
|
||||
|
||||
### Credentials Management
|
||||
|
||||
The CLI provides commands for managing credentials on your LiteLLM proxy server:
|
||||
|
||||
#### List Credentials
|
||||
|
||||
View all available credentials:
|
||||
|
||||
```bash
|
||||
litellm-proxy credentials list [--format table|json]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--format`: Output format (table or json, default: table)
|
||||
|
||||
The table format displays:
|
||||
- Credential Name
|
||||
- Custom LLM Provider
|
||||
|
||||
#### Create Credential
|
||||
|
||||
Create a new credential:
|
||||
|
||||
```bash
|
||||
litellm-proxy credentials create <credential-name> --info <json-string> --values <json-string>
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--info`: JSON string containing credential info (e.g., custom_llm_provider)
|
||||
- `--values`: JSON string containing credential values (e.g., api_key)
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
litellm-proxy credentials create azure-cred \
|
||||
--info '{"custom_llm_provider": "azure"}' \
|
||||
--values '{"api_key": "sk-123", "api_base": "https://example.azure.openai.com"}'
|
||||
```
|
||||
|
||||
#### Get Credential
|
||||
|
||||
Get information about a specific credential:
|
||||
|
||||
```bash
|
||||
litellm-proxy credentials get <credential-name>
|
||||
```
|
||||
|
||||
#### Delete Credential
|
||||
|
||||
Delete a credential:
|
||||
|
||||
```bash
|
||||
litellm-proxy credentials delete <credential-name>
|
||||
```
|
||||
|
||||
### Keys Management
|
||||
|
||||
The CLI provides commands for managing API keys on your LiteLLM proxy server:
|
||||
|
||||
#### List Keys
|
||||
|
||||
View all API keys:
|
||||
|
||||
```bash
|
||||
litellm-proxy keys list [--format table|json] [options]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--format`: Output format (table or json, default: table)
|
||||
- `--page`: Page number for pagination
|
||||
- `--size`: Number of items per page
|
||||
- `--user-id`: Filter keys by user ID
|
||||
- `--team-id`: Filter keys by team ID
|
||||
- `--organization-id`: Filter keys by organization ID
|
||||
- `--key-hash`: Filter by specific key hash
|
||||
- `--key-alias`: Filter by key alias
|
||||
- `--return-full-object`: Return the full key object
|
||||
- `--include-team-keys`: Include team keys in the response
|
||||
|
||||
#### Generate Key
|
||||
|
||||
Generate a new API key:
|
||||
|
||||
```bash
|
||||
litellm-proxy keys generate [options]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--models`: Comma-separated list of allowed models
|
||||
- `--aliases`: JSON string of model alias mappings
|
||||
- `--spend`: Maximum spend limit for this key
|
||||
- `--duration`: Duration for which the key is valid (e.g. '24h', '7d')
|
||||
- `--key-alias`: Alias/name for the key
|
||||
- `--team-id`: Team ID to associate the key with
|
||||
- `--user-id`: User ID to associate the key with
|
||||
- `--budget-id`: Budget ID to associate the key with
|
||||
- `--config`: JSON string of additional configuration parameters
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
litellm-proxy keys generate --models gpt-4,gpt-3.5-turbo --spend 100 --duration 24h --key-alias my-key --team-id team123
|
||||
```
|
||||
|
||||
#### Delete Keys
|
||||
|
||||
Delete API keys by key or alias:
|
||||
|
||||
```bash
|
||||
litellm-proxy keys delete [--keys <comma-separated-keys>] [--key-aliases <comma-separated-aliases>]
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--keys`: Comma-separated list of API keys to delete
|
||||
- `--key-aliases`: Comma-separated list of key aliases to delete
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
litellm-proxy keys delete --keys sk-key1,sk-key2 --key-aliases alias1,alias2
|
||||
```
|
||||
|
||||
#### Get Key Info
|
||||
|
||||
Get information about a specific API key:
|
||||
|
||||
```bash
|
||||
litellm-proxy keys info --key <key-hash>
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `--key`: The key hash to get information about
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
litellm-proxy keys info --key sk-key1
|
||||
```
|
||||
|
||||
### User Management
|
||||
|
||||
The CLI provides commands for managing users on your LiteLLM proxy server:
|
||||
|
||||
#### List Users
|
||||
|
||||
View all users:
|
||||
|
||||
```bash
|
||||
litellm-proxy users list
|
||||
```
|
||||
|
||||
#### Get User Info
|
||||
|
||||
Get information about a specific user:
|
||||
|
||||
```bash
|
||||
litellm-proxy users get --id <user-id>
|
||||
```
|
||||
|
||||
#### Create User
|
||||
|
||||
Create a new user:
|
||||
|
||||
```bash
|
||||
litellm-proxy users create --email user@example.com --role internal_user --alias "Alice" --team team1 --max-budget 100.0
|
||||
```
|
||||
|
||||
#### Delete User
|
||||
|
||||
Delete one or more users by user_id:
|
||||
|
||||
```bash
|
||||
litellm-proxy users delete <user-id-1> <user-id-2>
|
||||
```
|
||||
|
||||
### Chat Commands
|
||||
|
||||
The CLI provides commands for interacting with chat models through your LiteLLM proxy server:
|
||||
|
||||
#### Chat Completions
|
||||
|
||||
Create a chat completion:
|
||||
|
||||
```bash
|
||||
litellm-proxy chat completions <model> [options]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
- `model`: The model to use (e.g., gpt-4, claude-2)
|
||||
|
||||
Options:
|
||||
- `--message`, `-m`: Messages in 'role:content' format. Can be specified multiple times to create a conversation.
|
||||
- `--temperature`, `-t`: Sampling temperature between 0 and 2
|
||||
- `--top-p`: Nucleus sampling parameter between 0 and 1
|
||||
- `--n`: Number of completions to generate
|
||||
- `--max-tokens`: Maximum number of tokens to generate
|
||||
- `--presence-penalty`: Presence penalty between -2.0 and 2.0
|
||||
- `--frequency-penalty`: Frequency penalty between -2.0 and 2.0
|
||||
- `--user`: Unique identifier for the end user
|
||||
|
||||
Examples:
|
||||
|
||||
1. Simple completion:
|
||||
```bash
|
||||
litellm-proxy chat completions gpt-4 -m "user:Hello, how are you?"
|
||||
```
|
||||
|
||||
2. Multi-message conversation:
|
||||
```bash
|
||||
litellm-proxy chat completions gpt-4 \
|
||||
-m "system:You are a helpful assistant" \
|
||||
-m "user:What's the capital of France?" \
|
||||
-m "assistant:The capital of France is Paris." \
|
||||
-m "user:What's its population?"
|
||||
```
|
||||
|
||||
3. With generation parameters:
|
||||
```bash
|
||||
litellm-proxy chat completions gpt-4 \
|
||||
-m "user:Write a story" \
|
||||
--temperature 0.7 \
|
||||
--max-tokens 500 \
|
||||
--top-p 0.9
|
||||
```
|
||||
|
||||
### HTTP Commands
|
||||
|
||||
The CLI provides commands for making direct HTTP requests to your LiteLLM proxy server:
|
||||
|
||||
#### Make HTTP Request
|
||||
|
||||
Make an HTTP request to any endpoint:
|
||||
|
||||
```bash
|
||||
litellm-proxy http request <method> <uri> [options]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
- `method`: HTTP method (GET, POST, PUT, DELETE, etc.)
|
||||
- `uri`: URI path (will be appended to base_url)
|
||||
|
||||
Options:
|
||||
- `--data`, `-d`: Data to send in the request body (as JSON string)
|
||||
- `--json`, `-j`: JSON data to send in the request body (as JSON string)
|
||||
- `--header`, `-H`: HTTP headers in 'key:value' format. Can be specified multiple times.
|
||||
|
||||
Examples:
|
||||
|
||||
1. List models:
|
||||
```bash
|
||||
litellm-proxy http request GET /models
|
||||
```
|
||||
|
||||
2. Create a chat completion:
|
||||
```bash
|
||||
litellm-proxy http request POST /chat/completions -j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
```
|
||||
|
||||
3. Test connection with custom headers:
|
||||
```bash
|
||||
litellm-proxy http request GET /health/test_connection -H "X-Custom-Header:value"
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
The CLI respects the following environment variables:
|
||||
|
||||
- `LITELLM_PROXY_URL`: Base URL of the proxy server
|
||||
- `LITELLM_PROXY_API_KEY`: API key for authentication
|
||||
|
||||
## Examples
|
||||
|
||||
1. List all models in table format:
|
||||
|
||||
```bash
|
||||
litellm-proxy models list
|
||||
```
|
||||
|
||||
2. Add a new model with parameters:
|
||||
|
||||
```bash
|
||||
litellm-proxy models add gpt-4 -p api_key=sk-123 -p max_tokens=2048
|
||||
```
|
||||
|
||||
3. Get model information in JSON format:
|
||||
|
||||
```bash
|
||||
litellm-proxy models info --format json
|
||||
```
|
||||
|
||||
4. Update model parameters:
|
||||
|
||||
```bash
|
||||
litellm-proxy models update model-123 -p temperature=0.7 -i description="Updated model"
|
||||
```
|
||||
|
||||
5. List all credentials in table format:
|
||||
|
||||
```bash
|
||||
litellm-proxy credentials list
|
||||
```
|
||||
|
||||
6. Create a new credential for Azure:
|
||||
|
||||
```bash
|
||||
litellm-proxy credentials create azure-prod \
|
||||
--info '{"custom_llm_provider": "azure"}' \
|
||||
--values '{"api_key": "sk-123", "api_base": "https://prod.azure.openai.com"}'
|
||||
```
|
||||
|
||||
7. Make a custom HTTP request:
|
||||
|
||||
```bash
|
||||
litellm-proxy http request POST /chat/completions \
|
||||
-j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}' \
|
||||
-H "X-Custom-Header:value"
|
||||
```
|
||||
|
||||
8. User management:
|
||||
|
||||
```bash
|
||||
# List users
|
||||
litellm-proxy users list
|
||||
|
||||
# Get user info
|
||||
litellm-proxy users get --id u1
|
||||
|
||||
# Create a user
|
||||
litellm-proxy users create --email a@b.com --role internal_user --alias "Alice" --team team1 --max-budget 100.0
|
||||
|
||||
# Delete users
|
||||
litellm-proxy users delete u1 u2
|
||||
```
|
||||
|
||||
9. Import models from a YAML file (with filters):
|
||||
|
||||
```bash
|
||||
# Only import models where the model name contains 'gpt'
|
||||
litellm-proxy models import models.yaml --only-models-matching-regex gpt
|
||||
|
||||
# Only import models with access group containing 'beta'
|
||||
litellm-proxy models import models.yaml --only-access-groups-matching-regex beta
|
||||
|
||||
# Combine both filters
|
||||
litellm-proxy models import models.yaml --only-models-matching-regex gpt --only-access-groups-matching-regex beta
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The CLI will display appropriate error messages when:
|
||||
|
||||
- The proxy server is not accessible
|
||||
- Authentication fails
|
||||
- Invalid parameters are provided
|
||||
- The requested model or credential doesn't exist
|
||||
- Invalid JSON is provided for credential creation
|
||||
- Any other operation fails
|
||||
|
||||
For detailed debugging, use the `--debug` flag with any command.
|
||||
@@ -0,0 +1,5 @@
|
||||
"""CLI package for LiteLLM Proxy Client."""
|
||||
|
||||
from .main import cli
|
||||
|
||||
__all__ = ["cli"]
|
||||
@@ -0,0 +1 @@
|
||||
"""Command groups for the LiteLLM proxy CLI."""
|
||||
@@ -0,0 +1,623 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from litellm.constants import CLI_JWT_EXPIRATION_HOURS
|
||||
|
||||
|
||||
# Token storage utilities
|
||||
def get_token_file_path() -> str:
|
||||
"""Get the path to store the authentication token"""
|
||||
home_dir = Path.home()
|
||||
config_dir = home_dir / ".litellm"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
return str(config_dir / "token.json")
|
||||
|
||||
|
||||
def save_token(token_data: Dict[str, Any]) -> None:
|
||||
"""Save token data to file"""
|
||||
token_file = get_token_file_path()
|
||||
with open(token_file, "w") as f:
|
||||
json.dump(token_data, f, indent=2)
|
||||
# Set file permissions to be readable only by owner
|
||||
os.chmod(token_file, 0o600)
|
||||
|
||||
|
||||
def load_token() -> Optional[Dict[str, Any]]:
|
||||
"""Load token data from file"""
|
||||
token_file = get_token_file_path()
|
||||
if not os.path.exists(token_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(token_file, "r") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
return None
|
||||
|
||||
|
||||
def clear_token() -> None:
|
||||
"""Clear stored token"""
|
||||
token_file = get_token_file_path()
|
||||
if os.path.exists(token_file):
|
||||
os.remove(token_file)
|
||||
|
||||
|
||||
def get_stored_api_key() -> Optional[str]:
|
||||
"""Get the stored API key from token file"""
|
||||
# Use the SDK-level utility
|
||||
from litellm.litellm_core_utils.cli_token_utils import get_litellm_gateway_api_key
|
||||
|
||||
return get_litellm_gateway_api_key()
|
||||
|
||||
|
||||
# Team selection utilities
|
||||
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
|
||||
"""Display teams in a formatted table"""
|
||||
console = Console()
|
||||
|
||||
if not teams:
|
||||
console.print("❌ No teams found for your user.")
|
||||
return
|
||||
|
||||
table = Table(title="Available Teams")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Team Alias", style="magenta")
|
||||
table.add_column("Team ID", style="green")
|
||||
table.add_column("Models", style="yellow")
|
||||
table.add_column("Max Budget", style="blue")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_alias = team.get("team_alias") or "N/A"
|
||||
team_id = team.get("team_id", "N/A")
|
||||
models = team.get("models", [])
|
||||
max_budget = team.get("max_budget")
|
||||
|
||||
# Format models list
|
||||
if models:
|
||||
if len(models) > 3:
|
||||
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
|
||||
else:
|
||||
models_str = ", ".join(models)
|
||||
else:
|
||||
models_str = "All models"
|
||||
|
||||
# Format budget
|
||||
budget_str = f"${max_budget}" if max_budget else "Unlimited"
|
||||
|
||||
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def get_key_input():
|
||||
"""Get a single key input from the user (cross-platform)"""
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
|
||||
key = msvcrt.getch()
|
||||
if key == b"\xe0": # Arrow keys on Windows
|
||||
key = msvcrt.getch()
|
||||
if key == b"H": # Up arrow
|
||||
return "up"
|
||||
elif key == b"P": # Down arrow
|
||||
return "down"
|
||||
elif key == b"\r": # Enter key
|
||||
return "enter"
|
||||
elif key == b"\x1b": # Escape key
|
||||
return "escape"
|
||||
elif key == b"q":
|
||||
return "quit"
|
||||
return None
|
||||
else:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
key = sys.stdin.read(1)
|
||||
|
||||
if key == "\x1b": # Escape sequence
|
||||
key += sys.stdin.read(2)
|
||||
if key == "\x1b[A": # Up arrow
|
||||
return "up"
|
||||
elif key == "\x1b[B": # Down arrow
|
||||
return "down"
|
||||
elif key == "\x1b": # Just escape
|
||||
return "escape"
|
||||
elif key == "\r" or key == "\n": # Enter key
|
||||
return "enter"
|
||||
elif key == "q":
|
||||
return "quit"
|
||||
return None
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
except ImportError:
|
||||
# Fallback to simple input if termios/msvcrt not available
|
||||
return None
|
||||
|
||||
|
||||
def display_interactive_team_selection(
|
||||
teams: List[Dict[str, Any]], selected_index: int = 0
|
||||
) -> None:
|
||||
"""Display teams with one highlighted for selection"""
|
||||
console = Console()
|
||||
|
||||
# Clear the screen using Rich's method
|
||||
console.clear()
|
||||
|
||||
console.print("🎯 Select a Team (Use ↑↓ arrows, Enter to select, 'q' to skip):\n")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_alias = team.get("team_alias") or "N/A"
|
||||
team_id = team.get("team_id", "N/A")
|
||||
models = team.get("models", [])
|
||||
max_budget = team.get("max_budget")
|
||||
|
||||
# Format models list
|
||||
if models:
|
||||
if len(models) > 3:
|
||||
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
|
||||
else:
|
||||
models_str = ", ".join(models)
|
||||
else:
|
||||
models_str = "All models"
|
||||
|
||||
# Format budget
|
||||
budget_str = f"${max_budget}" if max_budget else "Unlimited"
|
||||
|
||||
# Highlight the selected item
|
||||
if i == selected_index:
|
||||
console.print(f"➤ [bold cyan]{team_alias}[/bold cyan] ({team_id})")
|
||||
console.print(f" Models: [yellow]{models_str}[/yellow]")
|
||||
console.print(f" Budget: [blue]{budget_str}[/blue]\n")
|
||||
else:
|
||||
console.print(f" [dim]{team_alias}[/dim] ({team_id})")
|
||||
console.print(f" Models: [dim]{models_str}[/dim]")
|
||||
console.print(f" Budget: [dim]{budget_str}[/dim]\n")
|
||||
|
||||
|
||||
def prompt_team_selection(teams: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Interactive team selection with arrow keys"""
|
||||
if not teams:
|
||||
return None
|
||||
|
||||
selected_index = 0
|
||||
|
||||
try:
|
||||
# Check if we can use interactive mode
|
||||
if not sys.stdin.isatty():
|
||||
# Fallback to simple selection for non-interactive environments
|
||||
return prompt_team_selection_fallback(teams)
|
||||
|
||||
while True:
|
||||
display_interactive_team_selection(teams, selected_index)
|
||||
|
||||
key = get_key_input()
|
||||
|
||||
if key == "up":
|
||||
selected_index = (selected_index - 1) % len(teams)
|
||||
elif key == "down":
|
||||
selected_index = (selected_index + 1) % len(teams)
|
||||
elif key == "enter":
|
||||
selected_team = teams[selected_index]
|
||||
# Clear screen and show selection
|
||||
console = Console()
|
||||
console.clear()
|
||||
click.echo(
|
||||
f"✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
|
||||
)
|
||||
return selected_team
|
||||
elif key == "quit" or key == "escape":
|
||||
# Clear screen
|
||||
console = Console()
|
||||
console.clear()
|
||||
click.echo("ℹ️ Team selection skipped.")
|
||||
return None
|
||||
elif key is None:
|
||||
# If we can't get key input, fall back to simple selection
|
||||
return prompt_team_selection_fallback(teams)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console = Console()
|
||||
console.clear()
|
||||
click.echo("\n❌ Team selection cancelled.")
|
||||
return None
|
||||
except Exception:
|
||||
# If interactive mode fails, fall back to simple selection
|
||||
return prompt_team_selection_fallback(teams)
|
||||
|
||||
|
||||
def prompt_team_selection_fallback(
|
||||
teams: List[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fallback team selection for non-interactive environments"""
|
||||
if not teams:
|
||||
return None
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = click.prompt(
|
||||
"\nSelect a team by entering the index number (or 'skip' to continue without a team)",
|
||||
type=str,
|
||||
).strip()
|
||||
|
||||
if choice.lower() == "skip":
|
||||
return None
|
||||
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(teams):
|
||||
selected_team = teams[index]
|
||||
click.echo(
|
||||
f"\n✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
|
||||
)
|
||||
return selected_team
|
||||
else:
|
||||
click.echo(
|
||||
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
|
||||
)
|
||||
except ValueError:
|
||||
click.echo("❌ Invalid input. Please enter a number or 'skip'")
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\n❌ Team selection cancelled.")
|
||||
return None
|
||||
|
||||
|
||||
# Polling-based authentication - no local server needed
|
||||
def _poll_for_ready_data(
|
||||
url: str,
|
||||
*,
|
||||
total_timeout: int = 300,
|
||||
poll_interval: int = 2,
|
||||
request_timeout: int = 10,
|
||||
pending_message: Optional[str] = None,
|
||||
pending_log_every: int = 10,
|
||||
other_status_message: Optional[str] = None,
|
||||
other_status_log_every: int = 10,
|
||||
http_error_log_every: int = 10,
|
||||
connection_error_log_every: int = 10,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
for attempt in range(total_timeout // poll_interval):
|
||||
try:
|
||||
response = requests.get(url, timeout=request_timeout)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
if status == "ready":
|
||||
return data
|
||||
if status == "pending":
|
||||
if (
|
||||
pending_message
|
||||
and pending_log_every > 0
|
||||
and attempt % pending_log_every == 0
|
||||
):
|
||||
click.echo(pending_message)
|
||||
elif (
|
||||
other_status_message
|
||||
and other_status_log_every > 0
|
||||
and attempt % other_status_log_every == 0
|
||||
):
|
||||
click.echo(other_status_message)
|
||||
elif http_error_log_every > 0 and attempt % http_error_log_every == 0:
|
||||
click.echo(f"Polling error: HTTP {response.status_code}")
|
||||
except requests.RequestException as e:
|
||||
if (
|
||||
connection_error_log_every > 0
|
||||
and attempt % connection_error_log_every == 0
|
||||
):
|
||||
click.echo(f"Connection error (will retry): {e}")
|
||||
time.sleep(poll_interval)
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_teams(teams, team_details):
|
||||
"""If team_details are a
|
||||
|
||||
Args:
|
||||
teams (_type_): _description_
|
||||
team_details (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
if isinstance(team_details, list) and team_details:
|
||||
return [
|
||||
{
|
||||
"team_id": i.get("team_id") or i.get("id"),
|
||||
"team_alias": i.get("team_alias"),
|
||||
}
|
||||
for i in team_details
|
||||
if isinstance(i, dict) and (i.get("team_id") or i.get("id"))
|
||||
]
|
||||
if isinstance(teams, list):
|
||||
return [{"team_id": str(t), "team_alias": None} for t in teams]
|
||||
return []
|
||||
|
||||
|
||||
def _poll_for_authentication(base_url: str, key_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Poll the server for authentication completion and handle team selection.
|
||||
|
||||
Returns:
|
||||
Dictionary with authentication data if successful, None otherwise
|
||||
"""
|
||||
poll_url = f"{base_url}/sso/cli/poll/{key_id}"
|
||||
data = _poll_for_ready_data(
|
||||
poll_url,
|
||||
pending_message="Still waiting for authentication...",
|
||||
)
|
||||
if not data:
|
||||
return None
|
||||
if data.get("requires_team_selection"):
|
||||
teams = data.get("teams", [])
|
||||
team_details = data.get("team_details")
|
||||
user_id = data.get("user_id")
|
||||
normalized_teams: List[Dict[str, Any]] = _normalize_teams(teams, team_details)
|
||||
if not normalized_teams:
|
||||
click.echo("⚠️ No teams available for selection.")
|
||||
return None
|
||||
|
||||
# User has multiple teams - let them select
|
||||
jwt_with_team = _handle_team_selection_during_polling(
|
||||
base_url=base_url,
|
||||
key_id=key_id,
|
||||
teams=normalized_teams,
|
||||
)
|
||||
|
||||
# Use the team-specific JWT if selection succeeded
|
||||
if jwt_with_team:
|
||||
return {
|
||||
"api_key": jwt_with_team,
|
||||
"user_id": user_id,
|
||||
"teams": teams,
|
||||
"team_id": None, # Set by server in JWT
|
||||
}
|
||||
|
||||
click.echo("❌ Team selection cancelled or JWT generation failed.")
|
||||
return None
|
||||
|
||||
# JWT is ready (single team or team already selected)
|
||||
api_key = data.get("key")
|
||||
user_id = data.get("user_id")
|
||||
teams = data.get("teams", [])
|
||||
team_id = data.get("team_id")
|
||||
|
||||
# Show which team was assigned
|
||||
if team_id and len(teams) == 1:
|
||||
click.echo(f"\n✅ Automatically assigned to team: {team_id}")
|
||||
|
||||
if api_key:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"user_id": user_id,
|
||||
"teams": teams,
|
||||
"team_id": team_id,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _handle_team_selection_during_polling(
|
||||
base_url: str, key_id: str, teams: List[Dict[str, Any]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Handle team selection and re-poll with selected team_id.
|
||||
|
||||
Args:
|
||||
teams: List of team IDs (strings)
|
||||
|
||||
Returns:
|
||||
The JWT token with the selected team, or None if selection was skipped
|
||||
"""
|
||||
if not teams:
|
||||
click.echo(
|
||||
"ℹ️ No teams found. You can create or join teams using the web interface."
|
||||
)
|
||||
return None
|
||||
|
||||
click.echo("\n" + "=" * 60)
|
||||
click.echo("📋 Select a team for your CLI session...")
|
||||
|
||||
team_id = _render_and_prompt_for_team_selection(teams)
|
||||
|
||||
if not team_id:
|
||||
click.echo("ℹ️ No team selected.")
|
||||
return None
|
||||
|
||||
click.echo(f"\n🔄 Generating JWT for team: {team_id}")
|
||||
|
||||
poll_url = f"{base_url}/sso/cli/poll/{key_id}?team_id={team_id}"
|
||||
data = _poll_for_ready_data(
|
||||
poll_url,
|
||||
pending_message="Still waiting for team authentication...",
|
||||
other_status_message="Waiting for team authentication to complete...",
|
||||
http_error_log_every=10,
|
||||
)
|
||||
if not data:
|
||||
return None
|
||||
jwt_token = data.get("key")
|
||||
if jwt_token:
|
||||
click.echo(f"✅ Successfully generated JWT for team: {team_id}")
|
||||
return jwt_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _render_and_prompt_for_team_selection(teams: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Render teams table and prompt user for a team selection.
|
||||
|
||||
Returns the selected team_id as a string, or None if selection was
|
||||
cancelled or skipped without any teams available.
|
||||
"""
|
||||
# Display teams as a simple list, but prefer showing aliases where
|
||||
# available while still keeping the underlying IDs intact.
|
||||
console = Console()
|
||||
table = Table(title="Available Teams")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Team Name", style="magenta")
|
||||
table.add_column("Team ID", style="green")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_id = str(team.get("team_id"))
|
||||
team_alias = team.get("team_alias") or team_id
|
||||
table.add_row(str(i + 1), team_alias, team_id)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Simple selection
|
||||
while True:
|
||||
try:
|
||||
choice = click.prompt(
|
||||
"\nSelect a team by entering the index number (or 'skip' to use first team)",
|
||||
type=str,
|
||||
).strip()
|
||||
|
||||
if choice.lower() == "skip":
|
||||
# Default to the first team's ID if the user skips an
|
||||
# explicit selection.
|
||||
if teams:
|
||||
first_team = teams[0]
|
||||
return str(first_team.get("team_id"))
|
||||
return None
|
||||
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(teams):
|
||||
selected_team = teams[index]
|
||||
team_id = str(selected_team.get("team_id"))
|
||||
team_alias = selected_team.get("team_alias") or team_id
|
||||
click.echo(f"\n✅ Selected team: {team_alias} ({team_id})")
|
||||
return team_id
|
||||
|
||||
click.echo(
|
||||
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
|
||||
)
|
||||
except ValueError:
|
||||
click.echo("❌ Invalid input. Please enter a number or 'skip'")
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\n❌ Team selection cancelled.")
|
||||
return None
|
||||
|
||||
|
||||
@click.command(name="login")
|
||||
@click.pass_context
|
||||
def login(ctx: click.Context):
|
||||
"""Login to LiteLLM proxy using SSO authentication"""
|
||||
from litellm._uuid import uuid
|
||||
from litellm.constants import LITELLM_CLI_SOURCE_IDENTIFIER
|
||||
from litellm.proxy.client.cli.interface import show_commands
|
||||
|
||||
base_url = ctx.obj["base_url"]
|
||||
|
||||
# Check if we have an existing key to regenerate
|
||||
existing_key = get_stored_api_key()
|
||||
|
||||
# Generate unique key ID for this login session
|
||||
key_id = f"sk-{str(uuid.uuid4())}"
|
||||
|
||||
try:
|
||||
# Construct SSO login URL with CLI source and pre-generated key
|
||||
sso_url = f"{base_url}/sso/key/generate?source={LITELLM_CLI_SOURCE_IDENTIFIER}&key={key_id}"
|
||||
|
||||
# If we have an existing key, include it as a parameter to the login endpoint
|
||||
# The server will encode it in the OAuth state parameter for the SSO flow
|
||||
if existing_key:
|
||||
sso_url += f"&existing_key={existing_key}"
|
||||
|
||||
click.echo(f"Opening browser to: {sso_url}")
|
||||
click.echo("Please complete the SSO authentication in your browser...")
|
||||
click.echo(f"Session ID: {key_id}")
|
||||
|
||||
# Open browser
|
||||
webbrowser.open(sso_url)
|
||||
|
||||
# Poll for authentication completion
|
||||
click.echo("Waiting for authentication...")
|
||||
|
||||
auth_result = _poll_for_authentication(base_url=base_url, key_id=key_id)
|
||||
|
||||
if auth_result:
|
||||
api_key = auth_result["api_key"]
|
||||
user_id = auth_result["user_id"]
|
||||
|
||||
# Save token data (simplified for CLI - we just need the key)
|
||||
save_token(
|
||||
{
|
||||
"key": api_key,
|
||||
"user_id": user_id or "cli-user",
|
||||
"user_email": "unknown",
|
||||
"user_role": "cli",
|
||||
"auth_header_name": "Authorization",
|
||||
"jwt_token": "",
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
click.echo("\n✅ Login successful!")
|
||||
click.echo(f"JWT Token: {api_key[:20]}...")
|
||||
click.echo("You can now use the CLI without specifying --api-key")
|
||||
|
||||
# Show available commands after successful login
|
||||
click.echo("\n" + "=" * 60)
|
||||
show_commands()
|
||||
return
|
||||
else:
|
||||
click.echo("❌ Authentication timed out. Please try again.")
|
||||
return
|
||||
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\n❌ Authentication cancelled by user.")
|
||||
return
|
||||
except Exception as e:
|
||||
click.echo(f"❌ Authentication failed: {e}")
|
||||
return
|
||||
|
||||
|
||||
@click.command(name="logout")
|
||||
def logout():
|
||||
"""Logout and clear stored authentication"""
|
||||
clear_token()
|
||||
click.echo("✅ Logged out successfully. Authentication token cleared.")
|
||||
|
||||
|
||||
@click.command(name="whoami")
|
||||
def whoami():
|
||||
"""Show current authentication status"""
|
||||
token_data = load_token()
|
||||
|
||||
if not token_data:
|
||||
click.echo("❌ Not authenticated. Run 'litellm-proxy login' to authenticate.")
|
||||
return
|
||||
|
||||
click.echo("✅ Authenticated")
|
||||
click.echo(f"User Email: {token_data.get('user_email', 'Unknown')}")
|
||||
click.echo(f"User ID: {token_data.get('user_id', 'Unknown')}")
|
||||
click.echo(f"User Role: {token_data.get('user_role', 'Unknown')}")
|
||||
|
||||
# Check if token is still valid (basic timestamp check)
|
||||
timestamp = token_data.get("timestamp", 0)
|
||||
age_hours = (time.time() - timestamp) / 3600
|
||||
click.echo(f"Token age: {age_hours:.1f} hours")
|
||||
|
||||
if age_hours > CLI_JWT_EXPIRATION_HOURS:
|
||||
click.echo(
|
||||
f"⚠️ Warning: Token is more than {CLI_JWT_EXPIRATION_HOURS} hours old and may have expired."
|
||||
)
|
||||
|
||||
|
||||
# Export functions for use by other CLI commands
|
||||
__all__ = ["login", "logout", "whoami", "prompt_team_selection"]
|
||||
|
||||
# Export individual commands instead of grouping them
|
||||
# login, logout, and whoami will be added as top-level commands
|
||||
@@ -0,0 +1,406 @@
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
from rich.table import Table
|
||||
|
||||
from ... import Client
|
||||
from ...chat import ChatClient
|
||||
|
||||
|
||||
def _get_available_models(ctx: click.Context) -> List[Dict[str, Any]]:
|
||||
"""Get list of available models from the proxy server"""
|
||||
try:
|
||||
client = Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
|
||||
models_list = client.models.list()
|
||||
# Ensure we return a list of dictionaries
|
||||
if isinstance(models_list, list):
|
||||
# Filter to ensure all items are dictionaries
|
||||
return [model for model in models_list if isinstance(model, dict)]
|
||||
return []
|
||||
except Exception as e:
|
||||
click.echo(f"Warning: Could not fetch models list: {e}", err=True)
|
||||
return []
|
||||
|
||||
|
||||
def _select_model(
|
||||
console: Console, available_models: List[Dict[str, Any]]
|
||||
) -> Optional[str]:
|
||||
"""Interactive model selection"""
|
||||
if not available_models:
|
||||
console.print(
|
||||
"[yellow]No models available or could not fetch models list.[/yellow]"
|
||||
)
|
||||
model_name = Prompt.ask("Please enter a model name")
|
||||
return model_name if model_name.strip() else None
|
||||
|
||||
# Display available models in a table
|
||||
table = Table(title="Available Models")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Model ID", style="green")
|
||||
table.add_column("Owned By", style="yellow")
|
||||
MAX_MODELS_TO_DISPLAY = 200
|
||||
|
||||
models_to_display: List[Dict[str, Any]] = available_models[:MAX_MODELS_TO_DISPLAY]
|
||||
for i, model in enumerate(models_to_display): # Limit to first 200 models
|
||||
table.add_row(
|
||||
str(i + 1), str(model.get("id", "")), str(model.get("owned_by", ""))
|
||||
)
|
||||
|
||||
if len(available_models) > MAX_MODELS_TO_DISPLAY:
|
||||
console.print(
|
||||
f"\n[dim]... and {len(available_models) - MAX_MODELS_TO_DISPLAY} more models[/dim]"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = Prompt.ask(
|
||||
"\nSelect a model by entering the index number (or type a model name directly)",
|
||||
default="1",
|
||||
).strip()
|
||||
|
||||
# Try to parse as index
|
||||
try:
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(available_models):
|
||||
return available_models[index]["id"]
|
||||
else:
|
||||
console.print(
|
||||
f"[red]Invalid index. Please enter a number between 1 and {len(available_models)}[/red]"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
# Not a number, treat as model name
|
||||
if choice:
|
||||
return choice
|
||||
else:
|
||||
console.print("[red]Please enter a valid model name or index[/red]")
|
||||
continue
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Model selection cancelled.[/yellow]")
|
||||
return None
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("model", required=False)
|
||||
@click.option(
|
||||
"--temperature",
|
||||
"-t",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Sampling temperature between 0 and 2 (default: 0.7)",
|
||||
)
|
||||
@click.option(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
@click.option(
|
||||
"--system",
|
||||
"-s",
|
||||
type=str,
|
||||
help="System message to set the behavior of the assistant",
|
||||
)
|
||||
@click.pass_context
|
||||
def chat(
|
||||
ctx: click.Context,
|
||||
model: Optional[str],
|
||||
temperature: float,
|
||||
max_tokens: Optional[int] = None,
|
||||
system: Optional[str] = None,
|
||||
):
|
||||
"""Interactive chat with streaming responses
|
||||
|
||||
Examples:
|
||||
|
||||
# Chat with a specific model
|
||||
litellm-proxy chat gpt-4
|
||||
|
||||
# Chat without specifying model (will show model selection)
|
||||
litellm-proxy chat
|
||||
|
||||
# Chat with custom settings
|
||||
litellm-proxy chat gpt-4 --temperature 0.9 --system "You are a helpful coding assistant"
|
||||
"""
|
||||
console = Console()
|
||||
|
||||
# If no model specified, show model selection
|
||||
if not model:
|
||||
available_models = _get_available_models(ctx)
|
||||
model = _select_model(console, available_models)
|
||||
if not model:
|
||||
console.print("[red]No model selected. Exiting.[/red]")
|
||||
return
|
||||
|
||||
client = ChatClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
# Initialize conversation history
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
# Add system message if provided
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
|
||||
# Display welcome message
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold blue]LiteLLM Interactive Chat[/bold blue]\n"
|
||||
f"Model: [green]{model}[/green]\n"
|
||||
f"Temperature: [yellow]{temperature}[/yellow]\n"
|
||||
f"Max Tokens: [yellow]{max_tokens or 'unlimited'}[/yellow]\n\n"
|
||||
f"Type your messages and press Enter. Type '/quit' or '/exit' to end the session.\n"
|
||||
f"Type '/help' for more commands.",
|
||||
title="🤖 Chat Session",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
try:
|
||||
user_input = console.input("\n[bold cyan]You:[/bold cyan] ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
console.print("\n[yellow]Chat session ended.[/yellow]")
|
||||
break
|
||||
|
||||
# Handle special commands
|
||||
should_exit, messages, new_model = _handle_special_commands(
|
||||
console, user_input, messages, system, ctx
|
||||
)
|
||||
|
||||
if should_exit:
|
||||
break
|
||||
if new_model:
|
||||
model = new_model
|
||||
|
||||
# Check if this was a special command that was handled (not a normal message)
|
||||
if (
|
||||
user_input.lower().startswith(
|
||||
(
|
||||
"/quit",
|
||||
"/exit",
|
||||
"/q",
|
||||
"/help",
|
||||
"/clear",
|
||||
"/history",
|
||||
"/save",
|
||||
"/load",
|
||||
"/model",
|
||||
)
|
||||
)
|
||||
or not user_input
|
||||
):
|
||||
continue
|
||||
|
||||
# Add user message to conversation
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Display assistant label
|
||||
console.print("\n[bold green]Assistant:[/bold green]")
|
||||
|
||||
# Stream the response
|
||||
assistant_content = _stream_response(
|
||||
console=console,
|
||||
client=client,
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Add assistant message to conversation history
|
||||
if assistant_content:
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
else:
|
||||
console.print("[red]Error: No content received from the model[/red]")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Chat session interrupted.[/yellow]")
|
||||
|
||||
|
||||
def _show_help(console: Console):
|
||||
"""Show help for interactive chat commands"""
|
||||
help_text = """
|
||||
[bold]Interactive Chat Commands:[/bold]
|
||||
|
||||
[cyan]/help[/cyan] - Show this help message
|
||||
[cyan]/quit[/cyan] - Exit the chat session (also /exit, /q)
|
||||
[cyan]/clear[/cyan] - Clear conversation history
|
||||
[cyan]/history[/cyan] - Show conversation history
|
||||
[cyan]/model[/cyan] - Switch to a different model
|
||||
[cyan]/save <name>[/cyan] - Save conversation to file
|
||||
[cyan]/load <name>[/cyan] - Load conversation from file
|
||||
|
||||
[bold]Tips:[/bold]
|
||||
- Your conversation history is maintained during the session
|
||||
- Use Ctrl+C to interrupt at any time
|
||||
- Responses are streamed in real-time
|
||||
- You can switch models mid-conversation with /model
|
||||
"""
|
||||
console.print(Panel(help_text, title="Help"))
|
||||
|
||||
|
||||
def _show_history(console: Console, messages: List[Dict[str, Any]]):
|
||||
"""Show conversation history"""
|
||||
if not messages:
|
||||
console.print("[yellow]No conversation history.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(Panel.fit("[bold]Conversation History[/bold]", title="History"))
|
||||
|
||||
for i, message in enumerate(messages, 1):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
console.print(
|
||||
f"[dim]{i}. [bold magenta]System:[/bold magenta] {content}[/dim]"
|
||||
)
|
||||
elif role == "user":
|
||||
console.print(f"{i}. [bold cyan]You:[/bold cyan] {content}")
|
||||
elif role == "assistant":
|
||||
console.print(
|
||||
f"{i}. [bold green]Assistant:[/bold green] {content[:100]}{'...' if len(content) > 100 else ''}"
|
||||
)
|
||||
|
||||
|
||||
def _save_conversation(console: Console, messages: List[Dict[str, Any]], command: str):
|
||||
"""Save conversation to a file"""
|
||||
parts = command.split()
|
||||
if len(parts) < 2:
|
||||
console.print("[red]Usage: /save <filename>[/red]")
|
||||
return
|
||||
|
||||
filename = parts[1]
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(messages, f, indent=2)
|
||||
console.print(f"[green]Conversation saved to {filename}[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error saving conversation: {e}[/red]")
|
||||
|
||||
|
||||
def _load_conversation(
|
||||
console: Console, command: str, system: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Load conversation from a file"""
|
||||
parts = command.split()
|
||||
if len(parts) < 2:
|
||||
console.print("[red]Usage: /load <filename>[/red]")
|
||||
return []
|
||||
|
||||
filename = parts[1]
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
messages = json.load(f)
|
||||
console.print(f"[green]Conversation loaded from {filename}[/green]")
|
||||
return messages
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]File not found: {filename}[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error loading conversation: {e}[/red]")
|
||||
|
||||
# Return empty list or just system message if load failed
|
||||
if system:
|
||||
return [{"role": "system", "content": system}]
|
||||
return []
|
||||
|
||||
|
||||
def _handle_special_commands(
|
||||
console: Console,
|
||||
user_input: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
system: Optional[str],
|
||||
ctx: click.Context,
|
||||
) -> tuple[bool, List[Dict[str, Any]], Optional[str]]:
|
||||
"""Handle special chat commands. Returns (should_exit, updated_messages, updated_model)"""
|
||||
if user_input.lower() in ["/quit", "/exit", "/q"]:
|
||||
console.print("[yellow]Chat session ended.[/yellow]")
|
||||
return True, messages, None
|
||||
elif user_input.lower() == "/help":
|
||||
_show_help(console)
|
||||
return False, messages, None
|
||||
elif user_input.lower() == "/clear":
|
||||
new_messages = []
|
||||
if system:
|
||||
new_messages.append({"role": "system", "content": system})
|
||||
console.print("[green]Conversation history cleared.[/green]")
|
||||
return False, new_messages, None
|
||||
elif user_input.lower() == "/history":
|
||||
_show_history(console, messages)
|
||||
return False, messages, None
|
||||
elif user_input.lower().startswith("/save"):
|
||||
_save_conversation(console, messages, user_input)
|
||||
return False, messages, None
|
||||
elif user_input.lower().startswith("/load"):
|
||||
new_messages = _load_conversation(console, user_input, system)
|
||||
return False, new_messages, None
|
||||
elif user_input.lower() == "/model":
|
||||
available_models = _get_available_models(ctx)
|
||||
new_model = _select_model(console, available_models)
|
||||
if new_model:
|
||||
console.print(f"[green]Switched to model: {new_model}[/green]")
|
||||
return False, messages, new_model
|
||||
return False, messages, None
|
||||
elif not user_input:
|
||||
return False, messages, None
|
||||
|
||||
# Not a special command
|
||||
return False, messages, None
|
||||
|
||||
|
||||
def _stream_response(
|
||||
console: Console,
|
||||
client: ChatClient,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
) -> Optional[str]:
|
||||
"""Stream the model response and return the complete content"""
|
||||
try:
|
||||
assistant_content = ""
|
||||
for chunk in client.completions_stream(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
if "choices" in chunk and len(chunk["choices"]) > 0:
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
assistant_content += content
|
||||
console.print(content, end="")
|
||||
sys.stdout.flush()
|
||||
|
||||
console.print() # Add newline after streaming
|
||||
return assistant_content if assistant_content else None
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
console.print(f"\n[red]Error: HTTP {e.response.status_code}[/red]")
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
console.print(
|
||||
f"[red]{error_body.get('error', {}).get('message', 'Unknown error')}[/red]"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
console.print(f"[red]{e.response.text}[/red]")
|
||||
return None
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]Error: {str(e)}[/red]")
|
||||
return None
|
||||
@@ -0,0 +1,116 @@
|
||||
import json
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
import rich
|
||||
import requests
|
||||
from rich.table import Table
|
||||
|
||||
from ...credentials import CredentialsManagementClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def credentials():
|
||||
"""Manage credentials for the LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.pass_context
|
||||
def list(ctx: click.Context, output_format: Literal["table", "json"]):
|
||||
"""List all credentials"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
response = client.list()
|
||||
assert isinstance(response, dict)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=response)
|
||||
else: # table format
|
||||
table = Table(title="Credentials")
|
||||
|
||||
# Add columns
|
||||
table.add_column("Credential Name", style="cyan")
|
||||
table.add_column("Custom LLM Provider", style="green")
|
||||
|
||||
# Add rows
|
||||
for cred in response.get("credentials", []):
|
||||
info = cred.get("credential_info", {})
|
||||
table.add_row(
|
||||
str(cred.get("credential_name", "")),
|
||||
str(info.get("custom_llm_provider", "")),
|
||||
)
|
||||
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.argument("credential_name")
|
||||
@click.option(
|
||||
"--info",
|
||||
type=str,
|
||||
help="JSON string containing credential info",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--values",
|
||||
type=str,
|
||||
help="JSON string containing credential values",
|
||||
required=True,
|
||||
)
|
||||
@click.pass_context
|
||||
def create(ctx: click.Context, credential_name: str, info: str, values: str):
|
||||
"""Create a new credential"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
credential_info = json.loads(info)
|
||||
credential_values = json.loads(values)
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.BadParameter(f"Invalid JSON: {str(e)}")
|
||||
|
||||
try:
|
||||
response = client.create(credential_name, credential_info, credential_values)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.argument("credential_name")
|
||||
@click.pass_context
|
||||
def delete(ctx: click.Context, credential_name: str):
|
||||
"""Delete a credential by name"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
response = client.delete(credential_name)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.argument("credential_name")
|
||||
@click.pass_context
|
||||
def get(ctx: click.Context, credential_name: str):
|
||||
"""Get a credential by name"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
response = client.get(credential_name)
|
||||
rich.print_json(data=response)
|
||||
@@ -0,0 +1,102 @@
|
||||
import json as json_lib
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import rich
|
||||
import requests
|
||||
|
||||
from ...http_client import HTTPClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def http():
|
||||
"""Make HTTP requests to the LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@http.command()
|
||||
@click.argument("method")
|
||||
@click.argument("uri")
|
||||
@click.option(
|
||||
"--data",
|
||||
"-d",
|
||||
type=str,
|
||||
help="Data to send in the request body (as JSON string)",
|
||||
)
|
||||
@click.option(
|
||||
"--json",
|
||||
"-j",
|
||||
type=str,
|
||||
help="JSON data to send in the request body (as JSON string)",
|
||||
)
|
||||
@click.option(
|
||||
"--header",
|
||||
"-H",
|
||||
multiple=True,
|
||||
help="HTTP headers in 'key:value' format. Can be specified multiple times.",
|
||||
)
|
||||
@click.pass_context
|
||||
def request(
|
||||
ctx: click.Context,
|
||||
method: str,
|
||||
uri: str,
|
||||
data: Optional[str] = None,
|
||||
json: Optional[str] = None,
|
||||
header: tuple[str, ...] = (),
|
||||
):
|
||||
"""Make an HTTP request to the LiteLLM proxy server
|
||||
|
||||
METHOD: HTTP method (GET, POST, PUT, DELETE, etc.)
|
||||
URI: URI path (will be appended to base_url)
|
||||
|
||||
Examples:
|
||||
litellm http request GET /models
|
||||
litellm http request POST /chat/completions -j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
litellm http request GET /health/test_connection -H "X-Custom-Header:value"
|
||||
"""
|
||||
# Parse headers from key:value format
|
||||
headers = {}
|
||||
for h in header:
|
||||
try:
|
||||
key, value = h.split(":", 1)
|
||||
headers[key.strip()] = value.strip()
|
||||
except ValueError:
|
||||
raise click.BadParameter(
|
||||
f"Invalid header format: {h}. Expected format: 'key:value'"
|
||||
)
|
||||
|
||||
# Parse JSON data if provided
|
||||
json_data = None
|
||||
if json:
|
||||
try:
|
||||
json_data = json_lib.loads(json)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(f"Invalid JSON format: {e}")
|
||||
|
||||
# Parse data if provided
|
||||
request_data = None
|
||||
if data:
|
||||
try:
|
||||
request_data = json_lib.loads(data)
|
||||
except ValueError:
|
||||
# If not JSON, use as raw data
|
||||
request_data = data
|
||||
|
||||
client = HTTPClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
response = client.request(
|
||||
method=method,
|
||||
uri=uri,
|
||||
data=request_data,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json_lib.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
@@ -0,0 +1,415 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional, List, Dict, Any
|
||||
|
||||
import click
|
||||
import rich
|
||||
import requests
|
||||
from rich.table import Table
|
||||
|
||||
from ...keys import KeysManagementClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def keys():
|
||||
"""Manage API keys for the LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@keys.command()
|
||||
@click.option("--page", type=int, help="Page number for pagination")
|
||||
@click.option("--size", type=int, help="Number of items per page")
|
||||
@click.option("--user-id", type=str, help="Filter keys by user ID")
|
||||
@click.option("--team-id", type=str, help="Filter keys by team ID")
|
||||
@click.option("--organization-id", type=str, help="Filter keys by organization ID")
|
||||
@click.option("--key-hash", type=str, help="Filter by specific key hash")
|
||||
@click.option("--key-alias", type=str, help="Filter by key alias")
|
||||
@click.option(
|
||||
"--return-full-object",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Return the full key object",
|
||||
)
|
||||
@click.option(
|
||||
"--include-team-keys", is_flag=True, help="Include team keys in the response"
|
||||
)
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.pass_context
|
||||
def list(
|
||||
ctx: click.Context,
|
||||
page: Optional[int],
|
||||
size: Optional[int],
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
organization_id: Optional[str],
|
||||
key_hash: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
include_team_keys: bool,
|
||||
output_format: Literal["table", "json"],
|
||||
return_full_object: bool,
|
||||
):
|
||||
"""List all API keys"""
|
||||
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
response = client.list(
|
||||
page=page,
|
||||
size=size,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
organization_id=organization_id,
|
||||
key_hash=key_hash,
|
||||
key_alias=key_alias,
|
||||
return_full_object=return_full_object,
|
||||
include_team_keys=include_team_keys,
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=response)
|
||||
else:
|
||||
rich.print(
|
||||
f"Showing {len(response.get('keys', []))} keys out of {response.get('total_count', 0)}"
|
||||
)
|
||||
table = Table(title="API Keys")
|
||||
table.add_column("Key Hash", style="cyan")
|
||||
table.add_column("Alias", style="green")
|
||||
table.add_column("User ID", style="magenta")
|
||||
table.add_column("Team ID", style="yellow")
|
||||
table.add_column("Spend", style="red")
|
||||
for key in response.get("keys", []):
|
||||
table.add_row(
|
||||
str(key.get("token", "")),
|
||||
str(key.get("key_alias", "")),
|
||||
str(key.get("user_id", "")),
|
||||
str(key.get("team_id", "")),
|
||||
str(key.get("spend", "")),
|
||||
)
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@keys.command()
|
||||
@click.option("--models", type=str, help="Comma-separated list of allowed models")
|
||||
@click.option("--aliases", type=str, help="JSON string of model alias mappings")
|
||||
@click.option("--spend", type=float, help="Maximum spend limit for this key")
|
||||
@click.option(
|
||||
"--duration",
|
||||
type=str,
|
||||
help="Duration for which the key is valid (e.g. '24h', '7d')",
|
||||
)
|
||||
@click.option("--key-alias", type=str, help="Alias/name for the key")
|
||||
@click.option("--team-id", type=str, help="Team ID to associate the key with")
|
||||
@click.option("--user-id", type=str, help="User ID to associate the key with")
|
||||
@click.option("--budget-id", type=str, help="Budget ID to associate the key with")
|
||||
@click.option(
|
||||
"--config", type=str, help="JSON string of additional configuration parameters"
|
||||
)
|
||||
@click.pass_context
|
||||
def generate(
|
||||
ctx: click.Context,
|
||||
models: Optional[str],
|
||||
aliases: Optional[str],
|
||||
spend: Optional[float],
|
||||
duration: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
team_id: Optional[str],
|
||||
user_id: Optional[str],
|
||||
budget_id: Optional[str],
|
||||
config: Optional[str],
|
||||
):
|
||||
"""Generate a new API key"""
|
||||
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
models_list = [m.strip() for m in models.split(",")] if models else None
|
||||
aliases_dict = json.loads(aliases) if aliases else None
|
||||
config_dict = json.loads(config) if config else None
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.BadParameter(f"Invalid JSON: {str(e)}")
|
||||
try:
|
||||
response = client.generate(
|
||||
models=models_list,
|
||||
aliases=aliases_dict,
|
||||
spend=spend,
|
||||
duration=duration,
|
||||
key_alias=key_alias,
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
budget_id=budget_id,
|
||||
config=config_dict,
|
||||
)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@keys.command()
|
||||
@click.option("--keys", type=str, help="Comma-separated list of API keys to delete")
|
||||
@click.option(
|
||||
"--key-aliases", type=str, help="Comma-separated list of key aliases to delete"
|
||||
)
|
||||
@click.pass_context
|
||||
def delete(ctx: click.Context, keys: Optional[str], key_aliases: Optional[str]):
|
||||
"""Delete API keys by key or alias"""
|
||||
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
keys_list = [k.strip() for k in keys.split(",")] if keys else None
|
||||
aliases_list = [a.strip() for a in key_aliases.split(",")] if key_aliases else None
|
||||
try:
|
||||
response = client.delete(keys=keys_list, key_aliases=aliases_list)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _parse_created_since_filter(created_since: Optional[str]) -> Optional[datetime]:
|
||||
"""Parse and validate the created_since date filter."""
|
||||
if not created_since:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Support formats: YYYY-MM-DD_HH:MM or YYYY-MM-DD
|
||||
if "_" in created_since:
|
||||
return datetime.strptime(created_since, "%Y-%m-%d_%H:%M")
|
||||
else:
|
||||
return datetime.strptime(created_since, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
click.echo(
|
||||
f"Error: Invalid date format '{created_since}'. Use YYYY-MM-DD_HH:MM or YYYY-MM-DD",
|
||||
err=True,
|
||||
)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _fetch_all_keys_with_pagination(
|
||||
source_client: KeysManagementClient, source_base_url: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Fetch all keys from source instance using pagination."""
|
||||
click.echo(f"Fetching keys from source server: {source_base_url}")
|
||||
source_keys = []
|
||||
page = 1
|
||||
page_size = 100 # Use a larger page size to minimize API calls
|
||||
|
||||
while True:
|
||||
source_response = source_client.list(
|
||||
return_full_object=True, page=page, size=page_size
|
||||
)
|
||||
# source_client.list() returns Dict[str, Any] when return_request is False (default)
|
||||
assert isinstance(source_response, dict), "Expected dict response from list API"
|
||||
page_keys = source_response.get("keys", [])
|
||||
|
||||
if not page_keys:
|
||||
break
|
||||
|
||||
source_keys.extend(page_keys)
|
||||
click.echo(f"Fetched page {page}: {len(page_keys)} keys")
|
||||
|
||||
# Check if we got fewer keys than the page size, indicating last page
|
||||
if len(page_keys) < page_size:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
return source_keys
|
||||
|
||||
|
||||
def _filter_keys_by_created_since(
|
||||
source_keys: List[Dict[str, Any]],
|
||||
created_since_dt: Optional[datetime],
|
||||
created_since: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter keys by created_since date if specified."""
|
||||
if not created_since_dt:
|
||||
return source_keys
|
||||
|
||||
filtered_keys = []
|
||||
for key in source_keys:
|
||||
key_created_at = key.get("created_at")
|
||||
if key_created_at:
|
||||
# Parse the key's created_at timestamp
|
||||
if isinstance(key_created_at, str):
|
||||
if "T" in key_created_at:
|
||||
key_dt = datetime.fromisoformat(
|
||||
key_created_at.replace("Z", "+00:00")
|
||||
)
|
||||
else:
|
||||
key_dt = datetime.fromisoformat(key_created_at)
|
||||
|
||||
# Convert to naive datetime for comparison (assuming UTC)
|
||||
if key_dt.tzinfo:
|
||||
key_dt = key_dt.replace(tzinfo=None)
|
||||
|
||||
if key_dt >= created_since_dt:
|
||||
filtered_keys.append(key)
|
||||
|
||||
click.echo(
|
||||
f"Filtered {len(source_keys)} keys to {len(filtered_keys)} keys created since {created_since}"
|
||||
)
|
||||
return filtered_keys
|
||||
|
||||
|
||||
def _display_dry_run_table(source_keys: List[Dict[str, Any]]) -> None:
|
||||
"""Display a table of keys that would be imported in dry-run mode."""
|
||||
click.echo("\n--- DRY RUN MODE ---")
|
||||
table = Table(title="Keys that would be imported")
|
||||
table.add_column("Key Alias", style="green")
|
||||
table.add_column("User ID", style="magenta")
|
||||
table.add_column("Created", style="cyan")
|
||||
|
||||
for key in source_keys:
|
||||
created_at = key.get("created_at", "")
|
||||
# Format the timestamp if it exists
|
||||
if created_at:
|
||||
# Try to parse and format the timestamp for better readability
|
||||
if isinstance(created_at, str):
|
||||
# Handle common timestamp formats
|
||||
if "T" in created_at:
|
||||
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
|
||||
created_at = dt.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
table.add_row(
|
||||
str(key.get("key_alias", "")), str(key.get("user_id", "")), str(created_at)
|
||||
)
|
||||
rich.print(table)
|
||||
|
||||
|
||||
def _prepare_key_import_data(key: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare key data for import by extracting relevant fields."""
|
||||
import_data = {}
|
||||
|
||||
# Copy relevant fields if they exist
|
||||
for field in [
|
||||
"models",
|
||||
"aliases",
|
||||
"spend",
|
||||
"key_alias",
|
||||
"team_id",
|
||||
"user_id",
|
||||
"budget_id",
|
||||
"config",
|
||||
]:
|
||||
if key.get(field):
|
||||
import_data[field] = key[field]
|
||||
|
||||
return import_data
|
||||
|
||||
|
||||
def _import_keys_to_destination(
|
||||
source_keys: List[Dict[str, Any]], dest_client: KeysManagementClient
|
||||
) -> tuple[int, int]:
|
||||
"""Import each key to the destination instance and return counts."""
|
||||
imported_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for key in source_keys:
|
||||
try:
|
||||
# Prepare key data for import
|
||||
import_data = _prepare_key_import_data(key)
|
||||
|
||||
# Generate the key in destination instance
|
||||
response = dest_client.generate(**import_data)
|
||||
click.echo(f"Generated key: {response}")
|
||||
# The generate method returns JSON data directly, not a Response object
|
||||
imported_count += 1
|
||||
|
||||
key_alias = key.get("key_alias", "N/A")
|
||||
click.echo(f"✓ Imported key: {key_alias}")
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
key_alias = key.get("key_alias", "N/A")
|
||||
click.echo(f"✗ Failed to import key {key_alias}: {str(e)}", err=True)
|
||||
|
||||
return imported_count, failed_count
|
||||
|
||||
|
||||
@keys.command(name="import")
|
||||
@click.option(
|
||||
"--source-base-url",
|
||||
required=True,
|
||||
help="Base URL of the source LiteLLM proxy server to import keys from",
|
||||
)
|
||||
@click.option(
|
||||
"--source-api-key", help="API key for authentication to the source server"
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Show what would be imported without actually importing",
|
||||
)
|
||||
@click.option(
|
||||
"--created-since",
|
||||
help="Only import keys created after this date/time (format: YYYY-MM-DD_HH:MM or YYYY-MM-DD)",
|
||||
)
|
||||
@click.pass_context
|
||||
def import_keys(
|
||||
ctx: click.Context,
|
||||
source_base_url: str,
|
||||
source_api_key: Optional[str],
|
||||
dry_run: bool,
|
||||
created_since: Optional[str],
|
||||
):
|
||||
"""Import API keys from another LiteLLM instance"""
|
||||
# Parse created_since filter if provided
|
||||
created_since_dt = _parse_created_since_filter(created_since)
|
||||
|
||||
# Create clients for both source and destination
|
||||
source_client = KeysManagementClient(source_base_url, source_api_key)
|
||||
dest_client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
try:
|
||||
# Get all keys from source instance with pagination
|
||||
source_keys = _fetch_all_keys_with_pagination(source_client, source_base_url)
|
||||
|
||||
# Filter keys by created_since if specified
|
||||
if created_since:
|
||||
source_keys = _filter_keys_by_created_since(
|
||||
source_keys, created_since_dt, created_since
|
||||
)
|
||||
|
||||
if not source_keys:
|
||||
click.echo("No keys found in source instance.")
|
||||
return
|
||||
|
||||
click.echo(f"Found {len(source_keys)} keys in source instance.")
|
||||
|
||||
if dry_run:
|
||||
_display_dry_run_table(source_keys)
|
||||
return
|
||||
|
||||
# Import each key
|
||||
imported_count, failed_count = _import_keys_to_destination(
|
||||
source_keys, dest_client
|
||||
)
|
||||
|
||||
# Summary
|
||||
click.echo("\nImport completed:")
|
||||
click.echo(f" Successfully imported: {imported_count}")
|
||||
click.echo(f" Failed to import: {failed_count}")
|
||||
click.echo(f" Total keys processed: {len(source_keys)}")
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
@@ -0,0 +1,485 @@
|
||||
# stdlib imports
|
||||
from datetime import datetime
|
||||
import re
|
||||
from typing import Optional, Literal, Any
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
# third party imports
|
||||
import click
|
||||
import rich
|
||||
|
||||
# local imports
|
||||
from ... import Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelYamlInfo:
|
||||
model_name: str
|
||||
model_params: dict[str, Any]
|
||||
model_info: dict[str, Any]
|
||||
model_id: str
|
||||
access_groups: list[str]
|
||||
provider: str
|
||||
|
||||
@property
|
||||
def access_groups_str(self) -> str:
|
||||
return ", ".join(self.access_groups) if self.access_groups else ""
|
||||
|
||||
|
||||
def _get_model_info_obj_from_yaml(model: dict[str, Any]) -> ModelYamlInfo:
|
||||
"""Extract model info from a model dict and return as ModelYamlInfo dataclass."""
|
||||
model_name: str = model["model_name"]
|
||||
model_params: dict[str, Any] = model["litellm_params"]
|
||||
model_info: dict[str, Any] = model.get("model_info", {})
|
||||
model_id: str = model_params["model"]
|
||||
access_groups = model_info.get("access_groups", [])
|
||||
provider = model_id.split("/", 1)[0] if "/" in model_id else model_id
|
||||
return ModelYamlInfo(
|
||||
model_name=model_name,
|
||||
model_params=model_params,
|
||||
model_info=model_info,
|
||||
model_id=model_id,
|
||||
access_groups=access_groups,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def format_iso_datetime_str(iso_datetime_str: Optional[str]) -> str:
|
||||
"""Format an ISO format datetime string to human-readable date with minute resolution."""
|
||||
if not iso_datetime_str:
|
||||
return ""
|
||||
try:
|
||||
# Parse ISO format datetime string
|
||||
dt = datetime.fromisoformat(iso_datetime_str.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
except (TypeError, ValueError):
|
||||
return str(iso_datetime_str)
|
||||
|
||||
|
||||
def format_timestamp(timestamp: Optional[int]) -> str:
|
||||
"""Format a Unix timestamp (integer) to human-readable date with minute resolution."""
|
||||
if timestamp is None:
|
||||
return ""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
except (TypeError, ValueError):
|
||||
return str(timestamp)
|
||||
|
||||
|
||||
def format_cost_per_1k_tokens(cost: Optional[float]) -> str:
|
||||
"""Format a per-token cost to cost per 1000 tokens."""
|
||||
if cost is None:
|
||||
return ""
|
||||
try:
|
||||
# Convert string to float if needed
|
||||
cost_float = float(cost)
|
||||
# Multiply by 1000 and format to 4 decimal places
|
||||
return f"${cost_float * 1000:.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return str(cost)
|
||||
|
||||
|
||||
def create_client(ctx: click.Context) -> Client:
|
||||
"""Helper function to create a client from context."""
|
||||
return Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
|
||||
|
||||
|
||||
@click.group()
|
||||
def models() -> None:
|
||||
"""Manage models on your LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@models.command("list")
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.pass_context
|
||||
def list_models(ctx: click.Context, output_format: Literal["table", "json"]) -> None:
|
||||
"""List all available models"""
|
||||
client = create_client(ctx)
|
||||
models_list = client.models.list()
|
||||
assert isinstance(models_list, list)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=models_list)
|
||||
else: # table format
|
||||
table = rich.table.Table(title="Available Models")
|
||||
|
||||
# Add columns based on the data structure
|
||||
table.add_column("ID", style="cyan")
|
||||
table.add_column("Object", style="green")
|
||||
table.add_column("Created", style="magenta")
|
||||
table.add_column("Owned By", style="yellow")
|
||||
|
||||
# Add rows
|
||||
for model in models_list:
|
||||
created = model.get("created")
|
||||
# Convert string timestamp to integer if needed
|
||||
if isinstance(created, str) and created.isdigit():
|
||||
created = int(created)
|
||||
|
||||
table.add_row(
|
||||
str(model.get("id", "")),
|
||||
str(model.get("object", "model")),
|
||||
format_timestamp(created)
|
||||
if isinstance(created, int)
|
||||
else format_iso_datetime_str(created),
|
||||
str(model.get("owned_by", "")),
|
||||
)
|
||||
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@models.command("add")
|
||||
@click.argument("model-name")
|
||||
@click.option(
|
||||
"--param",
|
||||
"-p",
|
||||
multiple=True,
|
||||
help="Model parameters in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.option(
|
||||
"--info",
|
||||
"-i",
|
||||
multiple=True,
|
||||
help="Model info in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.pass_context
|
||||
def add_model(
|
||||
ctx: click.Context, model_name: str, param: tuple[str, ...], info: tuple[str, ...]
|
||||
) -> None:
|
||||
"""Add a new model to the proxy"""
|
||||
# Convert parameters from key=value format to dict
|
||||
model_params = dict(p.split("=", 1) for p in param)
|
||||
model_info = dict(i.split("=", 1) for i in info) if info else None
|
||||
|
||||
client = create_client(ctx)
|
||||
result = client.models.new(
|
||||
model_name=model_name,
|
||||
model_params=model_params,
|
||||
model_info=model_info,
|
||||
)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@models.command("delete")
|
||||
@click.argument("model-id")
|
||||
@click.pass_context
|
||||
def delete_model(ctx: click.Context, model_id: str) -> None:
|
||||
"""Delete a model from the proxy"""
|
||||
client = create_client(ctx)
|
||||
result = client.models.delete(model_id=model_id)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@models.command("get")
|
||||
@click.option("--id", "model_id", help="ID of the model to retrieve")
|
||||
@click.option("--name", "model_name", help="Name of the model to retrieve")
|
||||
@click.pass_context
|
||||
def get_model(
|
||||
ctx: click.Context, model_id: Optional[str], model_name: Optional[str]
|
||||
) -> None:
|
||||
"""Get information about a specific model"""
|
||||
if not model_id and not model_name:
|
||||
raise click.UsageError("Either --id or --name must be provided")
|
||||
|
||||
client = create_client(ctx)
|
||||
result = client.models.get(model_id=model_id, model_name=model_name)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@models.command("info")
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.option(
|
||||
"--columns",
|
||||
"columns",
|
||||
default="public_model,upstream_model,updated_at",
|
||||
help="Comma-separated list of columns to display. Valid columns: public_model, upstream_model, credential_name, created_at, updated_at, id, input_cost, output_cost. Default: public_model,upstream_model,updated_at",
|
||||
)
|
||||
@click.pass_context
|
||||
def get_models_info(
|
||||
ctx: click.Context, output_format: Literal["table", "json"], columns: str
|
||||
) -> None:
|
||||
"""Get detailed information about all models"""
|
||||
client = create_client(ctx)
|
||||
models_info = client.models.info()
|
||||
assert isinstance(models_info, list)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=models_info)
|
||||
else: # table format
|
||||
table = rich.table.Table(title="Models Information")
|
||||
|
||||
# Define all possible columns with their configurations
|
||||
column_configs: dict[str, dict[str, Any]] = {
|
||||
"public_model": {
|
||||
"header": "Public Model",
|
||||
"style": "cyan",
|
||||
"get_value": lambda m: str(m.get("model_name", "")),
|
||||
},
|
||||
"upstream_model": {
|
||||
"header": "Upstream Model",
|
||||
"style": "green",
|
||||
"get_value": lambda m: str(
|
||||
m.get("litellm_params", {}).get("model", "")
|
||||
),
|
||||
},
|
||||
"credential_name": {
|
||||
"header": "Credential Name",
|
||||
"style": "yellow",
|
||||
"get_value": lambda m: str(
|
||||
m.get("litellm_params", {}).get("litellm_credential_name", "")
|
||||
),
|
||||
},
|
||||
"created_at": {
|
||||
"header": "Created At",
|
||||
"style": "magenta",
|
||||
"get_value": lambda m: format_iso_datetime_str(
|
||||
m.get("model_info", {}).get("created_at")
|
||||
),
|
||||
},
|
||||
"updated_at": {
|
||||
"header": "Updated At",
|
||||
"style": "magenta",
|
||||
"get_value": lambda m: format_iso_datetime_str(
|
||||
m.get("model_info", {}).get("updated_at")
|
||||
),
|
||||
},
|
||||
"id": {
|
||||
"header": "ID",
|
||||
"style": "blue",
|
||||
"get_value": lambda m: str(m.get("model_info", {}).get("id", "")),
|
||||
},
|
||||
"input_cost": {
|
||||
"header": "Input Cost",
|
||||
"style": "green",
|
||||
"justify": "right",
|
||||
"get_value": lambda m: format_cost_per_1k_tokens(
|
||||
m.get("model_info", {}).get("input_cost_per_token")
|
||||
),
|
||||
},
|
||||
"output_cost": {
|
||||
"header": "Output Cost",
|
||||
"style": "green",
|
||||
"justify": "right",
|
||||
"get_value": lambda m: format_cost_per_1k_tokens(
|
||||
m.get("model_info", {}).get("output_cost_per_token")
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Add requested columns
|
||||
requested_columns = [col.strip() for col in columns.split(",")]
|
||||
for col_name in requested_columns:
|
||||
if col_name in column_configs:
|
||||
config = column_configs[col_name]
|
||||
table.add_column(
|
||||
config["header"],
|
||||
style=config["style"],
|
||||
justify=config.get("justify", "left"),
|
||||
)
|
||||
else:
|
||||
click.echo(f"Warning: Unknown column '{col_name}'", err=True)
|
||||
|
||||
# Add rows with only the requested columns
|
||||
for model in models_info:
|
||||
row_values = []
|
||||
for col_name in requested_columns:
|
||||
if col_name in column_configs:
|
||||
row_values.append(column_configs[col_name]["get_value"](model))
|
||||
if row_values:
|
||||
table.add_row(*row_values)
|
||||
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@models.command("update")
|
||||
@click.argument("model-id")
|
||||
@click.option(
|
||||
"--param",
|
||||
"-p",
|
||||
multiple=True,
|
||||
help="Model parameters in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.option(
|
||||
"--info",
|
||||
"-i",
|
||||
multiple=True,
|
||||
help="Model info in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.pass_context
|
||||
def update_model(
|
||||
ctx: click.Context, model_id: str, param: tuple[str, ...], info: tuple[str, ...]
|
||||
) -> None:
|
||||
"""Update an existing model's configuration"""
|
||||
# Convert parameters from key=value format to dict
|
||||
model_params = dict(p.split("=", 1) for p in param)
|
||||
model_info = dict(i.split("=", 1) for i in info) if info else None
|
||||
|
||||
client = create_client(ctx)
|
||||
result = client.models.update(
|
||||
model_id=model_id,
|
||||
model_params=model_params,
|
||||
model_info=model_info,
|
||||
)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
def _filter_model(model, model_regex, access_group_regex):
|
||||
model_name = model.get("model_name")
|
||||
model_params = model.get("litellm_params")
|
||||
model_info = model.get("model_info", {})
|
||||
if not model_name or not model_params:
|
||||
return False
|
||||
model_id = model_params.get("model")
|
||||
if not model_id or not isinstance(model_id, str):
|
||||
return False
|
||||
if model_regex and not model_regex.search(model_id):
|
||||
return False
|
||||
access_groups = model_info.get("access_groups", [])
|
||||
if access_group_regex:
|
||||
if not isinstance(access_groups, list):
|
||||
return False
|
||||
if not any(
|
||||
isinstance(group, str) and access_group_regex.search(group)
|
||||
for group in access_groups
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _print_models_table(added_models: list[ModelYamlInfo], table_title: str):
|
||||
if not added_models:
|
||||
return
|
||||
table = rich.table.Table(title=table_title)
|
||||
table.add_column("Model Name", style="cyan")
|
||||
table.add_column("Upstream Model", style="green")
|
||||
table.add_column("Access Groups", style="magenta")
|
||||
for m in added_models:
|
||||
table.add_row(m.model_name, m.model_id, m.access_groups_str)
|
||||
rich.print(table)
|
||||
|
||||
|
||||
def _print_summary_table(provider_counts):
|
||||
summary_table = rich.table.Table(title="Model Import Summary")
|
||||
summary_table.add_column("Provider", style="cyan")
|
||||
summary_table.add_column("Count", style="green")
|
||||
|
||||
for provider, count in provider_counts.items():
|
||||
summary_table.add_row(str(provider), str(count))
|
||||
|
||||
total = sum(provider_counts.values())
|
||||
summary_table.add_row("[bold]Total[/bold]", f"[bold]{total}[/bold]")
|
||||
|
||||
rich.print(summary_table)
|
||||
|
||||
|
||||
def get_model_list_from_yaml_file(yaml_file: str) -> list[dict[str, Any]]:
|
||||
"""Load and validate the model list from a YAML file."""
|
||||
with open(yaml_file, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not data or "model_list" not in data:
|
||||
raise click.ClickException(
|
||||
"YAML file must contain a 'model_list' key with a list of models."
|
||||
)
|
||||
model_list = data["model_list"]
|
||||
if not isinstance(model_list, list):
|
||||
raise click.ClickException("'model_list' must be a list of model definitions.")
|
||||
return model_list
|
||||
|
||||
|
||||
def _get_filtered_model_list(
|
||||
model_list, only_models_matching_regex, only_access_groups_matching_regex
|
||||
):
|
||||
"""Return a list of models that pass the filter criteria."""
|
||||
model_regex = (
|
||||
re.compile(only_models_matching_regex) if only_models_matching_regex else None
|
||||
)
|
||||
access_group_regex = (
|
||||
re.compile(only_access_groups_matching_regex)
|
||||
if only_access_groups_matching_regex
|
||||
else None
|
||||
)
|
||||
return [
|
||||
model
|
||||
for model in model_list
|
||||
if _filter_model(model, model_regex, access_group_regex)
|
||||
]
|
||||
|
||||
|
||||
def _import_models_get_table_title(dry_run: bool) -> str:
|
||||
if dry_run:
|
||||
return "Models that would be imported if [yellow]--dry-run[/yellow] was not provided"
|
||||
else:
|
||||
return "Models Imported"
|
||||
|
||||
|
||||
@models.command("import")
|
||||
@click.argument(
|
||||
"yaml_file", type=click.Path(exists=True, dir_okay=False, readable=True)
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Show what would be imported without making any changes.",
|
||||
)
|
||||
@click.option(
|
||||
"--only-models-matching-regex",
|
||||
default=None,
|
||||
help="Only import models where litellm_params.model matches the given regex.",
|
||||
)
|
||||
@click.option(
|
||||
"--only-access-groups-matching-regex",
|
||||
default=None,
|
||||
help="Only import models where at least one item in model_info.access_groups matches the given regex.",
|
||||
)
|
||||
@click.pass_context
|
||||
def import_models(
|
||||
ctx: click.Context,
|
||||
yaml_file: str,
|
||||
dry_run: bool,
|
||||
only_models_matching_regex: Optional[str],
|
||||
only_access_groups_matching_regex: Optional[str],
|
||||
) -> None:
|
||||
"""Import models from a YAML file and add them to the proxy."""
|
||||
provider_counts: dict[str, int] = defaultdict(int)
|
||||
added_models: list[ModelYamlInfo] = []
|
||||
model_list = get_model_list_from_yaml_file(yaml_file)
|
||||
filtered_model_list = _get_filtered_model_list(
|
||||
model_list, only_models_matching_regex, only_access_groups_matching_regex
|
||||
)
|
||||
|
||||
if not dry_run:
|
||||
client = create_client(ctx)
|
||||
|
||||
for model in filtered_model_list:
|
||||
model_info_obj = _get_model_info_obj_from_yaml(model)
|
||||
if not dry_run:
|
||||
try:
|
||||
client.models.new(
|
||||
model_name=model_info_obj.model_name,
|
||||
model_params=model_info_obj.model_params,
|
||||
model_info=model_info_obj.model_info,
|
||||
)
|
||||
except Exception:
|
||||
pass # For summary, ignore errors
|
||||
added_models.append(model_info_obj)
|
||||
provider_counts[model_info_obj.provider] += 1
|
||||
|
||||
table_title = _import_models_get_table_title(dry_run)
|
||||
_print_models_table(added_models, table_title)
|
||||
_print_summary_table(provider_counts)
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Team management commands for LiteLLM CLI."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from litellm.proxy.client import Client
|
||||
|
||||
|
||||
@click.group()
|
||||
def teams():
|
||||
"""Manage teams and team assignments"""
|
||||
pass
|
||||
|
||||
|
||||
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
|
||||
"""Display teams in a formatted table"""
|
||||
console = Console()
|
||||
|
||||
if not teams:
|
||||
console.print("❌ No teams found for your user.")
|
||||
return
|
||||
|
||||
table = Table(title="Available Teams")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Team Alias", style="magenta")
|
||||
table.add_column("Team ID", style="green")
|
||||
table.add_column("Models", style="yellow")
|
||||
table.add_column("Max Budget", style="blue")
|
||||
table.add_column("Role", style="red")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_alias = team.get("team_alias") or "N/A"
|
||||
team_id = team.get("team_id", "N/A")
|
||||
models = team.get("models", [])
|
||||
max_budget = team.get("max_budget")
|
||||
|
||||
# Format models list
|
||||
if models:
|
||||
if len(models) > 3:
|
||||
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
|
||||
else:
|
||||
models_str = ", ".join(models)
|
||||
else:
|
||||
models_str = "All models"
|
||||
|
||||
# Format budget
|
||||
budget_str = f"${max_budget}" if max_budget else "Unlimited"
|
||||
|
||||
# Try to determine role (this might vary based on API response structure)
|
||||
role = "Member" # Default role
|
||||
if (
|
||||
isinstance(team, dict)
|
||||
and "members_with_roles" in team
|
||||
and team["members_with_roles"]
|
||||
):
|
||||
# This would need to be implemented based on actual API response structure
|
||||
pass
|
||||
|
||||
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str, role)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@teams.command()
|
||||
@click.pass_context
|
||||
def list(ctx: click.Context):
|
||||
"""List teams that you belong to"""
|
||||
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
try:
|
||||
# Use list() for simpler response structure (returns array directly)
|
||||
teams = client.teams.list()
|
||||
display_teams_table(teams)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
error_body = e.response.json()
|
||||
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
|
||||
raise click.Abort()
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@teams.command()
|
||||
@click.pass_context
|
||||
def available(ctx: click.Context):
|
||||
"""List teams that are available to join"""
|
||||
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
try:
|
||||
teams = client.teams.get_available()
|
||||
if teams:
|
||||
console = Console()
|
||||
console.print("\n🎯 Available Teams to Join:")
|
||||
display_teams_table(teams)
|
||||
else:
|
||||
click.echo("ℹ️ No available teams to join.")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
error_body = e.response.json()
|
||||
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@teams.command()
|
||||
@click.option("--team-id", type=str, help="Team ID to assign the key to")
|
||||
@click.pass_context
|
||||
def assign_key(ctx: click.Context, team_id: Optional[str]):
|
||||
"""Assign your current CLI key to a team"""
|
||||
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
api_key = ctx.obj["api_key"]
|
||||
|
||||
if not api_key:
|
||||
click.echo("❌ No API key found. Please login first using 'litellm login'")
|
||||
raise click.Abort()
|
||||
|
||||
try:
|
||||
# If no team_id provided, show teams and let user select
|
||||
if not team_id:
|
||||
teams = client.teams.list()
|
||||
|
||||
if not teams:
|
||||
click.echo("❌ No teams found for your user.")
|
||||
return
|
||||
|
||||
# Use interactive selection from auth module
|
||||
from .auth import prompt_team_selection
|
||||
|
||||
selected_team = prompt_team_selection(teams)
|
||||
|
||||
if selected_team:
|
||||
team_id = selected_team.get("team_id")
|
||||
else:
|
||||
click.echo("❌ Operation cancelled.")
|
||||
return
|
||||
|
||||
# Update the key with the selected team
|
||||
if team_id:
|
||||
click.echo(f"\n🔄 Assigning your key to team: {team_id}")
|
||||
client.keys.update(key=api_key, team_id=team_id)
|
||||
click.echo(f"✅ Successfully assigned key to team: {team_id}")
|
||||
|
||||
# Show team details if available
|
||||
teams = client.teams.list()
|
||||
for team in teams:
|
||||
if team.get("team_id") == team_id:
|
||||
models = team.get("models", [])
|
||||
if models:
|
||||
click.echo(f"🎯 You can now access models: {', '.join(models)}")
|
||||
else:
|
||||
click.echo("🎯 You can now access all available models")
|
||||
break
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
error_body = e.response.json()
|
||||
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
|
||||
raise click.Abort()
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
@@ -0,0 +1,91 @@
|
||||
import click
|
||||
import rich
|
||||
from ... import UsersManagementClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def users():
|
||||
"""Manage users on your LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@users.command("list")
|
||||
@click.pass_context
|
||||
def list_users(ctx: click.Context):
|
||||
"""List all users"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
users = client.list_users()
|
||||
if isinstance(users, dict) and "users" in users:
|
||||
users = users["users"]
|
||||
if not users:
|
||||
click.echo("No users found.")
|
||||
return
|
||||
from rich.table import Table
|
||||
from rich.console import Console
|
||||
|
||||
table = Table(title="Users")
|
||||
table.add_column("User ID", style="cyan")
|
||||
table.add_column("Email", style="green")
|
||||
table.add_column("Role", style="magenta")
|
||||
table.add_column("Teams", style="yellow")
|
||||
for user in users:
|
||||
table.add_row(
|
||||
str(user.get("user_id", "")),
|
||||
str(user.get("user_email", "")),
|
||||
str(user.get("user_role", "")),
|
||||
", ".join(user.get("teams", []) or []),
|
||||
)
|
||||
console = Console()
|
||||
console.print(table)
|
||||
|
||||
|
||||
@users.command("get")
|
||||
@click.option("--id", "user_id", help="ID of the user to retrieve")
|
||||
@click.pass_context
|
||||
def get_user(ctx: click.Context, user_id: str):
|
||||
"""Get information about a specific user"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
result = client.get_user(user_id=user_id)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@users.command("create")
|
||||
@click.option("--email", required=True, help="User email")
|
||||
@click.option("--role", default="internal_user", help="User role")
|
||||
@click.option("--alias", default=None, help="User alias")
|
||||
@click.option("--team", multiple=True, help="Team IDs (can specify multiple)")
|
||||
@click.option("--max-budget", type=float, default=None, help="Max budget for user")
|
||||
@click.pass_context
|
||||
def create_user(ctx: click.Context, email, role, alias, team, max_budget):
|
||||
"""Create a new user"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
user_data = {
|
||||
"user_email": email,
|
||||
"user_role": role,
|
||||
}
|
||||
if alias:
|
||||
user_data["user_alias"] = alias
|
||||
if team:
|
||||
user_data["teams"] = list(team)
|
||||
if max_budget is not None:
|
||||
user_data["max_budget"] = max_budget
|
||||
result = client.create_user(user_data)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@users.command("delete")
|
||||
@click.argument("user_ids", nargs=-1)
|
||||
@click.pass_context
|
||||
def delete_user(ctx: click.Context, user_ids):
|
||||
"""Delete one or more users by user_id"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
result = client.delete_user(list(user_ids))
|
||||
rich.print_json(data=result)
|
||||
@@ -0,0 +1,207 @@
|
||||
# stdlib imports
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# third party imports
|
||||
import click
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def styled_prompt():
|
||||
"""Create a styled blue box prompt for user input."""
|
||||
|
||||
# Get terminal height to ensure we have enough space
|
||||
try:
|
||||
terminal_height = os.get_terminal_size().lines
|
||||
# Ensure we have at least 5 lines of space (for the box + some buffer)
|
||||
if terminal_height < 10:
|
||||
# If terminal is too small, just add some newlines to push content up
|
||||
click.echo("\n" * 3)
|
||||
except Exception as e:
|
||||
# Fallback if we can't get terminal size
|
||||
verbose_logger.debug(f"Error getting terminal size: {e}")
|
||||
click.echo("\n" * 3)
|
||||
|
||||
# Unicode box drawing characters
|
||||
top_left = "┌"
|
||||
top_right = "┐"
|
||||
bottom_left = "└"
|
||||
bottom_right = "┘"
|
||||
horizontal = "─"
|
||||
vertical = "│"
|
||||
|
||||
# Create the box with increased width
|
||||
width = 80
|
||||
top_line = top_left + horizontal * (width - 2) + top_right
|
||||
bottom_line = bottom_left + horizontal * (width - 2) + bottom_right
|
||||
|
||||
# Create styled elements
|
||||
left_border = click.style(vertical, fg="blue", bold=True)
|
||||
right_border = click.style(vertical, fg="blue", bold=True)
|
||||
prompt_text = click.style("> ", fg="cyan", bold=True)
|
||||
|
||||
# Display the complete box structure first to reserve space
|
||||
click.echo(click.style(top_line, fg="blue", bold=True))
|
||||
|
||||
# Create empty space in the box for input
|
||||
empty_space = " " * (width - 4)
|
||||
click.echo(f"{left_border} {empty_space} {right_border}")
|
||||
|
||||
# Display bottom border to complete the box
|
||||
click.echo(click.style(bottom_line, fg="blue", bold=True))
|
||||
|
||||
# Now move cursor up to the input line and get input
|
||||
click.echo("\033[2A", nl=False) # Move cursor up 2 lines
|
||||
click.echo(
|
||||
f"\r{left_border} {prompt_text}", nl=False
|
||||
) # Position at start of input line
|
||||
|
||||
try:
|
||||
# Get user input
|
||||
user_input = input().strip()
|
||||
|
||||
# Move cursor down to after the box
|
||||
click.echo("\033[1B") # Move cursor down 1 line
|
||||
click.echo("") # Add some space after
|
||||
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Move cursor down and add space
|
||||
click.echo("\033[1B")
|
||||
click.echo("")
|
||||
raise
|
||||
|
||||
return user_input
|
||||
|
||||
|
||||
def show_commands():
|
||||
"""Display available commands."""
|
||||
commands = [
|
||||
("login", "Authenticate with the LiteLLM proxy server"),
|
||||
("logout", "Clear stored authentication"),
|
||||
("whoami", "Show current authentication status"),
|
||||
("models", "Manage and view model configurations"),
|
||||
("credentials", "Manage API credentials"),
|
||||
("chat", "Interactive streaming chat with models"),
|
||||
("http", "Make HTTP requests to the proxy"),
|
||||
("keys", "Manage API keys"),
|
||||
("teams", "Manage teams and team assignments"),
|
||||
("users", "Manage users"),
|
||||
("version", "Show version information"),
|
||||
("help", "Show this help message"),
|
||||
("quit", "Exit the interactive session"),
|
||||
]
|
||||
|
||||
click.echo("Available commands:")
|
||||
for cmd, description in commands:
|
||||
click.echo(f" {cmd:<20} {description}")
|
||||
click.echo()
|
||||
|
||||
|
||||
def setup_shell(ctx: click.Context):
|
||||
"""Set up the interactive shell with banner and initial info."""
|
||||
from litellm.proxy.common_utils.banner import show_banner
|
||||
|
||||
show_banner()
|
||||
|
||||
# Show server connection info
|
||||
base_url = ctx.obj.get("base_url")
|
||||
click.secho(f"Connected to LiteLLM server: {base_url}\n", fg="green")
|
||||
|
||||
show_commands()
|
||||
|
||||
|
||||
def handle_special_commands(user_input: str) -> bool:
|
||||
"""Handle special commands like exit, help, clear. Returns True if command was handled."""
|
||||
if user_input.lower() in ["exit", "quit"]:
|
||||
click.echo("Goodbye!")
|
||||
return True
|
||||
elif user_input.lower() == "help":
|
||||
click.echo("") # Add space before help
|
||||
show_commands()
|
||||
return True
|
||||
elif user_input.lower() == "clear":
|
||||
click.clear()
|
||||
from litellm.proxy.common_utils.banner import show_banner
|
||||
|
||||
show_banner()
|
||||
show_commands()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def execute_command(user_input: str, ctx: click.Context):
|
||||
"""Parse and execute a command."""
|
||||
# Parse command and arguments
|
||||
parts = user_input.split()
|
||||
command = parts[0]
|
||||
args = parts[1:] if len(parts) > 1 else []
|
||||
|
||||
# Import cli here to avoid circular import
|
||||
from . import main
|
||||
|
||||
cli = main.cli
|
||||
|
||||
# Check if command exists
|
||||
if command not in cli.commands:
|
||||
click.echo(f"Unknown command: {command}")
|
||||
click.echo("Type 'help' to see available commands.")
|
||||
return
|
||||
|
||||
# Execute the command
|
||||
try:
|
||||
# Create a new argument list for click to parse
|
||||
sys.argv = ["litellm-proxy"] + [command] + args
|
||||
|
||||
# Get the command object and invoke it
|
||||
cmd = cli.commands[command]
|
||||
|
||||
# Create a new context for the subcommand
|
||||
with ctx.scope():
|
||||
cmd.main(args, parent=ctx, standalone_mode=False)
|
||||
|
||||
except click.ClickException as e:
|
||||
e.show()
|
||||
except click.Abort:
|
||||
click.echo("Command aborted.")
|
||||
except SystemExit:
|
||||
# Prevent the interactive shell from exiting on command errors
|
||||
pass
|
||||
except Exception as e:
|
||||
click.echo(f"Error executing command: {e}")
|
||||
|
||||
|
||||
def interactive_shell(ctx: click.Context):
|
||||
"""Run the interactive shell."""
|
||||
setup_shell(ctx)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Add some space before the input box to ensure it's positioned well
|
||||
click.echo("\n") # Extra spacing
|
||||
|
||||
# Show styled prompt
|
||||
user_input = styled_prompt()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Handle special commands
|
||||
if handle_special_commands(user_input):
|
||||
if user_input.lower() in ["exit", "quit"]:
|
||||
break
|
||||
continue
|
||||
|
||||
# Execute regular commands
|
||||
execute_command(user_input, ctx)
|
||||
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
click.echo("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}")
|
||||
@@ -0,0 +1,115 @@
|
||||
# stdlib imports
|
||||
from typing import Optional
|
||||
|
||||
# third party imports
|
||||
import click
|
||||
|
||||
from litellm._version import version as litellm_version
|
||||
from litellm.proxy.client.health import HealthManagementClient
|
||||
|
||||
from .commands.auth import get_stored_api_key, login, logout, whoami
|
||||
from .commands.chat import chat
|
||||
from .commands.credentials import credentials
|
||||
from .commands.http import http
|
||||
from .commands.keys import keys
|
||||
|
||||
# local imports
|
||||
from .commands.models import models
|
||||
from .commands.teams import teams
|
||||
from .commands.users import users
|
||||
from .interface import interactive_shell
|
||||
|
||||
|
||||
def print_version(base_url: str, api_key: Optional[str]):
|
||||
"""Print CLI and server version info."""
|
||||
click.echo(f"LiteLLM Proxy CLI Version: {litellm_version}")
|
||||
if base_url:
|
||||
click.echo(f"LiteLLM Proxy Server URL: {base_url}")
|
||||
try:
|
||||
health_client = HealthManagementClient(base_url=base_url, api_key=api_key)
|
||||
server_version = health_client.get_server_version()
|
||||
if server_version:
|
||||
click.echo(f"LiteLLM Proxy Server Version: {server_version}")
|
||||
else:
|
||||
click.echo("LiteLLM Proxy Server Version: (unavailable)")
|
||||
except Exception as e:
|
||||
click.echo(f"Could not retrieve server version: {e}")
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.option(
|
||||
"--version",
|
||||
"-v",
|
||||
is_flag=True,
|
||||
is_eager=True,
|
||||
expose_value=False,
|
||||
help="Show the LiteLLM Proxy CLI and server version and exit.",
|
||||
callback=lambda ctx, param, value: (
|
||||
print_version(
|
||||
ctx.params.get("base_url") or "http://localhost:4000",
|
||||
ctx.params.get("api_key"),
|
||||
)
|
||||
or ctx.exit()
|
||||
)
|
||||
if value and not ctx.resilient_parsing
|
||||
else None,
|
||||
)
|
||||
@click.option(
|
||||
"--base-url",
|
||||
envvar="LITELLM_PROXY_URL",
|
||||
show_envvar=True,
|
||||
default="http://localhost:4000",
|
||||
help="Base URL of the LiteLLM proxy server",
|
||||
)
|
||||
@click.option(
|
||||
"--api-key",
|
||||
envvar="LITELLM_PROXY_API_KEY",
|
||||
show_envvar=True,
|
||||
help="API key for authentication",
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, base_url: str, api_key: Optional[str]) -> None:
|
||||
"""LiteLLM Proxy CLI - Manage your LiteLLM proxy server"""
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
# If no API key provided via flag or environment variable, try to load from saved token
|
||||
if api_key is None:
|
||||
api_key = get_stored_api_key()
|
||||
|
||||
ctx.obj["base_url"] = base_url
|
||||
ctx.obj["api_key"] = api_key
|
||||
|
||||
# If no subcommand was invoked, start interactive mode
|
||||
if ctx.invoked_subcommand is None:
|
||||
interactive_shell(ctx)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
def version(ctx: click.Context):
|
||||
"""Show the LiteLLM Proxy CLI and server version."""
|
||||
print_version(ctx.obj.get("base_url"), ctx.obj.get("api_key"))
|
||||
|
||||
|
||||
# Add authentication commands as top-level commands
|
||||
cli.add_command(login)
|
||||
cli.add_command(logout)
|
||||
cli.add_command(whoami)
|
||||
# Add the models command group
|
||||
cli.add_command(models)
|
||||
# Add the credentials command group
|
||||
cli.add_command(credentials)
|
||||
# Add the chat command group
|
||||
cli.add_command(chat)
|
||||
# Add the http command group
|
||||
cli.add_command(http)
|
||||
# Add the keys command group
|
||||
cli.add_command(keys)
|
||||
# Add the teams command group
|
||||
cli.add_command(teams)
|
||||
# Add the users command group
|
||||
cli.add_command(users)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.litellm_core_utils.cli_token_utils import get_litellm_gateway_api_key
|
||||
|
||||
from .chat import ChatClient
|
||||
from .credentials import CredentialsManagementClient
|
||||
from .http_client import HTTPClient
|
||||
from .keys import KeysManagementClient
|
||||
from .model_groups import ModelGroupsManagementClient
|
||||
from .models import ModelsManagementClient
|
||||
from .teams import TeamsManagementClient
|
||||
|
||||
|
||||
class Client:
|
||||
"""Main client for interacting with the LiteLLM proxy API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize the LiteLLM proxy client.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:4000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
timeout: Request timeout in seconds (default: 30)
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
|
||||
self._api_key = get_litellm_gateway_api_key() or api_key
|
||||
|
||||
# Initialize resource clients
|
||||
|
||||
self.http = HTTPClient(base_url=base_url, api_key=api_key, timeout=timeout)
|
||||
self.models = ModelsManagementClient(
|
||||
base_url=self._base_url, api_key=self._api_key
|
||||
)
|
||||
self.model_groups = ModelGroupsManagementClient(
|
||||
base_url=self._base_url, api_key=self._api_key
|
||||
)
|
||||
self.chat = ChatClient(base_url=self._base_url, api_key=self._api_key)
|
||||
self.keys = KeysManagementClient(base_url=self._base_url, api_key=self._api_key)
|
||||
self.credentials = CredentialsManagementClient(
|
||||
base_url=self._base_url, api_key=self._api_key
|
||||
)
|
||||
self.teams = TeamsManagementClient(
|
||||
base_url=self._base_url, api_key=self._api_key
|
||||
)
|
||||
@@ -0,0 +1,185 @@
|
||||
import requests
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
from .exceptions import UnauthorizedError
|
||||
|
||||
|
||||
class CredentialsManagementClient:
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the CredentialsManagementClient.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
|
||||
self._api_key = api_key
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get the headers for API requests, including authorization if api_key is set.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Headers to use for API requests
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
return headers
|
||||
|
||||
def list(
|
||||
self,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
List all credentials.
|
||||
|
||||
Args:
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/credentials"
|
||||
|
||||
request = requests.Request("GET", url, headers=self._get_headers())
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def create(
|
||||
self,
|
||||
credential_name: str,
|
||||
credential_info: Dict[str, Any],
|
||||
credential_values: Dict[str, Any],
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Create a new credential.
|
||||
|
||||
Args:
|
||||
credential_name (str): Name of the credential
|
||||
credential_info (Dict[str, Any]): Additional information about the credential
|
||||
credential_values (Dict[str, Any]): Values for the credential
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/credentials"
|
||||
|
||||
data = {
|
||||
"credential_name": credential_name,
|
||||
"credential_info": credential_info,
|
||||
"credential_values": credential_values,
|
||||
}
|
||||
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def delete(
|
||||
self,
|
||||
credential_name: str,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Delete a credential by name.
|
||||
|
||||
Args:
|
||||
credential_name (str): Name of the credential to delete
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/credentials/{credential_name}"
|
||||
|
||||
request = requests.Request("DELETE", url, headers=self._get_headers())
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def get(
|
||||
self,
|
||||
credential_name: str,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Get a credential by name.
|
||||
|
||||
Args:
|
||||
credential_name (str): Name of the credential to retrieve
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/credentials/by_name/{credential_name}"
|
||||
|
||||
request = requests.Request("GET", url, headers=self._get_headers())
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
@@ -0,0 +1,19 @@
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class UnauthorizedError(Exception):
|
||||
"""Exception raised when the API returns a 401 Unauthorized response."""
|
||||
|
||||
def __init__(self, orig_exception: Union[requests.exceptions.HTTPError, str]):
|
||||
self.orig_exception = orig_exception
|
||||
super().__init__(str(orig_exception))
|
||||
|
||||
|
||||
class NotFoundError(Exception):
|
||||
"""Exception raised when the API returns a 404 Not Found response or indicates a resource was not found."""
|
||||
|
||||
def __init__(self, orig_exception: Union[requests.exceptions.HTTPError, str]):
|
||||
self.orig_exception = orig_exception
|
||||
super().__init__(str(orig_exception))
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from .http_client import HTTPClient
|
||||
|
||||
|
||||
class HealthManagementClient:
|
||||
"""
|
||||
Client for interacting with the health endpoints of the LiteLLM proxy server.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None, timeout: int = 30):
|
||||
"""
|
||||
Initialize the HealthManagementClient.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:4000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
timeout (int): Request timeout in seconds (default: 30)
|
||||
"""
|
||||
self._http = HTTPClient(base_url=base_url, api_key=api_key, timeout=timeout)
|
||||
|
||||
def get_readiness(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check the readiness of the LiteLLM proxy server.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The readiness status and details from the server.
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If the request fails
|
||||
ValueError: If the response is not valid JSON
|
||||
"""
|
||||
return self._http.request("GET", "/health/readiness")
|
||||
|
||||
def get_server_version(self) -> Optional[str]:
|
||||
"""
|
||||
Get the LiteLLM server version from the readiness endpoint.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The server version if available, otherwise None.
|
||||
"""
|
||||
readiness = self.get_readiness()
|
||||
return readiness.get("litellm_version")
|
||||
@@ -0,0 +1,95 @@
|
||||
"""HTTP client for making requests to the LiteLLM proxy server."""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import requests
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""HTTP client for making requests to the LiteLLM proxy server."""
|
||||
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None, timeout: int = 30):
|
||||
"""Initialize the HTTP client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the LiteLLM proxy server
|
||||
api_key: Optional API key for authentication
|
||||
timeout: Request timeout in seconds (default: 30)
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._api_key = api_key
|
||||
self._timeout = timeout
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
*,
|
||||
data: Optional[Union[Dict[str, Any], list, bytes]] = None,
|
||||
json: Optional[Union[Dict[str, Any], list]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Make an HTTP request to the LiteLLM proxy server.
|
||||
|
||||
This method is used to make generic requests to the LiteLLM proxy
|
||||
server, when there is not a specific client or method for the request.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, DELETE, etc.)
|
||||
uri: URI path (will be appended to base_url) (e.g., "/credentials")
|
||||
data: (optional) Dictionary, list of tuples, bytes, or file-like
|
||||
object to send in the body of the request.
|
||||
json: (optional) A JSON serializable Python object to send in the body
|
||||
of the request.
|
||||
headers: (optional) Dictionary of HTTP headers to send with the request.
|
||||
**kwargs: Additional keyword arguments to pass to the request.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response from the server
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If the request fails
|
||||
ValueError: If the response is not valid JSON
|
||||
|
||||
Example:
|
||||
>>> client.http.request("POST", "/health/test_connection", json={
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"custom_llm_provider": "azure_ai",
|
||||
"litellm_credential_name": None,
|
||||
"api_key": "6xxxxxxx",
|
||||
"api_base": "https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-10-21",
|
||||
},
|
||||
"mode": "chat",
|
||||
})
|
||||
{'status': 'error',
|
||||
'result': {'model': 'gpt-4',
|
||||
'custom_llm_provider': 'azure_ai',
|
||||
'litellm_credential_name': None,
|
||||
...
|
||||
"""
|
||||
# Build complete URL
|
||||
url = f"{self._base_url}/{uri.lstrip('/')}"
|
||||
|
||||
# Prepare headers
|
||||
request_headers = {}
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
if self._api_key:
|
||||
request_headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=self._timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Raise for HTTP errors
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse and return JSON response
|
||||
return response.json()
|
||||
@@ -0,0 +1,319 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from .exceptions import UnauthorizedError
|
||||
|
||||
|
||||
class KeysManagementClient:
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the KeysManagementClient.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
|
||||
self._api_key = api_key
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get the headers for API requests, including authorization if api_key is set.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Headers to use for API requests
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
return headers
|
||||
|
||||
def list(
|
||||
self,
|
||||
page: Optional[int] = None,
|
||||
size: Optional[int] = None,
|
||||
user_id: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
key_hash: Optional[str] = None,
|
||||
key_alias: Optional[str] = None,
|
||||
return_full_object: Optional[bool] = None,
|
||||
include_team_keys: Optional[bool] = None,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
List all API keys with optional filtering and pagination.
|
||||
|
||||
Args:
|
||||
page (Optional[int]): Page number for pagination
|
||||
size (Optional[int]): Number of items per page
|
||||
user_id (Optional[str]): Filter keys by user ID
|
||||
team_id (Optional[str]): Filter keys by team ID
|
||||
organization_id (Optional[str]): Filter keys by organization ID
|
||||
key_hash (Optional[str]): Filter by specific key hash
|
||||
key_alias (Optional[str]): Filter by key alias
|
||||
return_full_object (Optional[bool]): Whether to return the full key object
|
||||
include_team_keys (Optional[bool]): Whether to include team keys in the response
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True. The response contains a list
|
||||
of API keys with their configurations.
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/key/list"
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
# Add optional query parameters
|
||||
if page is not None:
|
||||
params["page"] = page
|
||||
if size is not None:
|
||||
params["size"] = size
|
||||
if user_id is not None:
|
||||
params["user_id"] = user_id
|
||||
if team_id is not None:
|
||||
params["team_id"] = team_id
|
||||
if organization_id is not None:
|
||||
params["organization_id"] = organization_id
|
||||
if key_hash is not None:
|
||||
params["key_hash"] = key_hash
|
||||
if key_alias is not None:
|
||||
params["key_alias"] = key_alias
|
||||
if return_full_object is not None:
|
||||
params["return_full_object"] = str(return_full_object).lower()
|
||||
if include_team_keys is not None:
|
||||
params["include_team_keys"] = str(include_team_keys).lower()
|
||||
|
||||
request = requests.Request(
|
||||
"GET", url, headers=self._get_headers(), params=params
|
||||
)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def generate(
|
||||
self,
|
||||
models: Optional[List[str]] = None,
|
||||
aliases: Optional[Dict[str, str]] = None,
|
||||
spend: Optional[float] = None,
|
||||
duration: Optional[str] = None,
|
||||
key_alias: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
budget_id: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Generate an API key based on the provided data.
|
||||
|
||||
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||
|
||||
Args:
|
||||
models (Optional[List[str]]): List of allowed models for this key
|
||||
aliases (Optional[Dict[str, str]]): Model alias mappings
|
||||
spend (Optional[float]): Maximum spend limit for this key
|
||||
duration (Optional[str]): Duration for which the key is valid (e.g. "24h", "7d")
|
||||
key_alias (Optional[str]): Alias/name for the key for easier identification
|
||||
team_id (Optional[str]): Team ID to associate the key with
|
||||
user_id (Optional[str]): User ID to associate the key with
|
||||
budget_id (Optional[str]): Budget ID to associate the key with
|
||||
config (Optional[Dict[str, Any]]): Additional configuration parameters
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/key/generate"
|
||||
|
||||
data: Dict[str, Any] = {}
|
||||
if models is not None:
|
||||
data["models"] = models
|
||||
if aliases is not None:
|
||||
data["aliases"] = aliases
|
||||
if spend is not None:
|
||||
data["spend"] = spend
|
||||
if duration is not None:
|
||||
data["duration"] = duration
|
||||
if key_alias is not None:
|
||||
data["key_alias"] = key_alias
|
||||
if team_id is not None:
|
||||
data["team_id"] = team_id
|
||||
if user_id is not None:
|
||||
data["user_id"] = user_id
|
||||
if budget_id is not None:
|
||||
data["budget_id"] = budget_id
|
||||
if config is not None:
|
||||
data["config"] = config
|
||||
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def delete(
|
||||
self,
|
||||
keys: Optional[List[str]] = None,
|
||||
key_aliases: Optional[List[str]] = None,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Delete existing keys
|
||||
|
||||
Args:
|
||||
keys (List[str]): List of API keys to delete
|
||||
key_aliases (List[str]): List of key aliases to delete
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/key/delete"
|
||||
|
||||
data = {
|
||||
"keys": keys,
|
||||
"key_aliases": key_aliases,
|
||||
}
|
||||
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
key: str,
|
||||
models: Optional[List[str]] = None,
|
||||
aliases: Optional[Dict[str, str]] = None,
|
||||
spend: Optional[float] = None,
|
||||
duration: Optional[str] = None,
|
||||
key_alias: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Update an existing API key's parameters.
|
||||
|
||||
Args:
|
||||
models: Optional[List[str]] = None,
|
||||
aliases: Optional[Dict[str, str]] = None,
|
||||
spend: Optional[float] = None,
|
||||
duration: Optional[str] = None,
|
||||
key_alias: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/key/update"
|
||||
|
||||
data: Dict[str, Any] = {"key": key}
|
||||
|
||||
if key_alias is not None:
|
||||
data["key_alias"] = key_alias
|
||||
if user_id is not None:
|
||||
data["user_id"] = user_id
|
||||
if team_id is not None:
|
||||
data["team_id"] = team_id
|
||||
if models is not None:
|
||||
data["models"] = models
|
||||
if spend is not None:
|
||||
data["spend"] = spend
|
||||
if duration is not None:
|
||||
data["duration"] = duration
|
||||
if aliases is not None:
|
||||
data["aliases"] = aliases
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
session = requests.Session()
|
||||
response_text: Optional[str] = None
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response_text = response.text
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception:
|
||||
raise Exception(f"Error updating key: {response_text}")
|
||||
|
||||
def info(
|
||||
self, key: str, return_request: bool = False
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Get information about API keys.
|
||||
|
||||
Args:
|
||||
key (str): The key hash to get information about
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/key/info?key={key}"
|
||||
request = requests.Request("GET", url, headers=self._get_headers())
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
@@ -0,0 +1,62 @@
|
||||
import requests
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from .exceptions import UnauthorizedError
|
||||
|
||||
|
||||
class ModelGroupsManagementClient:
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the ModelGroupsManagementClient.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
|
||||
self._api_key = api_key
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get the headers for API requests, including authorization if api_key is set.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Headers to use for API requests
|
||||
"""
|
||||
headers = {}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
return headers
|
||||
|
||||
def info(
|
||||
self, return_request: bool = False
|
||||
) -> Union[List[Dict[str, Any]], requests.Request]:
|
||||
"""
|
||||
Get detailed information about all model groups from the server.
|
||||
|
||||
Args:
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[List[Dict[str, Any]], requests.Request]: Either a list of model group information dictionaries
|
||||
or a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/model_group/info"
|
||||
request = requests.Request("GET", url, headers=self._get_headers())
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
# Prepare and send the request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
@@ -0,0 +1,298 @@
|
||||
import requests
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from .exceptions import UnauthorizedError, NotFoundError
|
||||
|
||||
|
||||
class ModelsManagementClient:
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the ModelsManagementClient.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:8000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
|
||||
self._api_key = api_key
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get the headers for API requests, including authorization if api_key is set.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Headers to use for API requests
|
||||
"""
|
||||
headers = {}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
return headers
|
||||
|
||||
def list(
|
||||
self, return_request: bool = False
|
||||
) -> Union[List[Dict[str, Any]], requests.Request]:
|
||||
"""
|
||||
Get the list of models supported by the server.
|
||||
|
||||
Args:
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it.
|
||||
Useful for inspection or modification before sending.
|
||||
|
||||
Returns:
|
||||
Union[List[Dict[str, Any]], requests.Request]: Either a list of model information dictionaries
|
||||
or a prepared request object if return_request is True.
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/models"
|
||||
request = requests.Request("GET", url, headers=self._get_headers())
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
# Prepare and send the request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def new(
|
||||
self,
|
||||
model_name: str,
|
||||
model_params: Dict[str, Any],
|
||||
model_info: Optional[Dict[str, Any]] = None,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Add a new model to the proxy.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to add
|
||||
model_params (Dict[str, Any]): Parameters for the model (e.g., model type, api_base, api_key)
|
||||
model_info (Optional[Dict[str, Any]]): Additional information about the model
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/model/new"
|
||||
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"litellm_params": model_params,
|
||||
}
|
||||
if model_info:
|
||||
data["model_info"] = model_info
|
||||
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
# Prepare and send the request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def delete(
|
||||
self, model_id: str, return_request: bool = False
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Delete a model from the proxy.
|
||||
|
||||
Args:
|
||||
model_id (str): ID of the model to delete (e.g., "2f23364f-4579-4d79-a43a-2d48dd551c2e")
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
NotFoundError: If the request fails with a 404 status code or indicates the model was not found
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/model/delete"
|
||||
data = {"id": model_id}
|
||||
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
# Prepare and send the request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
if e.response.status_code == 404 or "not found" in e.response.text.lower():
|
||||
raise NotFoundError(e)
|
||||
raise
|
||||
|
||||
def get(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Get information about a specific model by its ID or name.
|
||||
|
||||
Args:
|
||||
model_id (Optional[str]): ID of the model to retrieve
|
||||
model_name (Optional[str]): Name of the model to retrieve
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the model information from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
ValueError: If neither model_id nor model_name is provided, or if both are provided
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
NotFoundError: If the model is not found
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
if (model_id is None and model_name is None) or (
|
||||
model_id is not None and model_name is not None
|
||||
):
|
||||
raise ValueError("Exactly one of model_id or model_name must be provided")
|
||||
|
||||
# If return_request is True, delegate to info
|
||||
if return_request:
|
||||
result = self.info(return_request=True)
|
||||
assert isinstance(result, requests.Request)
|
||||
return result
|
||||
|
||||
# Get all models and filter
|
||||
models = self.info()
|
||||
assert isinstance(models, List)
|
||||
|
||||
# Find the matching model
|
||||
for model in models:
|
||||
if (model_id and model.get("model_info", {}).get("id") == model_id) or (
|
||||
model_name and model.get("model_name") == model_name
|
||||
):
|
||||
return model
|
||||
|
||||
# If we get here, no model was found
|
||||
if model_id:
|
||||
msg = f"Model with id={model_id} not found"
|
||||
elif model_name:
|
||||
msg = f"Model with model_name={model_name} not found"
|
||||
else:
|
||||
msg = "Unknown error trying to find model"
|
||||
raise NotFoundError(
|
||||
requests.exceptions.HTTPError(
|
||||
msg,
|
||||
response=requests.Response(), # Empty response since we didn't make a direct request
|
||||
)
|
||||
)
|
||||
|
||||
def info(
|
||||
self, return_request: bool = False
|
||||
) -> Union[List[Dict[str, Any]], requests.Request]:
|
||||
"""
|
||||
Get detailed information about all models from the server.
|
||||
|
||||
Args:
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[List[Dict[str, Any]], requests.Request]: Either a list of model information dictionaries
|
||||
or a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/v1/model/info"
|
||||
request = requests.Request("GET", url, headers=self._get_headers())
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
# Prepare and send the request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
model_id: str,
|
||||
model_params: Dict[str, Any],
|
||||
model_info: Optional[Dict[str, Any]] = None,
|
||||
return_request: bool = False,
|
||||
) -> Union[Dict[str, Any], requests.Request]:
|
||||
"""
|
||||
Update an existing model's configuration.
|
||||
|
||||
Args:
|
||||
model_id (str): ID of the model to update
|
||||
model_params (Dict[str, Any]): New parameters for the model (e.g., model type, api_base, api_key)
|
||||
model_info (Optional[Dict[str, Any]]): Additional information about the model
|
||||
return_request (bool): If True, returns the prepared request object instead of executing it
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], requests.Request]: Either the response from the server or
|
||||
a prepared request object if return_request is True
|
||||
|
||||
Raises:
|
||||
UnauthorizedError: If the request fails with a 401 status code
|
||||
NotFoundError: If the model is not found
|
||||
requests.exceptions.RequestException: If the request fails with any other error
|
||||
"""
|
||||
url = f"{self._base_url}/model/update"
|
||||
|
||||
data = {
|
||||
"id": model_id,
|
||||
"litellm_params": model_params,
|
||||
}
|
||||
if model_info:
|
||||
data["model_info"] = model_info
|
||||
|
||||
request = requests.Request("POST", url, headers=self._get_headers(), json=data)
|
||||
|
||||
if return_request:
|
||||
return request
|
||||
|
||||
# Prepare and send the request
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.send(request.prepare())
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise UnauthorizedError(e)
|
||||
if e.response.status_code == 404 or "not found" in e.response.text.lower():
|
||||
raise NotFoundError(e)
|
||||
raise
|
||||
@@ -0,0 +1,146 @@
|
||||
"""Teams management client for LiteLLM proxy."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from .exceptions import UnauthorizedError
|
||||
|
||||
|
||||
class TeamsManagementClient:
|
||||
"""Client for managing teams in LiteLLM proxy."""
|
||||
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the TeamsManagementClient.
|
||||
|
||||
Args:
|
||||
base_url (str): The base URL of the LiteLLM proxy server (e.g., "http://localhost:4000")
|
||||
api_key (Optional[str]): API key for authentication. If provided, it will be sent as a Bearer token.
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/") # Remove trailing slash if present
|
||||
self._api_key = api_key
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get the headers for API requests, including authorization if api_key is set.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Headers to use for API requests
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
return headers
|
||||
|
||||
def list(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List teams that the user belongs to.
|
||||
|
||||
Args:
|
||||
user_id (Optional[str]): Only return teams which this user belongs to
|
||||
organization_id (Optional[str]): Only return teams which belong to this organization
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of team objects
|
||||
|
||||
Raises:
|
||||
requests.exceptions.HTTPError: If the request fails
|
||||
UnauthorizedError: If authentication fails
|
||||
"""
|
||||
url = f"{self._base_url}/team/list"
|
||||
params = {}
|
||||
if user_id:
|
||||
params["user_id"] = user_id
|
||||
if organization_id:
|
||||
params["organization_id"] = organization_id
|
||||
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizedError("Authentication failed. Check your API key.")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def list_v2(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
team_alias: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "asc",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a paginated list of teams with filtering and sorting options.
|
||||
|
||||
Args:
|
||||
user_id (Optional[str]): Only return teams which this user belongs to
|
||||
organization_id (Optional[str]): Only return teams which belong to this organization
|
||||
team_id (Optional[str]): Filter teams by exact team_id match
|
||||
team_alias (Optional[str]): Filter teams by partial team_alias match
|
||||
page (int): Page number for pagination
|
||||
page_size (int): Number of teams per page
|
||||
sort_by (Optional[str]): Column to sort by (e.g. 'team_id', 'team_alias', 'created_at')
|
||||
sort_order (str): Sort order ('asc' or 'desc')
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Paginated response containing teams and pagination info
|
||||
|
||||
Raises:
|
||||
requests.exceptions.HTTPError: If the request fails
|
||||
UnauthorizedError: If authentication fails
|
||||
"""
|
||||
url = f"{self._base_url}/v2/team/list"
|
||||
params: Dict[str, Union[str, int]] = {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"sort_order": sort_order,
|
||||
}
|
||||
|
||||
if user_id:
|
||||
params["user_id"] = user_id
|
||||
if organization_id:
|
||||
params["organization_id"] = organization_id
|
||||
if team_id:
|
||||
params["team_id"] = team_id
|
||||
if team_alias:
|
||||
params["team_alias"] = team_alias
|
||||
if sort_by:
|
||||
params["sort_by"] = sort_by
|
||||
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizedError("Authentication failed. Check your API key.")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_available(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get list of available teams that the user can join.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of available team objects
|
||||
|
||||
Raises:
|
||||
requests.exceptions.HTTPError: If the request fails
|
||||
UnauthorizedError: If authentication fails
|
||||
"""
|
||||
url = f"{self._base_url}/team/available"
|
||||
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizedError("Authentication failed. Check your API key.")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -0,0 +1,58 @@
|
||||
import requests
|
||||
from typing import List, Dict, Any, Optional
|
||||
from .exceptions import UnauthorizedError, NotFoundError
|
||||
|
||||
|
||||
class UsersManagementClient:
|
||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
def list_users(
|
||||
self, params: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List users (GET /user/list)"""
|
||||
url = f"{self.base_url}/user/list"
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizedError(response.text)
|
||||
response.raise_for_status()
|
||||
return response.json().get("users", response.json())
|
||||
|
||||
def get_user(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get user info (GET /user/info)"""
|
||||
url = f"{self.base_url}/user/info"
|
||||
params = {"user_id": user_id} if user_id else {}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizedError(response.text)
|
||||
if response.status_code == 404:
|
||||
raise NotFoundError(response.text)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def create_user(self, user_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a new user (POST /user/new)"""
|
||||
url = f"{self.base_url}/user/new"
|
||||
response = requests.post(url, headers=self._get_headers(), json=user_data)
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizedError(response.text)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def delete_user(self, user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""Delete users (POST /user/delete)"""
|
||||
url = f"{self.base_url}/user/delete"
|
||||
response = requests.post(
|
||||
url, headers=self._get_headers(), json={"user_ids": user_ids}
|
||||
)
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizedError(response.text)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,169 @@
|
||||
def show_missing_vars_in_env():
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from litellm.proxy.proxy_server import master_key, prisma_client
|
||||
|
||||
if prisma_client is None and master_key is None:
|
||||
return HTMLResponse(
|
||||
content=missing_keys_form(
|
||||
missing_key_names="DATABASE_URL, LITELLM_MASTER_KEY"
|
||||
),
|
||||
status_code=200,
|
||||
)
|
||||
if prisma_client is None:
|
||||
return HTMLResponse(
|
||||
content=missing_keys_form(missing_key_names="DATABASE_URL"), status_code=200
|
||||
)
|
||||
|
||||
if master_key is None:
|
||||
return HTMLResponse(
|
||||
content=missing_keys_form(missing_key_names="LITELLM_MASTER_KEY"),
|
||||
status_code=200,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def missing_keys_form(missing_key_names: str):
|
||||
missing_keys_html_form = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f4f4f9;
|
||||
color: #333;
|
||||
margin: 20px;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 800px;
|
||||
margin: auto;
|
||||
padding: 20px;
|
||||
background: #fff;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 24px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
pre {{
|
||||
background: #f8f8f8;
|
||||
padding: 1px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
font-size: 14px;
|
||||
}}
|
||||
.env-var {{
|
||||
font-weight: normal;
|
||||
}}
|
||||
.comment {{
|
||||
font-weight: normal;
|
||||
color: #777;
|
||||
}}
|
||||
</style>
|
||||
<title>Environment Setup Instructions</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Environment Setup Instructions</h1>
|
||||
<p>Please add the following variables to your environment variables:</p>
|
||||
<pre>
|
||||
<span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># Your master key for the proxy server. Can use this to send /chat/completion requests etc</span>
|
||||
<span class="env-var">LITELLM_SALT_KEY="sk-XXXXXXXX"</span> <span class="comment"># Can NOT CHANGE THIS ONCE SET - It is used to encrypt/decrypt credentials stored in DB. If value of 'LITELLM_SALT_KEY' changes your models cannot be retrieved from DB</span>
|
||||
<span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span>
|
||||
<span class="comment">## OPTIONAL ##</span>
|
||||
<span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span>
|
||||
<span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span>
|
||||
</pre>
|
||||
<h1>Missing Environment Variables</h1>
|
||||
<p>{missing_keys}</p>
|
||||
</div>
|
||||
|
||||
<div class="container">
|
||||
<h1>Need Help? Support</h1>
|
||||
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
|
||||
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return missing_keys_html_form.format(missing_keys=missing_key_names)
|
||||
|
||||
|
||||
def admin_ui_disabled():
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
ui_disabled_html = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f4f4f9;
|
||||
color: #333;
|
||||
margin: 20px;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 800px;
|
||||
margin: auto;
|
||||
padding: 20px;
|
||||
background: #fff;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 24px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
pre {{
|
||||
background: #f8f8f8;
|
||||
padding: 1px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
font-size: 14px;
|
||||
}}
|
||||
.env-var {{
|
||||
font-weight: normal;
|
||||
}}
|
||||
.comment {{
|
||||
font-weight: normal;
|
||||
color: #777;
|
||||
}}
|
||||
</style>
|
||||
<title>Admin UI Disabled</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Admin UI is Disabled</h1>
|
||||
<p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p>
|
||||
<pre>
|
||||
<span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span>
|
||||
</pre>
|
||||
<p>After making this change, restart the application for it to take effect.</p>
|
||||
</div>
|
||||
|
||||
<div class="container">
|
||||
<h1>Need Help? Support</h1>
|
||||
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
|
||||
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(
|
||||
content=ui_disabled_html,
|
||||
status_code=200,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
# LiteLLM ASCII banner
|
||||
LITELLM_BANNER = """ ██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗
|
||||
██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║
|
||||
██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║
|
||||
██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║
|
||||
███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║
|
||||
╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝"""
|
||||
|
||||
|
||||
def show_banner():
|
||||
"""Display the LiteLLM CLI banner."""
|
||||
try:
|
||||
import click
|
||||
|
||||
click.echo(f"\n{LITELLM_BANNER}\n")
|
||||
except ImportError:
|
||||
print("\n") # noqa: T201
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Event-driven cache coordinator to prevent cache stampede.
|
||||
|
||||
Use this when many requests can miss the same cache key at once (e.g. after
|
||||
expiry or restart). Without coordination, they would all run the expensive
|
||||
load (DB query, API call) in parallel and overload the backend.
|
||||
|
||||
This module ensures only one request performs the load; the rest wait for a
|
||||
signal and then read the freshly cached value. Reuse it for any cache-aside
|
||||
pattern: global spend, feature flags, config, or other shared read-through data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncCacheProtocol(Protocol):
|
||||
"""Protocol for cache backends used by EventDrivenCacheCoordinator."""
|
||||
|
||||
async def async_get_cache(self, key: str, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
async def async_set_cache(self, key: str, value: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class EventDrivenCacheCoordinator:
|
||||
"""
|
||||
Coordinates a single in-flight load per logical resource to prevent cache stampede.
|
||||
|
||||
Pattern:
|
||||
- First request: loads data (e.g. DB query), caches it, then signals waiters.
|
||||
- Other requests: wait for the signal, then read from cache.
|
||||
|
||||
Create one instance per resource (e.g. one for global spend, one for feature flags).
|
||||
"""
|
||||
|
||||
def __init__(self, log_prefix: str = "[CACHE]"):
|
||||
self._lock = asyncio.Lock()
|
||||
self._event: Optional[asyncio.Event] = None
|
||||
self._query_in_progress = False
|
||||
self._log_prefix = log_prefix
|
||||
|
||||
async def _get_cached(
|
||||
self, cache_key: str, cache: AsyncCacheProtocol
|
||||
) -> Optional[Any]:
|
||||
"""Return value from cache if present, else None."""
|
||||
return await cache.async_get_cache(key=cache_key)
|
||||
|
||||
def _log_cache_hit(self, value: T) -> None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Cache hit, value: %s", self._log_prefix, value
|
||||
)
|
||||
|
||||
def _log_cache_miss(self) -> None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug("%s Cache miss", self._log_prefix)
|
||||
|
||||
async def _claim_role(self) -> Optional[asyncio.Event]:
|
||||
"""
|
||||
Under lock: return event to wait on if load is in progress, else set us as loader and return None.
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._query_in_progress and self._event is not None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Load in flight, waiting for signal", self._log_prefix
|
||||
)
|
||||
return self._event
|
||||
self._query_in_progress = True
|
||||
self._event = asyncio.Event()
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Starting load (will signal others when done)",
|
||||
self._log_prefix,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _wait_for_signal_and_get(
|
||||
self,
|
||||
event: asyncio.Event,
|
||||
cache_key: str,
|
||||
cache: AsyncCacheProtocol,
|
||||
) -> Optional[T]:
|
||||
"""Wait for loader to finish, then read from cache."""
|
||||
await event.wait()
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Signal received, reading from cache", self._log_prefix
|
||||
)
|
||||
value: Optional[T] = await cache.async_get_cache(key=cache_key)
|
||||
if value is not None and self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Cache filled by other request, value: %s",
|
||||
self._log_prefix,
|
||||
value,
|
||||
)
|
||||
elif value is None and self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Signal received but cache still empty", self._log_prefix
|
||||
)
|
||||
return value
|
||||
|
||||
async def _load_and_cache(
|
||||
self,
|
||||
cache_key: str,
|
||||
cache: AsyncCacheProtocol,
|
||||
load_fn: Callable[[], Awaitable[T]],
|
||||
) -> Optional[T]:
|
||||
"""Double-check cache, run load_fn, set cache, return value. Caller must call _signal_done in finally."""
|
||||
value = await cache.async_get_cache(key=cache_key)
|
||||
if value is not None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Cache filled while acquiring lock, value: %s",
|
||||
self._log_prefix,
|
||||
value,
|
||||
)
|
||||
return value
|
||||
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug("%s Running load", self._log_prefix)
|
||||
start = time.perf_counter()
|
||||
value = await load_fn()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Load completed in %.2fms, result: %s",
|
||||
self._log_prefix,
|
||||
elapsed_ms,
|
||||
value,
|
||||
)
|
||||
|
||||
await cache.async_set_cache(key=cache_key, value=value)
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug("%s Result cached", self._log_prefix)
|
||||
return value
|
||||
|
||||
async def _signal_done(self) -> None:
|
||||
"""Reset loader state and signal all waiters."""
|
||||
async with self._lock:
|
||||
self._query_in_progress = False
|
||||
if self._event is not None:
|
||||
if self._log_prefix:
|
||||
verbose_proxy_logger.debug(
|
||||
"%s Signaling all waiting requests", self._log_prefix
|
||||
)
|
||||
self._event.set()
|
||||
self._event = None
|
||||
|
||||
async def get_or_load(
|
||||
self,
|
||||
cache_key: str,
|
||||
cache: AsyncCacheProtocol,
|
||||
load_fn: Callable[[], Awaitable[T]],
|
||||
) -> Optional[T]:
|
||||
"""
|
||||
Return cached value or load it once and signal waiters.
|
||||
|
||||
- cache_key: Key to read/write in the cache.
|
||||
- cache: Object with async_get_cache(key) and async_set_cache(key, value).
|
||||
- load_fn: Async callable that performs the load (e.g. DB query). No args.
|
||||
Return value is cached and returned. If it raises, waiters are
|
||||
still signaled so they can retry or handle empty cache.
|
||||
|
||||
Returns the value from cache or from load_fn, or None if load failed or
|
||||
cache was still empty after waiting.
|
||||
"""
|
||||
value = await self._get_cached(cache_key, cache)
|
||||
if value is not None:
|
||||
self._log_cache_hit(value)
|
||||
return value
|
||||
|
||||
self._log_cache_miss()
|
||||
event_to_wait = await self._claim_role()
|
||||
|
||||
if event_to_wait is not None:
|
||||
return await self._wait_for_signal_and_get(event_to_wait, cache_key, cache)
|
||||
|
||||
try:
|
||||
result = await self._load_and_cache(cache_key, cache, load_fn)
|
||||
return result
|
||||
finally:
|
||||
await self._signal_done()
|
||||
@@ -0,0 +1,526 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
|
||||
def initialize_callbacks_on_proxy( # noqa: PLR0915
|
||||
value: Any,
|
||||
premium_user: bool,
|
||||
config_file_path: str,
|
||||
litellm_settings: dict,
|
||||
callback_specific_params: dict = {},
|
||||
):
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.logging_callback_manager import (
|
||||
LoggingCallbackManager,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}"
|
||||
)
|
||||
if isinstance(value, list):
|
||||
imported_list: List[Any] = []
|
||||
for callback in value: # ["presidio", <my-custom-callback>]
|
||||
# check if callback is a custom logger compatible callback
|
||||
if isinstance(callback, str):
|
||||
callback = LoggingCallbackManager._add_custom_callback_generic_api_str(
|
||||
callback
|
||||
)
|
||||
if (
|
||||
isinstance(callback, str)
|
||||
and callback in litellm._known_custom_logger_compatible_callbacks
|
||||
):
|
||||
imported_list.append(callback)
|
||||
elif isinstance(callback, str) and callback == "presidio":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
|
||||
_OPTIONAL_PresidioPIIMasking,
|
||||
)
|
||||
|
||||
presidio_logging_only: Optional[bool] = litellm_settings.get(
|
||||
"presidio_logging_only", None
|
||||
)
|
||||
if presidio_logging_only is not None:
|
||||
presidio_logging_only = bool(
|
||||
presidio_logging_only
|
||||
) # validate boolean given
|
||||
|
||||
_presidio_params = {}
|
||||
if "presidio" in callback_specific_params and isinstance(
|
||||
callback_specific_params["presidio"], dict
|
||||
):
|
||||
_presidio_params = callback_specific_params["presidio"]
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"logging_only": presidio_logging_only,
|
||||
**_presidio_params,
|
||||
}
|
||||
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
|
||||
imported_list.append(pii_masking_object)
|
||||
elif isinstance(callback, str) and callback == "llamaguard_moderations":
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.llama_guard import (
|
||||
_ENTERPRISE_LlamaGuard,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"MissingTrying to use Llama Guard"
|
||||
+ CommonProxyErrors.missing_enterprise_package.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Llama Guard"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
llama_guard_object = _ENTERPRISE_LlamaGuard()
|
||||
imported_list.append(llama_guard_object)
|
||||
elif isinstance(callback, str) and callback == "hide_secrets":
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.secret_detection import (
|
||||
_ENTERPRISE_SecretDetection,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Secret Detection"
|
||||
+ CommonProxyErrors.missing_enterprise_package.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use secret hiding"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
_secret_detection_object = _ENTERPRISE_SecretDetection()
|
||||
imported_list.append(_secret_detection_object)
|
||||
elif isinstance(callback, str) and callback == "openai_moderations":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.openai_moderation import (
|
||||
_ENTERPRISE_OpenAI_Moderation,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use OpenAI Moderations Check,"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use OpenAI Moderations Check"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
|
||||
imported_list.append(openai_moderations_object)
|
||||
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
||||
lakeraAI_Moderation,
|
||||
)
|
||||
|
||||
init_params = {}
|
||||
if "lakera_prompt_injection" in callback_specific_params:
|
||||
init_params = callback_specific_params["lakera_prompt_injection"]
|
||||
lakera_moderations_object = lakeraAI_Moderation(**init_params)
|
||||
imported_list.append(lakera_moderations_object)
|
||||
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai.aporia_ai import (
|
||||
AporiaGuardrail,
|
||||
)
|
||||
|
||||
aporia_guardrail_object = AporiaGuardrail()
|
||||
imported_list.append(aporia_guardrail_object)
|
||||
elif isinstance(callback, str) and callback == "google_text_moderation":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.google_text_moderation import (
|
||||
_ENTERPRISE_GoogleTextModeration,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Google Text Moderation,"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Google Text Moderation"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
|
||||
imported_list.append(google_text_moderation_obj)
|
||||
elif isinstance(callback, str) and callback == "llmguard_moderations":
|
||||
try:
|
||||
from litellm_enterprise.enterprise_callbacks.llm_guard import (
|
||||
_ENTERPRISE_LLMGuard,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Llm Guard"
|
||||
+ CommonProxyErrors.missing_enterprise_package.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Llm Guard"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
|
||||
imported_list.append(llm_guard_moderation_obj)
|
||||
elif isinstance(callback, str) and callback == "blocked_user_check":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||
_ENTERPRISE_BlockedUserList,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Blocked User List"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use ENTERPRISE BlockedUser"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
blocked_user_list = _ENTERPRISE_BlockedUserList(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
imported_list.append(blocked_user_list)
|
||||
elif isinstance(callback, str) and callback == "banned_keywords":
|
||||
try:
|
||||
from enterprise.enterprise_hooks.banned_keywords import (
|
||||
_ENTERPRISE_BannedKeywords,
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Trying to use Banned Keywords"
|
||||
+ CommonProxyErrors.missing_enterprise_package_docker.value
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use ENTERPRISE BannedKeyword"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
||||
imported_list.append(banned_keywords_obj)
|
||||
elif isinstance(callback, str) and callback == "detect_prompt_injection":
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
|
||||
prompt_injection_params = None
|
||||
if "prompt_injection_params" in litellm_settings:
|
||||
prompt_injection_params_in_config = litellm_settings[
|
||||
"prompt_injection_params"
|
||||
]
|
||||
prompt_injection_params = LiteLLMPromptInjectionParams(
|
||||
**prompt_injection_params_in_config
|
||||
)
|
||||
|
||||
prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection(
|
||||
prompt_injection_params=prompt_injection_params,
|
||||
)
|
||||
imported_list.append(prompt_injection_detection_obj)
|
||||
elif isinstance(callback, str) and callback == "batch_redis_requests":
|
||||
from litellm.proxy.hooks.batch_redis_get import (
|
||||
_PROXY_BatchRedisRequests,
|
||||
)
|
||||
|
||||
batch_redis_obj = _PROXY_BatchRedisRequests()
|
||||
imported_list.append(batch_redis_obj)
|
||||
elif isinstance(callback, str) and callback == "azure_content_safety":
|
||||
from litellm.proxy.hooks.azure_content_safety import (
|
||||
_PROXY_AzureContentSafety,
|
||||
)
|
||||
|
||||
azure_content_safety_params = litellm_settings[
|
||||
"azure_content_safety_params"
|
||||
]
|
||||
for k, v in azure_content_safety_params.items():
|
||||
if (
|
||||
v is not None
|
||||
and isinstance(v, str)
|
||||
and v.startswith("os.environ/")
|
||||
):
|
||||
azure_content_safety_params[k] = get_secret(v)
|
||||
|
||||
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||
**azure_content_safety_params,
|
||||
)
|
||||
imported_list.append(azure_content_safety_obj)
|
||||
elif isinstance(callback, str) and callback == "websearch_interception":
|
||||
from litellm.integrations.websearch_interception.handler import (
|
||||
WebSearchInterceptionLogger,
|
||||
)
|
||||
|
||||
websearch_interception_obj = (
|
||||
WebSearchInterceptionLogger.initialize_from_proxy_config(
|
||||
litellm_settings=litellm_settings,
|
||||
callback_specific_params=callback_specific_params,
|
||||
)
|
||||
)
|
||||
imported_list.append(websearch_interception_obj)
|
||||
elif isinstance(callback, str) and callback == "datadog_cost_management":
|
||||
from litellm.integrations.datadog.datadog_cost_management import (
|
||||
DatadogCostManagementLogger,
|
||||
)
|
||||
|
||||
datadog_cost_management_obj = DatadogCostManagementLogger()
|
||||
imported_list.append(datadog_cost_management_obj)
|
||||
elif isinstance(callback, CustomLogger):
|
||||
imported_list.append(callback)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}"
|
||||
)
|
||||
imported_list.append(
|
||||
get_instance_fn(
|
||||
value=callback,
|
||||
config_file_path=config_file_path,
|
||||
)
|
||||
)
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.extend(imported_list)
|
||||
else:
|
||||
litellm.callbacks = imported_list # type: ignore
|
||||
|
||||
if "prometheus" in value:
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
PrometheusLogger._mount_metrics_endpoint()
|
||||
else:
|
||||
litellm.callbacks = [
|
||||
get_instance_fn(
|
||||
value=value,
|
||||
config_file_path=config_file_path,
|
||||
)
|
||||
]
|
||||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||
)
|
||||
|
||||
|
||||
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = (
|
||||
_litellm_params.get(get_metadata_variable_name_from_kwargs(kwargs)) or {}
|
||||
)
|
||||
_model_group = _metadata.get("model_group", None)
|
||||
if _model_group is not None:
|
||||
return _model_group
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_model_group_from_request_data(data: dict) -> Optional[str]:
|
||||
_metadata = data.get("metadata", None) or {}
|
||||
_model_group = _metadata.get("model_group", None)
|
||||
if _model_group is not None:
|
||||
return _model_group
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]:
|
||||
"""
|
||||
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group}
|
||||
|
||||
Returns {} when api_key + model rpm/tpm limit is not set
|
||||
|
||||
"""
|
||||
headers = {}
|
||||
_metadata = data.get("metadata", None) or {}
|
||||
model_group = get_model_group_from_request_data(data)
|
||||
|
||||
# The h11 package considers "/" or ":" invalid and raise a LocalProtocolError
|
||||
h11_model_group_name = (
|
||||
model_group.replace("/", "-").replace(":", "-") if model_group else None
|
||||
)
|
||||
|
||||
# Remaining Requests
|
||||
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}"
|
||||
remaining_requests = _metadata.get(remaining_requests_variable_name, None)
|
||||
if remaining_requests:
|
||||
headers[
|
||||
f"x-litellm-key-remaining-requests-{h11_model_group_name}"
|
||||
] = remaining_requests
|
||||
|
||||
# Remaining Tokens
|
||||
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
|
||||
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None)
|
||||
if remaining_tokens:
|
||||
headers[
|
||||
f"x-litellm-key-remaining-tokens-{h11_model_group_name}"
|
||||
] = remaining_tokens
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
|
||||
_metadata = request_data.get("metadata", None)
|
||||
if not _metadata:
|
||||
_metadata = request_data.get("litellm_metadata", None)
|
||||
if not isinstance(_metadata, dict):
|
||||
_metadata = {}
|
||||
headers = {}
|
||||
if "applied_guardrails" in _metadata:
|
||||
headers["x-litellm-applied-guardrails"] = ",".join(
|
||||
_metadata["applied_guardrails"]
|
||||
)
|
||||
|
||||
if "applied_policies" in _metadata:
|
||||
headers["x-litellm-applied-policies"] = ",".join(_metadata["applied_policies"])
|
||||
|
||||
if "policy_sources" in _metadata:
|
||||
sources = _metadata["policy_sources"]
|
||||
if isinstance(sources, dict) and sources:
|
||||
# Use ';' as delimiter — matched_via reasons may contain commas
|
||||
headers["x-litellm-policy-sources"] = "; ".join(
|
||||
f"{name}={reason}" for name, reason in sources.items()
|
||||
)
|
||||
|
||||
if "semantic-similarity" in _metadata:
|
||||
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
|
||||
|
||||
pillar_headers = _metadata.get("pillar_response_headers")
|
||||
if isinstance(pillar_headers, dict):
|
||||
headers.update(pillar_headers)
|
||||
elif "pillar_flagged" in _metadata:
|
||||
headers["x-pillar-flagged"] = str(_metadata["pillar_flagged"]).lower()
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def add_guardrail_to_applied_guardrails_header(
|
||||
request_data: Dict, guardrail_name: Optional[str]
|
||||
):
|
||||
if guardrail_name is None:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
if "applied_guardrails" in _metadata:
|
||||
_metadata["applied_guardrails"].append(guardrail_name)
|
||||
else:
|
||||
_metadata["applied_guardrails"] = [guardrail_name]
|
||||
# Ensure metadata is set back to request_data (important when metadata didn't exist)
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_policy_to_applied_policies_header(
|
||||
request_data: Dict, policy_name: Optional[str]
|
||||
):
|
||||
"""
|
||||
Add a policy name to the applied_policies list in request metadata.
|
||||
|
||||
This is used to track which policies were applied to a request,
|
||||
similar to how applied_guardrails tracks guardrails.
|
||||
"""
|
||||
if policy_name is None:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
if "applied_policies" in _metadata:
|
||||
if policy_name not in _metadata["applied_policies"]:
|
||||
_metadata["applied_policies"].append(policy_name)
|
||||
else:
|
||||
_metadata["applied_policies"] = [policy_name]
|
||||
# Ensure metadata is set back to request_data (important when metadata didn't exist)
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, str]):
|
||||
"""
|
||||
Store policy match reasons in metadata for x-litellm-policy-sources header.
|
||||
|
||||
Args:
|
||||
request_data: The request data dict
|
||||
policy_sources: Map of policy_name -> matched_via reason
|
||||
"""
|
||||
if not policy_sources:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
existing = _metadata.get("policy_sources", {})
|
||||
if not isinstance(existing, dict):
|
||||
existing = {}
|
||||
existing.update(policy_sources)
|
||||
_metadata["policy_sources"] = existing
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_guardrail_response_to_standard_logging_object(
|
||||
litellm_logging_obj: Optional["LiteLLMLogging"],
|
||||
guardrail_response: StandardLoggingGuardrailInformation,
|
||||
):
|
||||
if litellm_logging_obj is None:
|
||||
return
|
||||
standard_logging_object: Optional[
|
||||
StandardLoggingPayload
|
||||
] = litellm_logging_obj.model_call_details.get("standard_logging_object")
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
guardrail_information = standard_logging_object.get("guardrail_information", [])
|
||||
if guardrail_information is None:
|
||||
guardrail_information = []
|
||||
guardrail_information.append(guardrail_response)
|
||||
standard_logging_object["guardrail_information"] = guardrail_information
|
||||
|
||||
return standard_logging_object
|
||||
|
||||
|
||||
def get_metadata_variable_name_from_kwargs(
|
||||
kwargs: dict,
|
||||
) -> Literal["metadata", "litellm_metadata"]:
|
||||
"""
|
||||
Helper to return what the "metadata" field should be called in the request data
|
||||
|
||||
- New endpoints return `litellm_metadata`
|
||||
- Old endpoints return `metadata`
|
||||
|
||||
Context:
|
||||
- LiteLLM used `metadata` as an internal field for storing metadata
|
||||
- OpenAI then started using this field for their metadata
|
||||
- LiteLLM is now moving to using `litellm_metadata` for our metadata
|
||||
"""
|
||||
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
|
||||
|
||||
|
||||
def process_callback(
|
||||
_callback: str, callback_type: str, environment_variables: dict
|
||||
) -> dict:
|
||||
"""Process a single callback and return its data with environment variables"""
|
||||
env_vars = CustomLogger.get_callback_env_vars(_callback)
|
||||
|
||||
env_vars_dict: dict[str, str | None] = {}
|
||||
for _var in env_vars:
|
||||
env_variable = environment_variables.get(_var, None)
|
||||
if env_variable is None:
|
||||
env_vars_dict[_var] = None
|
||||
else:
|
||||
env_vars_dict[_var] = env_variable
|
||||
|
||||
return {"name": _callback, "variables": env_vars_dict, "type": callback_type}
|
||||
|
||||
|
||||
def normalize_callback_names(callbacks: Iterable[Any]) -> List[Any]:
|
||||
if callbacks is None:
|
||||
return []
|
||||
return [c.lower() if isinstance(c, str) else c for c in callbacks]
|
||||
@@ -0,0 +1,437 @@
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
class CustomOpenAPISpec:
|
||||
"""
|
||||
Handler for customizing OpenAPI specifications with Pydantic models
|
||||
for documentation purposes without runtime validation.
|
||||
"""
|
||||
|
||||
CHAT_COMPLETION_PATHS = [
|
||||
"/v1/chat/completions",
|
||||
"/chat/completions",
|
||||
"/engines/{model}/chat/completions",
|
||||
"/openai/deployments/{model}/chat/completions",
|
||||
]
|
||||
|
||||
EMBEDDING_PATHS = [
|
||||
"/v1/embeddings",
|
||||
"/embeddings",
|
||||
"/engines/{model}/embeddings",
|
||||
"/openai/deployments/{model}/embeddings",
|
||||
]
|
||||
|
||||
RESPONSES_API_PATHS = ["/v1/responses", "/responses"]
|
||||
|
||||
@staticmethod
|
||||
def get_pydantic_schema(model_class) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get JSON schema from a Pydantic model, handling both v1 and v2 APIs.
|
||||
|
||||
Args:
|
||||
model_class: Pydantic model class
|
||||
|
||||
Returns:
|
||||
JSON schema dict or None if failed
|
||||
"""
|
||||
try:
|
||||
# Try Pydantic v2 method first
|
||||
return model_class.model_json_schema() # type: ignore
|
||||
except AttributeError:
|
||||
try:
|
||||
# Fallback to Pydantic v1 method
|
||||
return model_class.schema() # type: ignore
|
||||
except AttributeError:
|
||||
# If both methods fail, return None
|
||||
return None
|
||||
except Exception as e:
|
||||
# FastAPI 0.120+ may fail schema generation for certain types (e.g., openai.Timeout)
|
||||
# Log the error and return None to skip schema generation for this model
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to generate schema for {model_class}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def add_schema_to_components(
|
||||
openapi_schema: Dict[str, Any], schema_name: str, schema_def: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Add a schema definition to the OpenAPI components/schemas section.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
schema_name: Name for the schema component
|
||||
schema_def: The schema definition
|
||||
"""
|
||||
# Ensure components/schemas structure exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
# Add the schema
|
||||
CustomOpenAPISpec._move_defs_to_components(
|
||||
openapi_schema, {schema_name: schema_def}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_request_body_to_paths(
|
||||
openapi_schema: Dict[str, Any], paths: List[str], schema_ref: str
|
||||
) -> None:
|
||||
"""
|
||||
Add request body with expanded form fields for better Swagger UI display.
|
||||
This keeps the request body but expands it to show individual fields in the UI.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
paths: List of paths to update
|
||||
schema_ref: Reference to the schema component (e.g., "#/components/schemas/ModelName")
|
||||
"""
|
||||
for path in paths:
|
||||
if (
|
||||
path in openapi_schema.get("paths", {})
|
||||
and "post" in openapi_schema["paths"][path]
|
||||
):
|
||||
# Get the actual schema to extract ALL field definitions
|
||||
schema_name = schema_ref.split("/")[
|
||||
-1
|
||||
] # Extract "ProxyChatCompletionRequest" from the ref
|
||||
actual_schema = (
|
||||
openapi_schema.get("components", {})
|
||||
.get("schemas", {})
|
||||
.get(schema_name, {})
|
||||
)
|
||||
schema_properties = actual_schema.get("properties", {})
|
||||
required_fields = actual_schema.get("required", [])
|
||||
|
||||
# Extract $defs and add them to components/schemas
|
||||
# This fixes Pydantic v2 $defs not being resolvable in Swagger/OpenAPI
|
||||
if "$defs" in actual_schema:
|
||||
CustomOpenAPISpec._move_defs_to_components(
|
||||
openapi_schema, actual_schema["$defs"]
|
||||
)
|
||||
|
||||
# Create an expanded inline schema instead of just a $ref
|
||||
# This makes Swagger UI show all individual fields in the request body editor
|
||||
expanded_schema = {
|
||||
"type": "object",
|
||||
"required": required_fields,
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
# Add all properties with their full definitions
|
||||
for field_name, field_def in schema_properties.items():
|
||||
expanded_field = CustomOpenAPISpec._expand_field_definition(
|
||||
field_def
|
||||
)
|
||||
|
||||
# Rewrite $defs references to use components/schemas instead
|
||||
expanded_field = CustomOpenAPISpec._rewrite_defs_refs(
|
||||
expanded_field
|
||||
)
|
||||
|
||||
# Add a simple example for the messages field
|
||||
if field_name == "messages":
|
||||
expanded_field["example"] = [
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
|
||||
expanded_schema["properties"][field_name] = expanded_field
|
||||
|
||||
# Set the request body with the expanded schema
|
||||
openapi_schema["paths"][path]["post"]["requestBody"] = {
|
||||
"required": True,
|
||||
"content": {"application/json": {"schema": expanded_schema}},
|
||||
}
|
||||
|
||||
# Keep any existing parameters (like path parameters) but remove conflicting query params
|
||||
if "parameters" in openapi_schema["paths"][path]["post"]:
|
||||
existing_params = openapi_schema["paths"][path]["post"][
|
||||
"parameters"
|
||||
]
|
||||
# Only keep path parameters, remove query params that conflict with request body
|
||||
filtered_params = [
|
||||
param for param in existing_params if param.get("in") == "path"
|
||||
]
|
||||
openapi_schema["paths"][path]["post"][
|
||||
"parameters"
|
||||
] = filtered_params
|
||||
|
||||
@staticmethod
|
||||
def _move_defs_to_components(
|
||||
openapi_schema: Dict[str, Any], defs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Move $defs from Pydantic v2 schema to OpenAPI components/schemas.
|
||||
This makes the definitions resolvable in Swagger/OpenAPI viewers.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
defs: The $defs dictionary from Pydantic schema
|
||||
"""
|
||||
if not defs:
|
||||
return
|
||||
|
||||
# Ensure components/schemas exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
# Add each definition to components/schemas
|
||||
for def_name, def_schema in defs.items():
|
||||
# Recursively rewrite any nested $defs references within this definition
|
||||
rewritten_def = CustomOpenAPISpec._rewrite_defs_refs(def_schema)
|
||||
openapi_schema["components"]["schemas"][def_name] = rewritten_def
|
||||
|
||||
# If this definition also has $defs, process them recursively
|
||||
if "$defs" in def_schema:
|
||||
CustomOpenAPISpec._move_defs_to_components(
|
||||
openapi_schema, def_schema["$defs"]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_defs_refs(schema: Any) -> Any:
|
||||
"""
|
||||
Recursively rewrite $ref values from #/$defs/... to #/components/schemas/...
|
||||
This converts Pydantic v2 references to OpenAPI-compatible references.
|
||||
|
||||
Args:
|
||||
schema: Schema object to process (can be dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Schema with rewritten references
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
result = {}
|
||||
for key, value in schema.items():
|
||||
if (
|
||||
key == "$ref"
|
||||
and isinstance(value, str)
|
||||
and value.startswith("#/$defs/")
|
||||
):
|
||||
# Rewrite the reference to use components/schemas
|
||||
def_name = value.replace("#/$defs/", "")
|
||||
result[key] = f"#/components/schemas/{def_name}"
|
||||
elif key == "$defs":
|
||||
# Remove $defs from the schema since they're moved to components
|
||||
continue
|
||||
else:
|
||||
# Recursively process nested structures
|
||||
result[key] = CustomOpenAPISpec._rewrite_defs_refs(value)
|
||||
return result
|
||||
elif isinstance(schema, list):
|
||||
return [CustomOpenAPISpec._rewrite_defs_refs(item) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_schema(field_def: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract a simple schema from a Pydantic field definition for parameter display.
|
||||
|
||||
Args:
|
||||
field_def: Pydantic field definition
|
||||
|
||||
Returns:
|
||||
Simplified schema for OpenAPI parameter
|
||||
"""
|
||||
# Handle simple types
|
||||
if "type" in field_def:
|
||||
return {"type": field_def["type"]}
|
||||
|
||||
# Handle anyOf (Optional fields in Pydantic v2)
|
||||
if "anyOf" in field_def:
|
||||
any_of = field_def["anyOf"]
|
||||
# Find the non-null type
|
||||
for option in any_of:
|
||||
if option.get("type") != "null":
|
||||
return option
|
||||
# Fallback to string if all else fails
|
||||
return {"type": "string"}
|
||||
|
||||
# Default fallback
|
||||
return {"type": "string"}
|
||||
|
||||
@staticmethod
|
||||
def _expand_field_definition(field_def: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Expand a Pydantic field definition for inline use in OpenAPI schema.
|
||||
This creates a full field definition that Swagger UI can render as individual form fields.
|
||||
|
||||
Args:
|
||||
field_def: Pydantic field definition
|
||||
|
||||
Returns:
|
||||
Expanded field definition for OpenAPI schema
|
||||
"""
|
||||
# Return the field definition as-is since Pydantic already provides proper schemas
|
||||
return field_def.copy()
|
||||
|
||||
@staticmethod
|
||||
def add_request_schema(
|
||||
openapi_schema: Dict[str, Any],
|
||||
model_class: Type,
|
||||
schema_name: str,
|
||||
paths: List[str],
|
||||
operation_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generic method to add a request schema to OpenAPI specification.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
model_class: The Pydantic model class to get schema from
|
||||
schema_name: Name for the schema component
|
||||
paths: List of paths to add the request body to
|
||||
operation_name: Name of the operation for logging (e.g., "chat completion", "embedding")
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
# Get the schema for the model class
|
||||
request_schema = CustomOpenAPISpec.get_pydantic_schema(model_class)
|
||||
|
||||
# Only proceed if we successfully got the schema
|
||||
if request_schema is not None:
|
||||
# Add schema to components
|
||||
CustomOpenAPISpec.add_schema_to_components(
|
||||
openapi_schema, schema_name, request_schema
|
||||
)
|
||||
|
||||
# Add request body to specified endpoints
|
||||
CustomOpenAPISpec.add_request_body_to_paths(
|
||||
openapi_schema, paths, f"#/components/schemas/{schema_name}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Successfully added {schema_name} schema to OpenAPI spec"
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(f"Could not get schema for {schema_name}")
|
||||
|
||||
except Exception as e:
|
||||
# If schema addition fails, continue without it
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to add {operation_name} request schema: {str(e)}"
|
||||
)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_chat_completion_request_schema(
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
|
||||
This shows the request body in Swagger without runtime validation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy._types import ProxyChatCompletionRequest
|
||||
|
||||
return CustomOpenAPISpec.add_request_schema(
|
||||
openapi_schema=openapi_schema,
|
||||
model_class=ProxyChatCompletionRequest,
|
||||
schema_name="ProxyChatCompletionRequest",
|
||||
paths=CustomOpenAPISpec.CHAT_COMPLETION_PATHS,
|
||||
operation_name="chat completion",
|
||||
)
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to import ProxyChatCompletionRequest: {str(e)}"
|
||||
)
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_embedding_request_schema(openapi_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Add EmbeddingRequest schema to embedding endpoints for documentation.
|
||||
This shows the request body in Swagger without runtime validation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
from litellm.types.embedding import EmbeddingRequest
|
||||
|
||||
return CustomOpenAPISpec.add_request_schema(
|
||||
openapi_schema=openapi_schema,
|
||||
model_class=EmbeddingRequest,
|
||||
schema_name="EmbeddingRequest",
|
||||
paths=CustomOpenAPISpec.EMBEDDING_PATHS,
|
||||
operation_name="embedding",
|
||||
)
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.debug(f"Failed to import EmbeddingRequest: {str(e)}")
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_responses_api_request_schema(
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
|
||||
This shows the request body in Swagger without runtime validation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dict to modify
|
||||
|
||||
Returns:
|
||||
Modified OpenAPI schema
|
||||
"""
|
||||
try:
|
||||
from litellm.types.llms.openai import ResponsesAPIRequestParams
|
||||
|
||||
return CustomOpenAPISpec.add_request_schema(
|
||||
openapi_schema=openapi_schema,
|
||||
model_class=ResponsesAPIRequestParams,
|
||||
schema_name="ResponsesAPIRequestParams",
|
||||
paths=CustomOpenAPISpec.RESPONSES_API_PATHS,
|
||||
operation_name="responses API",
|
||||
)
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to import ResponsesAPIRequestParams: {str(e)}"
|
||||
)
|
||||
return openapi_schema
|
||||
|
||||
@staticmethod
|
||||
def add_llm_api_request_schema_body(
|
||||
openapi_schema: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Add LLM API request schema bodies to OpenAPI specification for documentation.
|
||||
|
||||
Args:
|
||||
openapi_schema: The base OpenAPI schema
|
||||
|
||||
Returns:
|
||||
OpenAPI schema with added request body schemas
|
||||
"""
|
||||
# Add chat completion request schema
|
||||
openapi_schema = CustomOpenAPISpec.add_chat_completion_request_schema(
|
||||
openapi_schema
|
||||
)
|
||||
|
||||
# Add embedding request schema
|
||||
openapi_schema = CustomOpenAPISpec.add_embedding_request_schema(openapi_schema)
|
||||
|
||||
# Add responses API request schema
|
||||
openapi_schema = CustomOpenAPISpec.add_responses_api_request_schema(
|
||||
openapi_schema
|
||||
)
|
||||
|
||||
return openapi_schema
|
||||
@@ -0,0 +1,832 @@
|
||||
# Start tracing memory allocations
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tracemalloc
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from litellm import get_secret_str
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import PYTHON_GC_THRESHOLD
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Configure garbage collection thresholds from environment variables
|
||||
def configure_gc_thresholds():
|
||||
"""Configure Python garbage collection thresholds from environment variables."""
|
||||
gc_threshold_env = PYTHON_GC_THRESHOLD
|
||||
if gc_threshold_env:
|
||||
try:
|
||||
# Parse threshold string like "1000,50,50"
|
||||
thresholds = [int(x.strip()) for x in gc_threshold_env.split(",")]
|
||||
if len(thresholds) == 3:
|
||||
gc.set_threshold(*thresholds)
|
||||
verbose_proxy_logger.info(f"GC thresholds set to: {thresholds}")
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"GC threshold not set: {gc_threshold_env}. Expected format: 'gen0,gen1,gen2'"
|
||||
)
|
||||
except ValueError as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to parse GC threshold: {gc_threshold_env}. Error: {e}"
|
||||
)
|
||||
|
||||
# Log current thresholds
|
||||
current_thresholds = gc.get_threshold()
|
||||
verbose_proxy_logger.info(
|
||||
f"Current GC thresholds: gen0={current_thresholds[0]}, gen1={current_thresholds[1]}, gen2={current_thresholds[2]}"
|
||||
)
|
||||
|
||||
|
||||
# Initialize GC configuration
|
||||
configure_gc_thresholds()
|
||||
|
||||
|
||||
@router.get("/debug/asyncio-tasks")
|
||||
async def get_active_tasks_stats():
|
||||
"""
|
||||
Returns:
|
||||
total_active_tasks: int
|
||||
by_name: { coroutine_name: count }
|
||||
"""
|
||||
MAX_TASKS_TO_CHECK = 5000
|
||||
# Gather all tasks in this event loop (including this endpoint’s own task).
|
||||
all_tasks = asyncio.all_tasks()
|
||||
|
||||
# Filter out tasks that are already done.
|
||||
active_tasks = [t for t in all_tasks if not t.done()]
|
||||
|
||||
# Count how many active tasks exist, grouped by coroutine function name.
|
||||
counter = Counter()
|
||||
for idx, task in enumerate(active_tasks):
|
||||
# reasonable max circuit breaker
|
||||
if idx >= MAX_TASKS_TO_CHECK:
|
||||
break
|
||||
coro = task.get_coro()
|
||||
# Derive a human‐readable name from the coroutine:
|
||||
name = (
|
||||
getattr(coro, "__qualname__", None)
|
||||
or getattr(coro, "__name__", None)
|
||||
or repr(coro)
|
||||
)
|
||||
counter[name] += 1
|
||||
|
||||
return {
|
||||
"total_active_tasks": len(active_tasks),
|
||||
"by_name": dict(counter),
|
||||
}
|
||||
|
||||
|
||||
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
|
||||
try:
|
||||
import objgraph # type: ignore
|
||||
|
||||
print("growth of objects") # noqa
|
||||
objgraph.show_growth()
|
||||
print("\n\nMost common types") # noqa
|
||||
objgraph.show_most_common_types()
|
||||
roots = objgraph.get_leaking_objects()
|
||||
print("\n\nLeaking objects") # noqa
|
||||
objgraph.show_most_common_types(objects=roots)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"objgraph not found. Please install objgraph to use this feature."
|
||||
)
|
||||
|
||||
tracemalloc.start(10)
|
||||
|
||||
@router.get("/memory-usage", include_in_schema=False)
|
||||
async def memory_usage():
|
||||
# Take a snapshot of the current memory usage
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics("lineno")
|
||||
verbose_proxy_logger.debug("TOP STATS: %s", top_stats)
|
||||
|
||||
# Get the top 50 memory usage lines
|
||||
top_50 = top_stats[:50]
|
||||
result = []
|
||||
for stat in top_50:
|
||||
result.append(f"{stat.traceback.format(limit=10)}: {stat.size / 1024} KiB")
|
||||
|
||||
return {"top_50_memory_usage": result}
|
||||
|
||||
|
||||
@router.get("/memory-usage-in-mem-cache", include_in_schema=False)
|
||||
async def memory_usage_in_mem_cache(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
# returns the size of all in-memory caches on the proxy server
|
||||
"""
|
||||
1. user_api_key_cache
|
||||
2. router_cache
|
||||
3. proxy_logging_cache
|
||||
4. internal_usage_cache
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if llm_router is None:
|
||||
num_items_in_llm_router_cache = 0
|
||||
else:
|
||||
num_items_in_llm_router_cache = len(
|
||||
llm_router.cache.in_memory_cache.cache_dict
|
||||
) + len(llm_router.cache.in_memory_cache.ttl_dict)
|
||||
|
||||
num_items_in_user_api_key_cache = len(
|
||||
user_api_key_cache.in_memory_cache.cache_dict
|
||||
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
|
||||
|
||||
num_items_in_proxy_logging_obj_cache = len(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
|
||||
|
||||
return {
|
||||
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
|
||||
"num_items_in_llm_router_cache": num_items_in_llm_router_cache,
|
||||
"num_items_in_proxy_logging_obj_cache": num_items_in_proxy_logging_obj_cache,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/memory-usage-in-mem-cache-items", include_in_schema=False)
|
||||
async def memory_usage_in_mem_cache_items(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
# returns the size of all in-memory caches on the proxy server
|
||||
"""
|
||||
1. user_api_key_cache
|
||||
2. router_cache
|
||||
3. proxy_logging_cache
|
||||
4. internal_usage_cache
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if llm_router is None:
|
||||
llm_router_in_memory_cache_dict = {}
|
||||
llm_router_in_memory_ttl_dict = {}
|
||||
else:
|
||||
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
|
||||
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
|
||||
|
||||
return {
|
||||
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
|
||||
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
|
||||
"llm_router_cache": llm_router_in_memory_cache_dict,
|
||||
"llm_router_ttl": llm_router_in_memory_ttl_dict,
|
||||
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
|
||||
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/debug/memory/summary", include_in_schema=False)
|
||||
async def get_memory_summary(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get simplified memory usage summary for the proxy.
|
||||
|
||||
Returns:
|
||||
- worker_pid: Process ID
|
||||
- status: Overall health based on memory usage
|
||||
- memory: Process memory usage and RAM info
|
||||
- caches: Cache item counts and descriptions
|
||||
- garbage_collector: GC status and pending object counts
|
||||
|
||||
Example usage:
|
||||
curl http://localhost:4000/debug/memory/summary -H "Authorization: Bearer sk-1234"
|
||||
|
||||
For detailed analysis, call GET /debug/memory/details
|
||||
For cache management, use the cache management endpoints
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
# Get process memory info
|
||||
process_memory = {}
|
||||
health_status = "healthy"
|
||||
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
memory_mb = memory_info.rss / (1024 * 1024)
|
||||
memory_percent = process.memory_percent()
|
||||
|
||||
process_memory = {
|
||||
"summary": f"{memory_mb:.1f} MB ({memory_percent:.1f}% of system memory)",
|
||||
"ram_usage_mb": round(memory_mb, 2),
|
||||
"system_memory_percent": round(memory_percent, 2),
|
||||
}
|
||||
|
||||
# Check memory health status
|
||||
if memory_percent > 80:
|
||||
health_status = "critical"
|
||||
elif memory_percent > 60:
|
||||
health_status = "warning"
|
||||
else:
|
||||
health_status = "healthy"
|
||||
|
||||
except ImportError:
|
||||
process_memory[
|
||||
"error"
|
||||
] = "Install psutil for memory monitoring: pip install psutil"
|
||||
except Exception as e:
|
||||
process_memory["error"] = str(e)
|
||||
|
||||
# Get cache information
|
||||
caches: Dict[str, Any] = {}
|
||||
total_cache_items = 0
|
||||
|
||||
try:
|
||||
# User API key cache
|
||||
user_cache_items = len(user_api_key_cache.in_memory_cache.cache_dict)
|
||||
total_cache_items += user_cache_items
|
||||
caches["user_api_keys"] = {
|
||||
"count": user_cache_items,
|
||||
"count_readable": f"{user_cache_items:,}",
|
||||
"what_it_stores": "Validated API keys for faster authentication",
|
||||
}
|
||||
|
||||
# Router cache
|
||||
if llm_router is not None:
|
||||
router_cache_items = len(llm_router.cache.in_memory_cache.cache_dict)
|
||||
total_cache_items += router_cache_items
|
||||
caches["llm_responses"] = {
|
||||
"count": router_cache_items,
|
||||
"count_readable": f"{router_cache_items:,}",
|
||||
"what_it_stores": "LLM responses for identical requests",
|
||||
}
|
||||
|
||||
# Proxy logging cache
|
||||
logging_cache_items = len(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
)
|
||||
total_cache_items += logging_cache_items
|
||||
caches["usage_tracking"] = {
|
||||
"count": logging_cache_items,
|
||||
"count_readable": f"{logging_cache_items:,}",
|
||||
"what_it_stores": "Usage metrics before database write",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
caches["error"] = str(e)
|
||||
|
||||
# Get garbage collector stats
|
||||
gc_enabled = gc.isenabled()
|
||||
objects_pending = gc.get_count()[0]
|
||||
uncollectable = len(gc.garbage)
|
||||
|
||||
gc_info = {
|
||||
"status": "enabled" if gc_enabled else "disabled",
|
||||
"objects_awaiting_collection": objects_pending,
|
||||
}
|
||||
|
||||
# Add warning if garbage collection issues detected
|
||||
if uncollectable > 0:
|
||||
gc_info[
|
||||
"warning"
|
||||
] = f"{uncollectable} uncollectable objects (possible memory leak)"
|
||||
|
||||
return {
|
||||
"worker_pid": os.getpid(),
|
||||
"status": health_status,
|
||||
"memory": process_memory,
|
||||
"caches": {
|
||||
"total_items": total_cache_items,
|
||||
"breakdown": caches,
|
||||
},
|
||||
"garbage_collector": gc_info,
|
||||
}
|
||||
|
||||
|
||||
def _get_gc_statistics() -> Dict[str, Any]:
|
||||
"""Get garbage collector statistics."""
|
||||
return {
|
||||
"enabled": gc.isenabled(),
|
||||
"thresholds": {
|
||||
"generation_0": gc.get_threshold()[0],
|
||||
"generation_1": gc.get_threshold()[1],
|
||||
"generation_2": gc.get_threshold()[2],
|
||||
"explanation": "Number of allocations before automatic collection for each generation",
|
||||
},
|
||||
"current_counts": {
|
||||
"generation_0": gc.get_count()[0],
|
||||
"generation_1": gc.get_count()[1],
|
||||
"generation_2": gc.get_count()[2],
|
||||
"explanation": "Current number of allocated objects in each generation",
|
||||
},
|
||||
"collection_history": [
|
||||
{
|
||||
"generation": i,
|
||||
"total_collections": stat["collections"],
|
||||
"total_collected": stat["collected"],
|
||||
"uncollectable": stat["uncollectable"],
|
||||
}
|
||||
for i, stat in enumerate(gc.get_stats())
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _get_object_type_counts(top_n: int) -> Tuple[int, List[Dict[str, Any]]]:
|
||||
"""Count objects by type and return total count and top N types."""
|
||||
type_counts: Counter = Counter()
|
||||
total_objects = 0
|
||||
|
||||
for obj in gc.get_objects():
|
||||
total_objects += 1
|
||||
obj_type = type(obj).__name__
|
||||
type_counts[obj_type] += 1
|
||||
|
||||
top_object_types = [
|
||||
{"type": obj_type, "count": count, "count_readable": f"{count:,}"}
|
||||
for obj_type, count in type_counts.most_common(top_n)
|
||||
]
|
||||
|
||||
return total_objects, top_object_types
|
||||
|
||||
|
||||
def _get_uncollectable_objects_info() -> Dict[str, Any]:
|
||||
"""Get information about uncollectable objects (potential memory leaks)."""
|
||||
uncollectable = gc.garbage
|
||||
return {
|
||||
"count": len(uncollectable),
|
||||
"sample_types": [type(obj).__name__ for obj in uncollectable[:10]],
|
||||
"warning": "If count > 0, you may have reference cycles preventing garbage collection"
|
||||
if len(uncollectable) > 0
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
def _get_cache_memory_stats(
|
||||
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate memory usage for all caches."""
|
||||
cache_stats: Dict[str, Any] = {}
|
||||
try:
|
||||
# User API key cache
|
||||
user_cache_size = sys.getsizeof(user_api_key_cache.in_memory_cache.cache_dict)
|
||||
user_ttl_size = sys.getsizeof(user_api_key_cache.in_memory_cache.ttl_dict)
|
||||
cache_stats["user_api_key_cache"] = {
|
||||
"num_items": len(user_api_key_cache.in_memory_cache.cache_dict),
|
||||
"cache_dict_size_bytes": user_cache_size,
|
||||
"ttl_dict_size_bytes": user_ttl_size,
|
||||
"total_size_mb": round(
|
||||
(user_cache_size + user_ttl_size) / (1024 * 1024), 2
|
||||
),
|
||||
}
|
||||
|
||||
# Router cache
|
||||
if llm_router is not None:
|
||||
router_cache_size = sys.getsizeof(
|
||||
llm_router.cache.in_memory_cache.cache_dict
|
||||
)
|
||||
router_ttl_size = sys.getsizeof(llm_router.cache.in_memory_cache.ttl_dict)
|
||||
cache_stats["llm_router_cache"] = {
|
||||
"num_items": len(llm_router.cache.in_memory_cache.cache_dict),
|
||||
"cache_dict_size_bytes": router_cache_size,
|
||||
"ttl_dict_size_bytes": router_ttl_size,
|
||||
"total_size_mb": round(
|
||||
(router_cache_size + router_ttl_size) / (1024 * 1024), 2
|
||||
),
|
||||
}
|
||||
|
||||
# Proxy logging cache
|
||||
logging_cache_size = sys.getsizeof(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
)
|
||||
logging_ttl_size = sys.getsizeof(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict
|
||||
)
|
||||
cache_stats["proxy_logging_cache"] = {
|
||||
"num_items": len(
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
),
|
||||
"cache_dict_size_bytes": logging_cache_size,
|
||||
"ttl_dict_size_bytes": logging_ttl_size,
|
||||
"total_size_mb": round(
|
||||
(logging_cache_size + logging_ttl_size) / (1024 * 1024), 2
|
||||
),
|
||||
}
|
||||
|
||||
# Redis cache info
|
||||
if redis_usage_cache is not None:
|
||||
cache_stats["redis_usage_cache"] = {
|
||||
"enabled": True,
|
||||
"cache_type": type(redis_usage_cache).__name__,
|
||||
}
|
||||
# Try to get Redis connection pool info if available
|
||||
try:
|
||||
if (
|
||||
hasattr(redis_usage_cache, "redis_client")
|
||||
and redis_usage_cache.redis_client
|
||||
):
|
||||
if hasattr(redis_usage_cache.redis_client, "connection_pool"):
|
||||
pool_info = redis_usage_cache.redis_client.connection_pool # type: ignore
|
||||
cache_stats["redis_usage_cache"]["connection_pool"] = {
|
||||
"max_connections": pool_info.max_connections
|
||||
if hasattr(pool_info, "max_connections")
|
||||
else None,
|
||||
"connection_class": pool_info.connection_class.__name__
|
||||
if hasattr(pool_info, "connection_class")
|
||||
else None,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting Redis pool info: {e}")
|
||||
else:
|
||||
cache_stats["redis_usage_cache"] = {"enabled": False}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error calculating cache stats: {e}")
|
||||
cache_stats["error"] = str(e)
|
||||
|
||||
return cache_stats
|
||||
|
||||
|
||||
def _get_router_memory_stats(llm_router) -> Dict[str, Any]:
|
||||
"""Get memory usage statistics for LiteLLM router."""
|
||||
litellm_router_memory: Dict[str, Any] = {}
|
||||
try:
|
||||
if llm_router is not None:
|
||||
# Model list memory size
|
||||
if hasattr(llm_router, "model_list") and llm_router.model_list:
|
||||
model_list_size = sys.getsizeof(llm_router.model_list)
|
||||
litellm_router_memory["model_list"] = {
|
||||
"num_models": len(llm_router.model_list),
|
||||
"size_bytes": model_list_size,
|
||||
"size_mb": round(model_list_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Model names set
|
||||
if hasattr(llm_router, "model_names") and llm_router.model_names:
|
||||
model_names_size = sys.getsizeof(llm_router.model_names)
|
||||
litellm_router_memory["model_names_set"] = {
|
||||
"num_model_groups": len(llm_router.model_names),
|
||||
"size_bytes": model_names_size,
|
||||
"size_mb": round(model_names_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Deployment names list
|
||||
if hasattr(llm_router, "deployment_names") and llm_router.deployment_names:
|
||||
deployment_names_size = sys.getsizeof(llm_router.deployment_names)
|
||||
litellm_router_memory["deployment_names"] = {
|
||||
"num_deployments": len(llm_router.deployment_names),
|
||||
"size_bytes": deployment_names_size,
|
||||
"size_mb": round(deployment_names_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Deployment latency map
|
||||
if (
|
||||
hasattr(llm_router, "deployment_latency_map")
|
||||
and llm_router.deployment_latency_map
|
||||
):
|
||||
latency_map_size = sys.getsizeof(llm_router.deployment_latency_map)
|
||||
litellm_router_memory["deployment_latency_map"] = {
|
||||
"num_tracked_deployments": len(llm_router.deployment_latency_map),
|
||||
"size_bytes": latency_map_size,
|
||||
"size_mb": round(latency_map_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Fallback configuration
|
||||
if hasattr(llm_router, "fallbacks") and llm_router.fallbacks:
|
||||
fallbacks_size = sys.getsizeof(llm_router.fallbacks)
|
||||
litellm_router_memory["fallbacks"] = {
|
||||
"num_fallback_configs": len(llm_router.fallbacks),
|
||||
"size_bytes": fallbacks_size,
|
||||
"size_mb": round(fallbacks_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
# Total router object size
|
||||
router_obj_size = sys.getsizeof(llm_router)
|
||||
litellm_router_memory["router_object"] = {
|
||||
"size_bytes": router_obj_size,
|
||||
"size_mb": round(router_obj_size / (1024 * 1024), 4),
|
||||
}
|
||||
|
||||
else:
|
||||
litellm_router_memory = {"note": "Router not initialized"}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting router memory info: {e}")
|
||||
litellm_router_memory = {"error": str(e)}
|
||||
|
||||
return litellm_router_memory
|
||||
|
||||
|
||||
def _get_process_memory_info(
|
||||
worker_pid: int, include_process_info: bool
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get process-level memory information using psutil."""
|
||||
if not include_process_info:
|
||||
return None
|
||||
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
ram_usage_mb = round(memory_info.rss / (1024 * 1024), 2)
|
||||
virtual_memory_mb = round(memory_info.vms / (1024 * 1024), 2)
|
||||
memory_percent = round(process.memory_percent(), 2)
|
||||
|
||||
return {
|
||||
"pid": worker_pid,
|
||||
"summary": f"Worker PID {worker_pid} using {ram_usage_mb:.1f} MB of RAM ({memory_percent:.1f}% of system memory)",
|
||||
"ram_usage": {
|
||||
"megabytes": ram_usage_mb,
|
||||
"description": "Actual physical RAM used by this process",
|
||||
},
|
||||
"virtual_memory": {
|
||||
"megabytes": virtual_memory_mb,
|
||||
"description": "Total virtual memory allocated (includes swapped memory)",
|
||||
},
|
||||
"system_memory_percent": {
|
||||
"percent": memory_percent,
|
||||
"description": "Percentage of total system RAM being used",
|
||||
},
|
||||
"open_file_handles": {
|
||||
"count": process.num_fds()
|
||||
if hasattr(process, "num_fds")
|
||||
else "N/A (Windows)",
|
||||
"description": "Number of open file descriptors/handles",
|
||||
},
|
||||
"threads": {
|
||||
"count": process.num_threads(),
|
||||
"description": "Number of active threads in this process",
|
||||
},
|
||||
}
|
||||
except ImportError:
|
||||
return {
|
||||
"pid": worker_pid,
|
||||
"error": "psutil not installed. Install with: pip install psutil",
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting process info: {e}")
|
||||
return {"pid": worker_pid, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/debug/memory/details", include_in_schema=False)
|
||||
async def get_memory_details(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
top_n: int = Query(20, description="Number of top object types to return"),
|
||||
include_process_info: bool = Query(True, description="Include process memory info"),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get detailed memory diagnostics for deep debugging.
|
||||
|
||||
Returns:
|
||||
- worker_pid: Process ID
|
||||
- process_memory: RAM usage, virtual memory, file handles, threads
|
||||
- garbage_collector: GC thresholds, counts, collection history
|
||||
- objects: Total tracked objects and top object types
|
||||
- uncollectable: Objects that can't be garbage collected (potential leaks)
|
||||
- cache_memory: Memory usage of user_api_key, router, and logging caches
|
||||
- router_memory: Memory usage of router components (model_list, deployment_names, etc.)
|
||||
|
||||
Query Parameters:
|
||||
- top_n: Number of top object types to return (default: 20)
|
||||
- include_process_info: Include process-level memory info using psutil (default: true)
|
||||
|
||||
Example usage:
|
||||
curl "http://localhost:4000/debug/memory/details?top_n=30" -H "Authorization: Bearer sk-1234"
|
||||
|
||||
All memory sizes are reported in both bytes and MB.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
redis_usage_cache,
|
||||
)
|
||||
|
||||
worker_pid = os.getpid()
|
||||
|
||||
# Collect all diagnostics using helper functions
|
||||
gc_stats = _get_gc_statistics()
|
||||
total_objects, top_object_types = _get_object_type_counts(top_n)
|
||||
uncollectable_info = _get_uncollectable_objects_info()
|
||||
cache_stats = _get_cache_memory_stats(
|
||||
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
|
||||
)
|
||||
litellm_router_memory = _get_router_memory_stats(llm_router)
|
||||
process_info = _get_process_memory_info(worker_pid, include_process_info)
|
||||
|
||||
return {
|
||||
"worker_pid": worker_pid,
|
||||
"process_memory": process_info,
|
||||
"garbage_collector": gc_stats,
|
||||
"objects": {
|
||||
"total_tracked": total_objects,
|
||||
"total_tracked_readable": f"{total_objects:,}",
|
||||
"top_types": top_object_types,
|
||||
},
|
||||
"uncollectable": uncollectable_info,
|
||||
"cache_memory": cache_stats,
|
||||
"router_memory": litellm_router_memory,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/debug/memory/gc/configure", include_in_schema=False)
|
||||
async def configure_gc_thresholds_endpoint(
|
||||
_: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
generation_0: int = Query(700, description="Generation 0 threshold (default: 700)"),
|
||||
generation_1: int = Query(10, description="Generation 1 threshold (default: 10)"),
|
||||
generation_2: int = Query(10, description="Generation 2 threshold (default: 10)"),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Configure Python garbage collection thresholds.
|
||||
|
||||
Lower thresholds mean more frequent GC cycles (less memory, more CPU overhead).
|
||||
Higher thresholds mean less frequent GC cycles (more memory, less CPU overhead).
|
||||
|
||||
Returns:
|
||||
- message: Confirmation message
|
||||
- previous_thresholds: Old threshold values
|
||||
- new_thresholds: New threshold values
|
||||
- objects_awaiting_collection: Current object count in gen-0
|
||||
- tip: Hint about when next collection will occur
|
||||
|
||||
Query Parameters:
|
||||
- generation_0: Number of allocations before gen-0 collection (default: 700)
|
||||
- generation_1: Number of gen-0 collections before gen-1 collection (default: 10)
|
||||
- generation_2: Number of gen-1 collections before gen-2 collection (default: 10)
|
||||
|
||||
Example for more aggressive collection:
|
||||
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=500" -H "Authorization: Bearer sk-1234"
|
||||
|
||||
Example for less aggressive collection:
|
||||
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=1000" -H "Authorization: Bearer sk-1234"
|
||||
|
||||
Monitor memory usage with GET /debug/memory/summary after changes.
|
||||
"""
|
||||
# Get current thresholds for logging
|
||||
old_thresholds = gc.get_threshold()
|
||||
|
||||
# Set new thresholds with error handling
|
||||
try:
|
||||
gc.set_threshold(generation_0, generation_1, generation_2)
|
||||
verbose_proxy_logger.info(
|
||||
f"GC thresholds updated from {old_thresholds} to "
|
||||
f"({generation_0}, {generation_1}, {generation_2})"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Failed to set GC thresholds: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to set GC thresholds: {str(e)}"
|
||||
)
|
||||
|
||||
# Get current object count to show immediate impact
|
||||
current_count = gc.get_count()[0]
|
||||
|
||||
return {
|
||||
"message": "GC thresholds updated",
|
||||
"previous_thresholds": f"{old_thresholds[0]}, {old_thresholds[1]}, {old_thresholds[2]}",
|
||||
"new_thresholds": f"{generation_0}, {generation_1}, {generation_2}",
|
||||
"objects_awaiting_collection": current_count,
|
||||
"tip": f"Next collection will run after {generation_0 - current_count} more allocations",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/otel-spans", include_in_schema=False)
|
||||
async def get_otel_spans():
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is None:
|
||||
return {
|
||||
"otel_spans": [],
|
||||
"spans_grouped_by_parent": {},
|
||||
"most_recent_parent": None,
|
||||
}
|
||||
|
||||
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
|
||||
if hasattr(otel_exporter, "get_finished_spans"):
|
||||
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
|
||||
else:
|
||||
recorded_spans = []
|
||||
|
||||
print("Spans: ", recorded_spans) # noqa
|
||||
|
||||
most_recent_parent = None
|
||||
most_recent_start_time = 1000000
|
||||
spans_grouped_by_parent = {}
|
||||
for span in recorded_spans:
|
||||
if span.parent is not None:
|
||||
parent_trace_id = span.parent.trace_id
|
||||
if parent_trace_id not in spans_grouped_by_parent:
|
||||
spans_grouped_by_parent[parent_trace_id] = []
|
||||
spans_grouped_by_parent[parent_trace_id].append(span.name)
|
||||
|
||||
# check time of span
|
||||
if span.start_time > most_recent_start_time:
|
||||
most_recent_parent = parent_trace_id
|
||||
most_recent_start_time = span.start_time
|
||||
|
||||
# these are otel spans - get the span name
|
||||
span_names = [span.name for span in recorded_spans]
|
||||
return {
|
||||
"otel_spans": span_names,
|
||||
"spans_grouped_by_parent": spans_grouped_by_parent,
|
||||
"most_recent_parent": most_recent_parent,
|
||||
}
|
||||
|
||||
|
||||
# Helper functions for debugging
|
||||
def init_verbose_loggers():
|
||||
try:
|
||||
worker_config = get_secret_str("WORKER_CONFIG")
|
||||
# if not, assume it's a json string
|
||||
if worker_config is None:
|
||||
return
|
||||
if os.path.isfile(worker_config):
|
||||
return
|
||||
_settings = json.loads(worker_config)
|
||||
if not isinstance(_settings, dict):
|
||||
return
|
||||
|
||||
debug = _settings.get("debug", None)
|
||||
detailed_debug = _settings.get("detailed_debug", None)
|
||||
if debug is True: # this needs to be first, so users can see Router init debugg
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.INFO
|
||||
) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
||||
if detailed_debug is True:
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set router logs to debug
|
||||
verbose_proxy_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set proxy logs to debug
|
||||
elif debug is False and detailed_debug is False:
|
||||
# users can control proxy debugging using env variable = 'LITELLM_LOG'
|
||||
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
||||
if litellm_log_setting is not None:
|
||||
if litellm_log_setting.upper() == "INFO":
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.INFO
|
||||
) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(
|
||||
level=logging.INFO
|
||||
) # set proxy logs to info
|
||||
elif litellm_log_setting.upper() == "DEBUG":
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set proxy logs to debug
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to init verbose loggers: {str(e)}")
|
||||
@@ -0,0 +1,122 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def _get_salt_key():
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
salt_key = os.getenv("LITELLM_SALT_KEY", None)
|
||||
|
||||
if salt_key is None:
|
||||
salt_key = master_key
|
||||
|
||||
return salt_key
|
||||
|
||||
|
||||
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
|
||||
signing_key = new_encryption_key or _get_salt_key()
|
||||
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore
|
||||
# Use urlsafe_b64encode for URL-safe base64 encoding (replaces + with - and / with _)
|
||||
encrypted_value = base64.urlsafe_b64encode(encrypted_value).decode("utf-8")
|
||||
|
||||
return encrypted_value
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Invalid value type passed to encrypt_value: {type(value)} for Value: {value}\n Value must be a string"
|
||||
)
|
||||
# if it's not a string - do not encrypt it and return the value
|
||||
return value
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def decrypt_value_helper(
|
||||
value: str,
|
||||
key: str, # this is just for debug purposes, showing the k,v pair that's invalid. not a signing key.
|
||||
exception_type: Literal["debug", "error"] = "error",
|
||||
return_original_value: bool = False,
|
||||
):
|
||||
signing_key = _get_salt_key()
|
||||
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
# Try URL-safe base64 decoding first (new format)
|
||||
# Fall back to standard base64 decoding for backwards compatibility (old format)
|
||||
try:
|
||||
decoded_b64 = base64.urlsafe_b64decode(value)
|
||||
except Exception:
|
||||
# If URL-safe decoding fails, try standard base64 decoding for backwards compatibility
|
||||
decoded_b64 = base64.b64decode(value)
|
||||
|
||||
value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore
|
||||
return value
|
||||
|
||||
# if it's not str - do not decrypt it, return the value
|
||||
return value
|
||||
except Exception as e:
|
||||
error_message = f"Error decrypting value for key: {key}, Did your master_key/salt key change recently? \nError: {str(e)}\nSet permanent salt key - https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
|
||||
if exception_type == "debug":
|
||||
verbose_proxy_logger.debug(error_message)
|
||||
return value if return_original_value else None
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Unable to decrypt value={value} for key: {key}, returning None"
|
||||
)
|
||||
if return_original_value:
|
||||
return value
|
||||
else:
|
||||
verbose_proxy_logger.exception(error_message)
|
||||
# [Non-Blocking Exception. - this should not block decrypting other values]
|
||||
return None
|
||||
|
||||
|
||||
def encrypt_value(value: str, signing_key: str):
|
||||
import hashlib
|
||||
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
|
||||
# get 32 byte master key #
|
||||
hash_object = hashlib.sha256(signing_key.encode())
|
||||
hash_bytes = hash_object.digest()
|
||||
|
||||
# initialize secret box #
|
||||
box = nacl.secret.SecretBox(hash_bytes)
|
||||
|
||||
# encode message #
|
||||
value_bytes = value.encode("utf-8")
|
||||
|
||||
encrypted = box.encrypt(value_bytes)
|
||||
|
||||
return encrypted
|
||||
|
||||
|
||||
def decrypt_value(value: bytes, signing_key: str) -> str:
|
||||
import hashlib
|
||||
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
|
||||
# get 32 byte master key #
|
||||
hash_object = hashlib.sha256(signing_key.encode())
|
||||
hash_bytes = hash_object.digest()
|
||||
|
||||
# initialize secret box #
|
||||
box = nacl.secret.SecretBox(hash_bytes)
|
||||
|
||||
# Convert the bytes object to a string
|
||||
try:
|
||||
if len(value) == 0:
|
||||
return ""
|
||||
|
||||
plaintext = box.decrypt(value)
|
||||
plaintext = plaintext.decode("utf-8") # type: ignore
|
||||
return plaintext # type: ignore
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Utility class for getting routes from a FastAPI app.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class GetRoutes:
|
||||
@staticmethod
|
||||
def get_app_routes(
|
||||
route: BaseRoute,
|
||||
endpoint_route: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get routes for a regular route.
|
||||
"""
|
||||
routes: List[Dict[str, Any]] = []
|
||||
route_info = {
|
||||
"path": getattr(route, "path", None),
|
||||
"methods": getattr(route, "methods", None),
|
||||
"name": getattr(route, "name", None),
|
||||
"endpoint": (
|
||||
endpoint_route.__name__ if getattr(route, "endpoint", None) else None
|
||||
),
|
||||
}
|
||||
routes.append(route_info)
|
||||
return routes
|
||||
|
||||
@staticmethod
|
||||
def get_routes_for_mounted_app(
|
||||
route: BaseRoute,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get routes for a mounted sub-application.
|
||||
"""
|
||||
routes: List[Dict[str, Any]] = []
|
||||
mount_path = getattr(route, "path", "")
|
||||
sub_app = getattr(route, "app", None)
|
||||
if sub_app and hasattr(sub_app, "routes"):
|
||||
for sub_route in sub_app.routes:
|
||||
# Get endpoint - either from endpoint attribute or app attribute
|
||||
endpoint_func = getattr(sub_route, "endpoint", None) or getattr(
|
||||
sub_route, "app", None
|
||||
)
|
||||
|
||||
if endpoint_func is not None:
|
||||
sub_route_path = getattr(sub_route, "path", "")
|
||||
full_path = mount_path.rstrip("/") + sub_route_path
|
||||
|
||||
route_info = {
|
||||
"path": full_path,
|
||||
"methods": getattr(sub_route, "methods", ["GET", "POST"]),
|
||||
"name": getattr(sub_route, "name", None),
|
||||
"endpoint": GetRoutes._safe_get_endpoint_name(endpoint_func),
|
||||
"mounted_app": True,
|
||||
}
|
||||
routes.append(route_info)
|
||||
return routes
|
||||
|
||||
@staticmethod
|
||||
def _safe_get_endpoint_name(endpoint_function: Any) -> Optional[str]:
|
||||
"""
|
||||
Safely get the name of the endpoint function.
|
||||
"""
|
||||
try:
|
||||
if hasattr(endpoint_function, "__name__"):
|
||||
return getattr(endpoint_function, "__name__")
|
||||
elif hasattr(endpoint_function, "__class__") and hasattr(
|
||||
endpoint_function.__class__, "__name__"
|
||||
):
|
||||
return getattr(endpoint_function.__class__, "__name__")
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
f"Error getting endpoint name for route: {endpoint_function}"
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,207 @@
|
||||
from litellm.proxy.common_utils.banner import LITELLM_BANNER
|
||||
|
||||
|
||||
def render_cli_sso_success_page() -> str:
|
||||
"""
|
||||
Renders the CLI SSO authentication success page with minimal styling
|
||||
|
||||
Returns:
|
||||
str: HTML content for the success page
|
||||
"""
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>CLI Authentication Successful - LiteLLM</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f8fafc;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
color: #1e293b;
|
||||
}}
|
||||
|
||||
.container {{
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
width: 450px;
|
||||
max-width: 100%;
|
||||
text-align: center;
|
||||
}}
|
||||
|
||||
.logo-container {{
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
|
||||
.logo {{
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}}
|
||||
|
||||
h1 {{
|
||||
margin: 0 0 10px;
|
||||
color: #1e293b;
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
}}
|
||||
|
||||
.subtitle {{
|
||||
color: #64748b;
|
||||
margin: 0 0 30px;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.banner {{
|
||||
background-color: #f8fafc;
|
||||
color: #334155;
|
||||
font-family: 'Courier New', Consolas, monospace;
|
||||
font-size: 10px;
|
||||
line-height: 1.1;
|
||||
white-space: pre;
|
||||
padding: 20px;
|
||||
border-radius: 6px;
|
||||
margin: 20px 0;
|
||||
text-align: center;
|
||||
border: 1px solid #e2e8f0;
|
||||
overflow-x: auto;
|
||||
}}
|
||||
|
||||
.success-box {{
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
|
||||
.success-header {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e293b;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.success-header svg {{
|
||||
margin-right: 8px;
|
||||
}}
|
||||
|
||||
.success-box p {{
|
||||
color: #64748b;
|
||||
margin: 8px 0;
|
||||
line-height: 1.5;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
.instructions {{
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
|
||||
.instructions-header {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e293b;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.instructions-header svg {{
|
||||
margin-right: 8px;
|
||||
}}
|
||||
|
||||
.instructions p {{
|
||||
color: #64748b;
|
||||
margin: 8px 0;
|
||||
line-height: 1.5;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
.countdown {{
|
||||
color: #64748b;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
padding: 12px;
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #e2e8f0;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo-container">
|
||||
<div class="logo">
|
||||
🚅 LiteLLM
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="banner">{LITELLM_BANNER}</div>
|
||||
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p class="subtitle">Your CLI authentication is complete.</p>
|
||||
|
||||
<div class="success-box">
|
||||
<div class="success-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M9 12l2 2 4-4"></path>
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
</svg>
|
||||
CLI Authentication Complete
|
||||
</div>
|
||||
<p>Your LiteLLM CLI has been successfully authenticated and is ready to use.</p>
|
||||
</div>
|
||||
|
||||
<div class="instructions">
|
||||
<div class="instructions-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
<line x1="12" y1="16" x2="12" y2="12"></line>
|
||||
<line x1="12" y1="8" x2="12.01" y2="8"></line>
|
||||
</svg>
|
||||
Next Steps
|
||||
</div>
|
||||
<p>Return to your terminal - the CLI will automatically detect the successful authentication.</p>
|
||||
<p>You can now use LiteLLM CLI commands with your authenticated session.</p>
|
||||
</div>
|
||||
|
||||
<div class="countdown" id="countdown">This window will close in 3 seconds...</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let seconds = 3;
|
||||
const countdownElement = document.getElementById('countdown');
|
||||
|
||||
const countdown = setInterval(function() {{
|
||||
seconds--;
|
||||
if (seconds > 0) {{
|
||||
countdownElement.textContent = `This window will close in ${{seconds}} second${{seconds === 1 ? '' : 's'}}...`;
|
||||
}} else {{
|
||||
countdownElement.textContent = 'Closing...';
|
||||
clearInterval(countdown);
|
||||
window.close();
|
||||
}}
|
||||
}}, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html_content
|
||||
@@ -0,0 +1,284 @@
|
||||
# JWT display template for SSO debug callback
|
||||
jwt_display_template = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>LiteLLM SSO Debug - JWT Information</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f8fafc;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
width: 800px;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.logo {
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}
|
||||
|
||||
h2 {
|
||||
margin: 0 0 10px;
|
||||
color: #1e293b;
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: #64748b;
|
||||
margin: 0 0 20px;
|
||||
font-size: 16px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.info-box {
|
||||
background-color: #f1f5f9;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #2563eb;
|
||||
}
|
||||
|
||||
.success-box {
|
||||
background-color: #f0fdf4;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #16a34a;
|
||||
}
|
||||
|
||||
.info-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e40af;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.success-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #166534;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.info-header svg, .success-header svg {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.data-container {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.data-row {
|
||||
display: flex;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
padding: 12px 0;
|
||||
}
|
||||
|
||||
.data-row:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.data-label {
|
||||
font-weight: 500;
|
||||
color: #334155;
|
||||
width: 180px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.data-value {
|
||||
color: #475569;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.jwt-container {
|
||||
background-color: #f8fafc;
|
||||
border-radius: 6px;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
overflow-x: auto;
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.jwt-text {
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
margin: 0;
|
||||
color: #334155;
|
||||
}
|
||||
|
||||
.back-button {
|
||||
display: inline-block;
|
||||
background-color: #6466E9;
|
||||
color: #fff;
|
||||
text-decoration: none;
|
||||
padding: 10px 16px;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.back-button:hover {
|
||||
background-color: #4138C2;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.buttons {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.copy-button {
|
||||
background-color: #e2e8f0;
|
||||
color: #334155;
|
||||
border: none;
|
||||
padding: 8px 12px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.copy-button:hover {
|
||||
background-color: #cbd5e1;
|
||||
}
|
||||
|
||||
.copy-button svg {
|
||||
margin-right: 6px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo-container">
|
||||
<div class="logo">
|
||||
🚅 LiteLLM
|
||||
</div>
|
||||
</div>
|
||||
<h2>SSO Debug Information</h2>
|
||||
<p class="subtitle">Results from the SSO authentication process.</p>
|
||||
|
||||
<div class="success-box">
|
||||
<div class="success-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"></path>
|
||||
<polyline points="22 4 12 14.01 9 11.01"></polyline>
|
||||
</svg>
|
||||
Authentication Successful
|
||||
</div>
|
||||
<p>The SSO authentication completed successfully. Below is the information returned by the provider.</p>
|
||||
</div>
|
||||
|
||||
<div class="data-container" id="userData">
|
||||
<!-- Data will be inserted here by JavaScript -->
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
<div class="info-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
<line x1="12" y1="16" x2="12" y2="12"></line>
|
||||
<line x1="12" y1="8" x2="12.01" y2="8"></line>
|
||||
</svg>
|
||||
JSON Representation
|
||||
</div>
|
||||
<div class="jwt-container">
|
||||
<pre class="jwt-text" id="jsonData">Loading...</pre>
|
||||
</div>
|
||||
<div class="buttons">
|
||||
<button class="copy-button" onclick="copyToClipboard('jsonData')">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
|
||||
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
|
||||
</svg>
|
||||
Copy to Clipboard
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<a href="/sso/debug/login" class="back-button">
|
||||
Try Another SSO Login
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// This will be populated with the actual data from the server
|
||||
const userData = SSO_DATA;
|
||||
|
||||
function renderUserData() {
|
||||
const container = document.getElementById('userData');
|
||||
const jsonDisplay = document.getElementById('jsonData');
|
||||
|
||||
// Format JSON with indentation for display
|
||||
jsonDisplay.textContent = JSON.stringify(userData, null, 2);
|
||||
|
||||
// Clear container
|
||||
container.innerHTML = '';
|
||||
|
||||
// Add each key-value pair to the UI
|
||||
for (const [key, value] of Object.entries(userData)) {
|
||||
if (typeof value !== 'object' || value === null) {
|
||||
const row = document.createElement('div');
|
||||
row.className = 'data-row';
|
||||
|
||||
const label = document.createElement('div');
|
||||
label.className = 'data-label';
|
||||
label.textContent = key;
|
||||
|
||||
const dataValue = document.createElement('div');
|
||||
dataValue.className = 'data-value';
|
||||
dataValue.textContent = value !== null ? value : 'null';
|
||||
|
||||
row.appendChild(label);
|
||||
row.appendChild(dataValue);
|
||||
container.appendChild(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function copyToClipboard(elementId) {
|
||||
const text = document.getElementById(elementId).textContent;
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
alert('Copied to clipboard!');
|
||||
}).catch(err => {
|
||||
console.error('Could not copy text: ', err);
|
||||
});
|
||||
}
|
||||
|
||||
// Render the data when the page loads
|
||||
document.addEventListener('DOMContentLoaded', renderUserData);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,269 @@
|
||||
import os
|
||||
|
||||
from litellm.proxy.utils import get_custom_url
|
||||
|
||||
url_to_redirect_to = os.getenv("PROXY_BASE_URL", "")
|
||||
server_root_path = os.getenv("SERVER_ROOT_PATH", "")
|
||||
if server_root_path != "":
|
||||
url_to_redirect_to += server_root_path
|
||||
url_to_redirect_to += "/login"
|
||||
new_ui_login_url = get_custom_url("", "ui/login")
|
||||
|
||||
|
||||
def build_ui_login_form(show_deprecation_banner: bool = False) -> str:
|
||||
banner_html = (
|
||||
f"""
|
||||
<div class="deprecation-banner">
|
||||
<strong>Deprecated:</strong> Logging in with username and password on this page is deprecated.
|
||||
Please use the <a href="{new_ui_login_url}">new login page</a> instead.
|
||||
This page will be dedicated to signing in via SSO in the future.
|
||||
</div>
|
||||
"""
|
||||
if show_deprecation_banner
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>LiteLLM Login</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f8fafc;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
color: #333;
|
||||
}}
|
||||
|
||||
form {{
|
||||
background-color: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
width: 450px;
|
||||
max-width: 100%;
|
||||
}}
|
||||
|
||||
.logo-container {{
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}}
|
||||
|
||||
.logo {{
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}}
|
||||
|
||||
h2 {{
|
||||
margin: 0 0 10px;
|
||||
color: #1e293b;
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}}
|
||||
|
||||
.subtitle {{
|
||||
color: #64748b;
|
||||
margin: 0 0 20px;
|
||||
font-size: 16px;
|
||||
text-align: center;
|
||||
}}
|
||||
|
||||
.info-box {{
|
||||
background-color: #f1f5f9;
|
||||
border-radius: 6px;
|
||||
padding: 20px;
|
||||
margin-bottom: 30px;
|
||||
border-left: 4px solid #2563eb;
|
||||
}}
|
||||
|
||||
.info-header {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
color: #1e40af;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
}}
|
||||
|
||||
.info-header svg {{
|
||||
margin-right: 8px;
|
||||
}}
|
||||
|
||||
.info-box p {{
|
||||
color: #475569;
|
||||
margin: 8px 0;
|
||||
line-height: 1.5;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
label {{
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
font-weight: 500;
|
||||
color: #334155;
|
||||
font-size: 14px;
|
||||
}}
|
||||
|
||||
.required {{
|
||||
color: #dc2626;
|
||||
margin-left: 2px;
|
||||
}}
|
||||
|
||||
input[type="text"],
|
||||
input[type="password"] {{
|
||||
width: 100%;
|
||||
padding: 10px 14px;
|
||||
margin-bottom: 20px;
|
||||
box-sizing: border-box;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 6px;
|
||||
font-size: 15px;
|
||||
color: #1e293b;
|
||||
background-color: #fff;
|
||||
transition: border-color 0.2s, box-shadow 0.2s;
|
||||
}}
|
||||
|
||||
input[type="text"]:focus,
|
||||
input[type="password"]:focus {{
|
||||
outline: none;
|
||||
border-color: #3b82f6;
|
||||
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.2);
|
||||
}}
|
||||
|
||||
.toggle-password {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: -15px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
|
||||
.toggle-password input[type="checkbox"] {{
|
||||
margin-right: 8px;
|
||||
vertical-align: middle;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
}}
|
||||
|
||||
.toggle-password label {{
|
||||
margin-bottom: 0;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
line-height: 1;
|
||||
}}
|
||||
|
||||
input[type="submit"] {{
|
||||
background-color: #6466E9;
|
||||
color: #fff;
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
border: none;
|
||||
padding: 10px 16px;
|
||||
transition: background-color 0.2s;
|
||||
border-radius: 6px;
|
||||
margin-top: 10px;
|
||||
font-size: 14px;
|
||||
width: 100%;
|
||||
}}
|
||||
|
||||
input[type="submit"]:hover {{
|
||||
background-color: #4138C2;
|
||||
}}
|
||||
|
||||
a {{
|
||||
color: #3b82f6;
|
||||
text-decoration: none;
|
||||
}}
|
||||
|
||||
a:hover {{
|
||||
text-decoration: underline;
|
||||
}}
|
||||
|
||||
code {{
|
||||
background-color: #f1f5f9;
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
font-size: 13px;
|
||||
color: #334155;
|
||||
}}
|
||||
|
||||
.help-text {{
|
||||
color: #64748b;
|
||||
font-size: 14px;
|
||||
margin-top: -12px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
|
||||
.deprecation-banner {{
|
||||
background-color: #fee2e2;
|
||||
border: 1px solid #ef4444;
|
||||
color: #991b1b;
|
||||
padding: 14px 16px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
|
||||
.deprecation-banner a {{
|
||||
color: #991b1b;
|
||||
font-weight: 600;
|
||||
text-decoration: underline;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<form action="{url_to_redirect_to}" method="post">
|
||||
{banner_html}
|
||||
<div class="logo-container">
|
||||
<div class="logo">
|
||||
🚅 LiteLLM
|
||||
</div>
|
||||
</div>
|
||||
<h2>Login</h2>
|
||||
<p class="subtitle">Access your LiteLLM Admin UI.</p>
|
||||
<div class="info-box">
|
||||
<div class="info-header">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<circle cx="12" cy="12" r="10"></circle>
|
||||
<line x1="12" y1="16" x2="12" y2="12"></line>
|
||||
<line x1="12" y1="8" x2="12.01" y2="8"></line>
|
||||
</svg>
|
||||
Default Credentials
|
||||
</div>
|
||||
<p>By default, Username is <code>admin</code> and Password is your set LiteLLM Proxy <code>MASTER_KEY</code>.</p>
|
||||
<p>Need to set UI credentials or SSO? <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">Check the documentation</a>.</p>
|
||||
</div>
|
||||
<label for="username">Username<span class="required">*</span></label>
|
||||
<input type="text" id="username" name="username" required placeholder="Enter your username" autocomplete="username">
|
||||
|
||||
<label for="password">Password<span class="required">*</span></label>
|
||||
<input type="password" id="password" name="password" required placeholder="Enter your password" autocomplete="current-password">
|
||||
<div class="toggle-password">
|
||||
<input type="checkbox" id="show-password" onclick="togglePasswordVisibility()">
|
||||
<label for="show-password">Show password</label>
|
||||
</div>
|
||||
<input type="submit" value="Login">
|
||||
</form>
|
||||
<script>
|
||||
function togglePasswordVisibility() {{
|
||||
var passwordField = document.getElementById("password");
|
||||
passwordField.type = passwordField.type === "password" ? "text" : "password";
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
html_form = build_ui_login_form(show_deprecation_banner=True)
|
||||
@@ -0,0 +1,522 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Collection, Dict, List, Optional
|
||||
|
||||
import orjson
|
||||
from fastapi import Request, UploadFile, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyException
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_metadata_variable_name_from_kwargs,
|
||||
)
|
||||
from litellm.types.router import Deployment
|
||||
|
||||
|
||||
async def _read_request_body(request: Optional[Request]) -> Dict:
|
||||
"""
|
||||
Safely read the request body and parse it as JSON.
|
||||
|
||||
Parameters:
|
||||
- request: The request object to read the body from
|
||||
|
||||
Returns:
|
||||
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return {}
|
||||
|
||||
# Check if we already read and parsed the body
|
||||
_cached_request_body: Optional[dict] = _safe_get_request_parsed_body(
|
||||
request=request
|
||||
)
|
||||
if _cached_request_body is not None:
|
||||
return _cached_request_body
|
||||
|
||||
_request_headers: dict = _safe_get_request_headers(request=request)
|
||||
content_type = _request_headers.get("content-type", "")
|
||||
|
||||
if "form" in content_type:
|
||||
parsed_body = dict(await request.form())
|
||||
if "metadata" in parsed_body and isinstance(parsed_body["metadata"], str):
|
||||
parsed_body["metadata"] = json.loads(parsed_body["metadata"])
|
||||
else:
|
||||
# Read the request body
|
||||
body = await request.body()
|
||||
|
||||
# Return empty dict if body is empty or None
|
||||
if not body:
|
||||
parsed_body = {}
|
||||
else:
|
||||
try:
|
||||
parsed_body = orjson.loads(body)
|
||||
except orjson.JSONDecodeError as e:
|
||||
# First try the standard json module which is more forgiving
|
||||
# First decode bytes to string if needed
|
||||
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
|
||||
|
||||
# Replace invalid surrogate pairs
|
||||
# This regex finds incomplete surrogate pairs
|
||||
body_str = re.sub(
|
||||
r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str
|
||||
)
|
||||
# This regex finds low surrogates without high surrogates
|
||||
body_str = re.sub(
|
||||
r"(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]", "", body_str
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_body = json.loads(body_str)
|
||||
except json.JSONDecodeError:
|
||||
# If both orjson and json.loads fail, throw a proper error
|
||||
verbose_proxy_logger.error(
|
||||
f"Invalid JSON payload received: {str(e)}"
|
||||
)
|
||||
raise ProxyException(
|
||||
message=f"Invalid JSON payload: {str(e)}",
|
||||
type="invalid_request_error",
|
||||
param="request_body",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Cache the parsed result
|
||||
_safe_set_request_parsed_body(request=request, parsed_body=parsed_body)
|
||||
return parsed_body
|
||||
|
||||
except (json.JSONDecodeError, orjson.JSONDecodeError, ProxyException) as e:
|
||||
# Re-raise ProxyException as-is
|
||||
verbose_proxy_logger.error(f"Invalid JSON payload received: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch unexpected errors to avoid crashes
|
||||
verbose_proxy_logger.exception(
|
||||
"Unexpected error reading request body - {}".format(e)
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]:
|
||||
if request is None:
|
||||
return None
|
||||
if (
|
||||
hasattr(request, "scope")
|
||||
and "parsed_body" in request.scope
|
||||
and isinstance(request.scope["parsed_body"], tuple)
|
||||
):
|
||||
accepted_keys, parsed_body = request.scope["parsed_body"]
|
||||
return {key: parsed_body[key] for key in accepted_keys}
|
||||
return None
|
||||
|
||||
|
||||
def _safe_get_request_query_params(request: Optional[Request]) -> Dict:
|
||||
if request is None:
|
||||
return {}
|
||||
try:
|
||||
if hasattr(request, "query_params"):
|
||||
return dict(request.query_params)
|
||||
return {}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected error reading request query params - {}".format(e)
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def _safe_set_request_parsed_body(
|
||||
request: Optional[Request],
|
||||
parsed_body: dict,
|
||||
) -> None:
|
||||
try:
|
||||
if request is None:
|
||||
return
|
||||
request.scope["parsed_body"] = (tuple(parsed_body.keys()), parsed_body)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected error setting request parsed body - {}".format(e)
|
||||
)
|
||||
|
||||
|
||||
def _safe_get_request_headers(request: Optional[Request]) -> dict:
|
||||
"""
|
||||
[Non-Blocking] Safely get the request headers.
|
||||
Caches the result on request.state to avoid re-creating dict(request.headers) per call.
|
||||
|
||||
Warning: Callers must NOT mutate the returned dict — it is shared across
|
||||
all callers within the same request via the cache.
|
||||
"""
|
||||
if request is None:
|
||||
return {}
|
||||
state = getattr(request, "state", None)
|
||||
cached = getattr(state, "_cached_headers", None)
|
||||
if isinstance(cached, dict):
|
||||
return cached
|
||||
if cached is not None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected cached request headers type - {}".format(type(cached))
|
||||
)
|
||||
try:
|
||||
headers = dict(request.headers)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected error reading request headers - {}".format(e)
|
||||
)
|
||||
headers = {}
|
||||
try:
|
||||
if state is not None:
|
||||
state._cached_headers = headers
|
||||
except Exception:
|
||||
pass # request.state may not be available in all contexts
|
||||
return headers
|
||||
|
||||
|
||||
def check_file_size_under_limit(
|
||||
request_data: dict,
|
||||
file: UploadFile,
|
||||
router_model_names: Collection[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if any files passed in request are under max_file_size_mb
|
||||
|
||||
Returns True -> when file size is under max_file_size_mb limit
|
||||
Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
CommonProxyErrors,
|
||||
ProxyException,
|
||||
llm_router,
|
||||
premium_user,
|
||||
)
|
||||
|
||||
file_contents_size = file.size or 0
|
||||
file_content_size_in_mb = file_contents_size / (1024 * 1024)
|
||||
if "metadata" not in request_data:
|
||||
request_data["metadata"] = {}
|
||||
request_data["metadata"]["file_size_in_mb"] = file_content_size_in_mb
|
||||
max_file_size_mb = None
|
||||
|
||||
if llm_router is not None and request_data["model"] in router_model_names:
|
||||
try:
|
||||
deployment: Optional[
|
||||
Deployment
|
||||
] = llm_router.get_deployment_by_model_group_name(
|
||||
model_group_name=request_data["model"]
|
||||
)
|
||||
if (
|
||||
deployment
|
||||
and deployment.litellm_params is not None
|
||||
and deployment.litellm_params.max_file_size_mb is not None
|
||||
):
|
||||
max_file_size_mb = deployment.litellm_params.max_file_size_mb
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Got error when checking file size: %s", (str(e))
|
||||
)
|
||||
|
||||
if max_file_size_mb is not None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Checking file size, file content size=%s, max_file_size_mb=%s",
|
||||
file_content_size_in_mb,
|
||||
max_file_size_mb,
|
||||
)
|
||||
if not premium_user:
|
||||
raise ProxyException(
|
||||
message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type="bad_request",
|
||||
param="file",
|
||||
)
|
||||
if file_content_size_in_mb > max_file_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type="bad_request",
|
||||
param="file",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_form_data(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Read form data from request
|
||||
|
||||
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||
"""
|
||||
form = await request.form()
|
||||
form_data = dict(form)
|
||||
parsed_form_data: dict[str, Any] = {}
|
||||
for key, value in form_data.items():
|
||||
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
||||
if key.endswith("[]"):
|
||||
clean_key = key[:-2]
|
||||
parsed_form_data.setdefault(clean_key, []).append(value)
|
||||
else:
|
||||
parsed_form_data[key] = value
|
||||
return parsed_form_data
|
||||
|
||||
|
||||
async def convert_upload_files_to_file_data(
|
||||
form_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert FastAPI UploadFile objects to file data tuples for litellm.
|
||||
|
||||
Converts UploadFile objects to tuples of (filename, content, content_type)
|
||||
which is the format expected by httpx and litellm's HTTP handlers.
|
||||
|
||||
Args:
|
||||
form_data: Dictionary containing form data with potential UploadFile objects
|
||||
|
||||
Returns:
|
||||
Dictionary with UploadFile objects converted to file data tuples
|
||||
|
||||
Example:
|
||||
```python
|
||||
form_data = await get_form_data(request)
|
||||
data = await convert_upload_files_to_file_data(form_data)
|
||||
# data["files"] is now [(filename, content, content_type), ...]
|
||||
```
|
||||
"""
|
||||
data = {}
|
||||
for key, value in form_data.items():
|
||||
if isinstance(value, list):
|
||||
# Check if it's a list of UploadFile objects
|
||||
if value and hasattr(value[0], "read"):
|
||||
files = []
|
||||
for f in value:
|
||||
file_content = await f.read()
|
||||
# Create tuple: (filename, content, content_type)
|
||||
files.append((f.filename, file_content, f.content_type))
|
||||
data[key] = files
|
||||
else:
|
||||
data[key] = value
|
||||
elif hasattr(value, "read"):
|
||||
# Single UploadFile object - read and convert to list for consistency
|
||||
file_content = await value.read()
|
||||
data[key] = [(value.filename, file_content, value.content_type)]
|
||||
else:
|
||||
# Regular form field
|
||||
data[key] = value
|
||||
return data
|
||||
|
||||
|
||||
async def get_request_body(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Read the request body and parse it as JSON.
|
||||
"""
|
||||
if request.method == "POST":
|
||||
if request.headers.get("content-type", "") == "application/json":
|
||||
return await _read_request_body(request)
|
||||
elif "multipart/form-data" in request.headers.get(
|
||||
"content-type", ""
|
||||
) or "application/x-www-form-urlencoded" in request.headers.get(
|
||||
"content-type", ""
|
||||
):
|
||||
return await get_form_data(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported content type: {request.headers.get('content-type')}"
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def extract_nested_form_metadata(
|
||||
form_data: Dict[str, Any], prefix: str = "litellm_metadata["
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract nested metadata from form data with bracket notation.
|
||||
|
||||
Handles form data that uses bracket notation to represent nested dictionaries,
|
||||
such as litellm_metadata[spend_logs_metadata][owner] = "value".
|
||||
|
||||
This is commonly encountered when SDKs or clients send form data with nested
|
||||
structures using bracket notation instead of JSON.
|
||||
|
||||
Args:
|
||||
form_data: Dictionary containing form data (from request.form())
|
||||
prefix: The prefix to look for in form keys (default: "litellm_metadata[")
|
||||
|
||||
Returns:
|
||||
Dictionary with nested structure reconstructed from bracket notation
|
||||
|
||||
Example:
|
||||
Input form_data:
|
||||
{
|
||||
"litellm_metadata[spend_logs_metadata][owner]": "john",
|
||||
"litellm_metadata[spend_logs_metadata][team]": "engineering",
|
||||
"litellm_metadata[tags]": "production",
|
||||
"other_field": "value"
|
||||
}
|
||||
|
||||
Output:
|
||||
{
|
||||
"spend_logs_metadata": {
|
||||
"owner": "john",
|
||||
"team": "engineering"
|
||||
},
|
||||
"tags": "production"
|
||||
}
|
||||
"""
|
||||
if not form_data:
|
||||
return {}
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
for key, value in form_data.items():
|
||||
# Skip keys that don't start with the prefix
|
||||
if not isinstance(key, str) or not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Skip UploadFile objects - they should not be in metadata
|
||||
if isinstance(value, UploadFile):
|
||||
verbose_proxy_logger.warning(
|
||||
f"Skipping UploadFile in metadata extraction for key: {key}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Extract the nested path from bracket notation
|
||||
# Example: "litellm_metadata[spend_logs_metadata][owner]" -> ["spend_logs_metadata", "owner"]
|
||||
try:
|
||||
# Remove the prefix and strip trailing ']'
|
||||
path_string = key.replace(prefix, "").rstrip("]")
|
||||
|
||||
# Split by "][" to get individual path parts
|
||||
parts = path_string.split("][")
|
||||
|
||||
if not parts or not parts[0]:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Invalid metadata key format (empty path): {key}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Navigate/create nested dictionary structure
|
||||
current = metadata
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(current, dict):
|
||||
verbose_proxy_logger.warning(
|
||||
f"Cannot create nested path - intermediate value is not a dict at: {part}"
|
||||
)
|
||||
break
|
||||
current = current.setdefault(part, {})
|
||||
else:
|
||||
# Set the final value (only if we didn't break out of the loop)
|
||||
if isinstance(current, dict):
|
||||
current[parts[-1]] = value
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Cannot set value - parent is not a dict for key: {key}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error parsing metadata key '{key}': {str(e)}")
|
||||
continue
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def get_tags_from_request_body(request_body: dict) -> List[str]:
|
||||
"""
|
||||
Extract tags from request body metadata.
|
||||
|
||||
Args:
|
||||
request_body: The request body dictionary
|
||||
|
||||
Returns:
|
||||
List of tag names (strings), empty list if no valid tags found
|
||||
"""
|
||||
metadata_variable_name = get_metadata_variable_name_from_kwargs(request_body)
|
||||
metadata = request_body.get(metadata_variable_name) or {}
|
||||
tags_in_metadata: Any = metadata.get("tags", [])
|
||||
tags_in_request_body: Any = request_body.get("tags", [])
|
||||
combined_tags: List[str] = []
|
||||
|
||||
######################################
|
||||
# Only combine tags if they are lists
|
||||
######################################
|
||||
if isinstance(tags_in_metadata, list):
|
||||
combined_tags.extend(tags_in_metadata)
|
||||
if isinstance(tags_in_request_body, list):
|
||||
combined_tags.extend(tags_in_request_body)
|
||||
######################################
|
||||
return [tag for tag in combined_tags if isinstance(tag, str)]
|
||||
|
||||
|
||||
def populate_request_with_path_params(request_data: dict, request: Request) -> dict:
|
||||
"""
|
||||
Copy FastAPI path params and query params into the request payload so downstream checks
|
||||
(e.g. vector store RBAC, organization RBAC) see them the same way as body params.
|
||||
|
||||
Since path_params may not be available during dependency injection,
|
||||
we parse the URL path directly for known patterns.
|
||||
|
||||
Args:
|
||||
request_data: The request data dictionary to populate
|
||||
request: The FastAPI Request object
|
||||
|
||||
Returns:
|
||||
dict: Updated request_data with path parameters and query parameters added
|
||||
"""
|
||||
# Add query parameters to request_data (for GET requests, etc.)
|
||||
query_params = _safe_get_request_query_params(request)
|
||||
if query_params:
|
||||
for key, value in query_params.items():
|
||||
# Don't overwrite existing values from request body
|
||||
request_data.setdefault(key, value)
|
||||
|
||||
# Try to get path_params if available (sometimes populated by FastAPI)
|
||||
path_params = getattr(request, "path_params", None)
|
||||
if isinstance(path_params, dict) and path_params:
|
||||
for key, value in path_params.items():
|
||||
if key == "vector_store_id":
|
||||
request_data.setdefault("vector_store_id", value)
|
||||
existing_ids = request_data.get("vector_store_ids")
|
||||
if isinstance(existing_ids, list):
|
||||
if value not in existing_ids:
|
||||
existing_ids.append(value)
|
||||
else:
|
||||
request_data["vector_store_ids"] = [value]
|
||||
continue
|
||||
request_data.setdefault(key, value)
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: Found path_params, vector_store_ids={request_data.get('vector_store_ids')}"
|
||||
)
|
||||
return request_data
|
||||
|
||||
# Fallback: parse the URL path directly to extract vector_store_id
|
||||
_add_vector_store_id_from_path(request_data=request_data, request=request)
|
||||
|
||||
return request_data
|
||||
|
||||
|
||||
def _add_vector_store_id_from_path(request_data: dict, request: Request) -> None:
|
||||
"""
|
||||
Parse the request path to find /vector_stores/{vector_store_id}/... segments.
|
||||
|
||||
When found, ensure both vector_store_id and vector_store_ids are populated.
|
||||
|
||||
Args:
|
||||
request_data: The request data dictionary to populate
|
||||
request: The FastAPI Request object
|
||||
"""
|
||||
path = request.url.path
|
||||
vector_store_match = re.search(r"/vector_stores/([^/]+)/", path)
|
||||
if vector_store_match:
|
||||
vector_store_id = vector_store_match.group(1)
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: Extracted vector_store_id={vector_store_id} from path={path}"
|
||||
)
|
||||
request_data.setdefault("vector_store_id", vector_store_id)
|
||||
existing_ids = request_data.get("vector_store_ids")
|
||||
if isinstance(existing_ids, list):
|
||||
if vector_store_id not in existing_ids:
|
||||
existing_ids.append(vector_store_id)
|
||||
else:
|
||||
request_data["vector_store_ids"] = [vector_store_id]
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: Updated request_data with vector_store_ids={request_data.get('vector_store_ids')}"
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
f"populate_request_with_path_params: No vector_store_id present in path={path}"
|
||||
)
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Key Rotation Manager - Automated key rotation based on rotation schedules
|
||||
|
||||
Handles finding keys that need rotation based on their individual schedules.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import (
|
||||
LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
|
||||
LITELLM_KEY_ROTATION_GRACE_PERIOD,
|
||||
)
|
||||
from litellm.proxy._types import (
|
||||
GenerateKeyResponse,
|
||||
LiteLLM_VerificationToken,
|
||||
RegenerateKeyRequest,
|
||||
)
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_calculate_key_rotation_time,
|
||||
regenerate_key_fn,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
class KeyRotationManager:
|
||||
"""
|
||||
Manages automated key rotation based on individual key rotation schedules.
|
||||
"""
|
||||
|
||||
def __init__(self, prisma_client: PrismaClient):
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
async def process_rotations(self):
|
||||
"""
|
||||
Main entry point - find and rotate keys that are due for rotation
|
||||
"""
|
||||
try:
|
||||
verbose_proxy_logger.info("Starting scheduled key rotation check...")
|
||||
|
||||
# Clean up expired deprecated keys first
|
||||
await self._cleanup_expired_deprecated_keys()
|
||||
|
||||
# Find keys that are due for rotation
|
||||
keys_to_rotate = await self._find_keys_needing_rotation()
|
||||
|
||||
if not keys_to_rotate:
|
||||
verbose_proxy_logger.debug("No keys are due for rotation at this time")
|
||||
return
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Found {len(keys_to_rotate)} keys due for rotation"
|
||||
)
|
||||
|
||||
# Rotate each key
|
||||
for key in keys_to_rotate:
|
||||
try:
|
||||
await self._rotate_key(key)
|
||||
key_identifier = key.key_name or (
|
||||
key.token[:8] + "..." if key.token else "unknown"
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully rotated key: {key_identifier}"
|
||||
)
|
||||
except Exception as e:
|
||||
key_identifier = key.key_name or (
|
||||
key.token[:8] + "..." if key.token else "unknown"
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
f"Failed to rotate key {key_identifier}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Key rotation process failed: {e}")
|
||||
|
||||
async def _find_keys_needing_rotation(self) -> List[LiteLLM_VerificationToken]:
|
||||
"""
|
||||
Find keys that are due for rotation based on their key_rotation_at timestamp.
|
||||
|
||||
Logic:
|
||||
- Key has auto_rotate = true
|
||||
- key_rotation_at is null (needs initial setup) OR key_rotation_at <= now
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
keys_with_rotation = (
|
||||
await self.prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"auto_rotate": True, # Only keys marked for auto rotation
|
||||
"OR": [
|
||||
{
|
||||
"key_rotation_at": None
|
||||
}, # Keys that need initial rotation time setup
|
||||
{
|
||||
"key_rotation_at": {"lte": now}
|
||||
}, # Keys where rotation time has passed
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return keys_with_rotation
|
||||
|
||||
async def _cleanup_expired_deprecated_keys(self) -> None:
|
||||
"""
|
||||
Remove deprecated key entries whose revoke_at has passed.
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many(
|
||||
where={"revoke_at": {"lt": now}}
|
||||
)
|
||||
if result > 0:
|
||||
verbose_proxy_logger.debug(
|
||||
"Cleaned up %s expired deprecated key(s)", result
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Deprecated key cleanup skipped (table may not exist): %s", e
|
||||
)
|
||||
|
||||
def _should_rotate_key(self, key: LiteLLM_VerificationToken, now: datetime) -> bool:
|
||||
"""
|
||||
Determine if a key should be rotated based on key_rotation_at timestamp.
|
||||
"""
|
||||
if not key.rotation_interval:
|
||||
return False
|
||||
|
||||
# If key_rotation_at is not set, rotate immediately (and set it)
|
||||
if key.key_rotation_at is None:
|
||||
return True
|
||||
|
||||
# Check if the rotation time has passed
|
||||
return now >= key.key_rotation_at
|
||||
|
||||
async def _rotate_key(self, key: LiteLLM_VerificationToken):
|
||||
"""
|
||||
Rotate a single key using existing regenerate_key_fn and call the rotation hook
|
||||
"""
|
||||
# Create regenerate request with grace period for seamless cutover
|
||||
regenerate_request = RegenerateKeyRequest(
|
||||
key=key.token or "",
|
||||
key_alias=key.key_alias, # Pass key alias to ensure correct secret is updated in AWS Secrets Manager
|
||||
grace_period=LITELLM_KEY_ROTATION_GRACE_PERIOD or None,
|
||||
)
|
||||
|
||||
# Create a system user for key rotation
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
system_user = UserAPIKeyAuth.get_litellm_internal_jobs_user_api_key_auth()
|
||||
|
||||
# Use existing regenerate key function
|
||||
response = await regenerate_key_fn(
|
||||
data=regenerate_request,
|
||||
user_api_key_dict=system_user,
|
||||
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
|
||||
)
|
||||
|
||||
# Update the NEW key with rotation info (regenerate_key_fn creates a new token)
|
||||
if (
|
||||
isinstance(response, GenerateKeyResponse)
|
||||
and response.token_id
|
||||
and key.rotation_interval
|
||||
):
|
||||
# Calculate next rotation time using helper function
|
||||
now = datetime.now(timezone.utc)
|
||||
next_rotation_time = _calculate_key_rotation_time(key.rotation_interval)
|
||||
await self.prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": response.token_id},
|
||||
data={
|
||||
"rotation_count": (key.rotation_count or 0) + 1,
|
||||
"last_rotation_at": now,
|
||||
"key_rotation_at": next_rotation_time,
|
||||
},
|
||||
)
|
||||
|
||||
# Call the existing rotation hook for notifications, audit logs, etc.
|
||||
if isinstance(response, GenerateKeyResponse):
|
||||
await KeyManagementEventHooks.async_key_rotated_hook(
|
||||
data=regenerate_request,
|
||||
existing_key_row=key,
|
||||
response=response,
|
||||
user_api_key_dict=system_user,
|
||||
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
|
||||
)
|
||||
@@ -0,0 +1,178 @@
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def get_file_contents_from_s3(bucket_name, object_key):
|
||||
try:
|
||||
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
from litellm.main import bedrock_converse_chat_completion
|
||||
|
||||
credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token, # Optional, if using temporary credentials
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
|
||||
)
|
||||
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||
verbose_proxy_logger.debug(f"Response: {response}")
|
||||
|
||||
# Read the file contents and directly parse YAML
|
||||
file_contents = response["Body"].read().decode("utf-8")
|
||||
verbose_proxy_logger.debug("File contents retrieved from S3")
|
||||
|
||||
# Parse YAML directly from string
|
||||
config = yaml.safe_load(file_contents)
|
||||
return config
|
||||
|
||||
except ImportError as e:
|
||||
# this is most likely if a user is not using the litellm docker container
|
||||
verbose_proxy_logger.error(f"ImportError: {str(e)}")
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_config_file_contents_from_gcs(bucket_name, object_key):
|
||||
try:
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
|
||||
|
||||
gcs_bucket = GCSBucketLogger(
|
||||
bucket_name=bucket_name,
|
||||
)
|
||||
file_contents = await gcs_bucket.download_gcs_object(object_key)
|
||||
if file_contents is None:
|
||||
raise Exception(f"File contents are None for {object_key}")
|
||||
# file_contentis is a bytes object, so we need to convert it to yaml
|
||||
file_contents = file_contents.decode("utf-8")
|
||||
# convert to yaml
|
||||
config = yaml.safe_load(file_contents)
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def download_python_file_from_s3(
|
||||
bucket_name: str,
|
||||
object_key: str,
|
||||
local_file_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Download a Python file from S3 and save it to local filesystem.
|
||||
|
||||
Args:
|
||||
bucket_name (str): S3 bucket name
|
||||
object_key (str): S3 object key (file path in bucket)
|
||||
local_file_path (str): Local path where file should be saved
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
base_aws_llm = BaseAWSLLM()
|
||||
|
||||
credentials: Credentials = base_aws_llm.get_credentials()
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Downloading Python file {object_key} from S3 bucket: {bucket_name}"
|
||||
)
|
||||
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||
|
||||
# Read the file contents
|
||||
file_contents = response["Body"].read().decode("utf-8")
|
||||
verbose_proxy_logger.debug(f"File contents: {file_contents}")
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
|
||||
# Write to local file
|
||||
with open(local_file_path, "w") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Python file downloaded successfully to {local_file_path}"
|
||||
)
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.error(f"ImportError: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error downloading Python file: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def download_python_file_from_gcs(
|
||||
bucket_name: str,
|
||||
object_key: str,
|
||||
local_file_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Download a Python file from GCS and save it to local filesystem.
|
||||
|
||||
Args:
|
||||
bucket_name (str): GCS bucket name
|
||||
object_key (str): GCS object key (file path in bucket)
|
||||
local_file_path (str): Local path where file should be saved
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
|
||||
|
||||
gcs_bucket = GCSBucketLogger(
|
||||
bucket_name=bucket_name,
|
||||
)
|
||||
file_contents = await gcs_bucket.download_gcs_object(object_key)
|
||||
if file_contents is None:
|
||||
raise Exception(f"File contents are None for {object_key}")
|
||||
|
||||
# file_contents is a bytes object, decode it
|
||||
file_contents = file_contents.decode("utf-8")
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
|
||||
# Write to local file
|
||||
with open(local_file_path, "w") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Python file downloaded successfully to {local_file_path}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error downloading Python file from GCS: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# # Example usage
|
||||
# bucket_name = 'litellm-proxy'
|
||||
# object_key = 'litellm_proxy_config.yaml'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user