feat: profiling middleware

This commit is contained in:
Andy Li
2025-07-25 13:12:59 -07:00
committed by GitHub
parent efbe9e228f
commit 1a6dfa8668
7 changed files with 735 additions and 92 deletions

View File

@@ -12,7 +12,6 @@ from typing import Optional
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__ as letta_version
@@ -35,6 +34,7 @@ from letta.server.db import db_registry
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.middleware import CheckPasswordMiddleware, ProfilerContextMiddleware
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
@@ -42,7 +42,7 @@ from letta.server.rest_api.routers.v1.users import router as users_router # TOD
from letta.server.rest_api.static_files import mount_static_files
from letta.server.rest_api.utils import SENTRY_ENABLED
from letta.server.server import SyncServer
from letta.settings import settings
from letta.settings import settings, telemetry_settings
if SENTRY_ENABLED:
import sentry_sdk
@@ -92,24 +92,6 @@ def generate_password():
random_password = os.getenv("LETTA_SERVER_PASSWORD") or generate_password()
class CheckPasswordMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
# Exclude health check endpoint from password protection
if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}:
return await call_next(request)
if (
request.headers.get("X-BARE-PASSWORD") == f"password {random_password}"
or request.headers.get("Authorization") == f"Bearer {random_password}"
):
return await call_next(request)
return JSONResponse(
content={"detail": "Unauthorized"},
status_code=401,
)
@asynccontextmanager
async def lifespan(app_: FastAPI):
"""
@@ -117,6 +99,19 @@ async def lifespan(app_: FastAPI):
"""
worker_id = os.getpid()
if telemetry_settings.profiler:
try:
import googlecloudprofiler
googlecloudprofiler.start(
service="memgpt-server",
service_version=str(letta_version),
verbose=3,
)
logger.info("Profiler started.")
except (ValueError, NotImplementedError) as exc:
logger.info("Profiler not enabled: %", exc)
logger.info(f"[Worker {worker_id}] Starting lifespan initialization")
logger.info(f"[Worker {worker_id}] Initializing database connections")
db_registry.initialize_sync()
@@ -283,11 +278,14 @@ def create_application() -> "FastAPI":
if (os.getenv("LETTA_SERVER_SECURE") == "true") or "--secure" in sys.argv:
print(f"▶ Using secure mode with password: {random_password}")
app.add_middleware(CheckPasswordMiddleware)
app.add_middleware(CheckPasswordMiddleware, password=random_password)
# Add reverse proxy middleware to handle X-Forwarded-* headers
# app.add_middleware(ReverseProxyMiddleware, base_path=settings.server_base_path)
if telemetry_settings.profiler:
app.add_middleware(ProfilerContextMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,

View File

@@ -0,0 +1,4 @@
from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware
from letta.server.rest_api.middleware.profiler_context import ProfilerContextMiddleware
__all__ = ["CheckPasswordMiddleware", "ProfilerContextMiddleware"]

View File

@@ -0,0 +1,24 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
class CheckPasswordMiddleware(BaseHTTPMiddleware):
def __init__(self, app, password: str):
super().__init__(app)
self.password = password
async def dispatch(self, request, call_next):
# Exclude health check endpoint from password protection
if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}:
return await call_next(request)
if (
request.headers.get("X-BARE-PASSWORD") == f"password {self.password}"
or request.headers.get("Authorization") == f"Bearer {self.password}"
):
return await call_next(request)
return JSONResponse(
content={"detail": "Unauthorized"},
status_code=401,
)

View File

@@ -0,0 +1,25 @@
from starlette.middleware.base import BaseHTTPMiddleware
class ProfilerContextMiddleware(BaseHTTPMiddleware):
"""Middleware to set context if using profiler. Currently just uses google-cloud-profiler."""
async def dispatch(self, request, call_next):
ctx = None
if request.url.path in {"/v1/health", "/v1/health/"}:
return await call_next(request)
try:
labels = {
"method": request.method,
"path": request.url.path,
"endpoint": request.url.path,
}
import googlecloudprofiler
ctx = googlecloudprofiler.context.set_labels(**labels)
except:
return await call_next(request)
if ctx:
with ctx:
return await call_next(request)
return await call_next(request)

View File

@@ -331,6 +331,11 @@ class LogSettings(BaseSettings):
verbose_telemetry_logging: bool = Field(False)
class TelemetrySettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="letta_telemetry_", extra="ignore")
profiler: bool | None = Field(False, description="Enable use of the profiler.")
# singleton
settings = Settings(_env_parse_none_str="None")
test_settings = TestSettings()
@@ -338,3 +343,4 @@ model_settings = ModelSettings()
tool_settings = ToolSettings()
summarizer_settings = SummarizerSettings()
log_settings = LogSettings()
telemetry_settings = TelemetrySettings()

723
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -101,6 +101,7 @@ certifi = "^2025.6.15"
aioboto3 = {version = "^14.3.0", optional = true}
pinecone = {extras = ["asyncio"], version = "^7.3.0", optional = true}
markitdown = {extras = ["docx", "pdf", "pptx"], version = "^0.1.2"}
google-cloud-profiler = {version = "^4.1.0", optional = true}
[tool.poetry.extras]
@@ -108,7 +109,7 @@ postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "asyncpg"]
redis = ["redis"]
pinecone = ["pinecone"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "locust"]
experimental = ["uvloop", "granian"]
experimental = ["uvloop", "granian", "google-cloud-profiler"]
server = ["websockets", "fastapi", "uvicorn"]
cloud-tool-sandbox = ["e2b-code-interpreter"]
external-tools = ["docker", "langchain", "wikipedia", "langchain-community", "firecrawl-py"]
@@ -116,7 +117,7 @@ tests = ["wikipedia"]
bedrock = ["boto3", "aioboto3"]
google = ["google-genai"]
desktop = ["pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone", "google-cloud-profiler"]
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"