feat: profiling middleware
This commit is contained in:
@@ -12,7 +12,6 @@ from typing import Optional
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.__init__ import __version__ as letta_version
|
||||
@@ -35,6 +34,7 @@ from letta.server.db import db_registry
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.middleware import CheckPasswordMiddleware, ProfilerContextMiddleware
|
||||
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
|
||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
|
||||
@@ -42,7 +42,7 @@ from letta.server.rest_api.routers.v1.users import router as users_router # TOD
|
||||
from letta.server.rest_api.static_files import mount_static_files
|
||||
from letta.server.rest_api.utils import SENTRY_ENABLED
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
from letta.settings import settings, telemetry_settings
|
||||
|
||||
if SENTRY_ENABLED:
|
||||
import sentry_sdk
|
||||
@@ -92,24 +92,6 @@ def generate_password():
|
||||
random_password = os.getenv("LETTA_SERVER_PASSWORD") or generate_password()
|
||||
|
||||
|
||||
class CheckPasswordMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request, call_next):
|
||||
# Exclude health check endpoint from password protection
|
||||
if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}:
|
||||
return await call_next(request)
|
||||
|
||||
if (
|
||||
request.headers.get("X-BARE-PASSWORD") == f"password {random_password}"
|
||||
or request.headers.get("Authorization") == f"Bearer {random_password}"
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
return JSONResponse(
|
||||
content={"detail": "Unauthorized"},
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app_: FastAPI):
|
||||
"""
|
||||
@@ -117,6 +99,19 @@ async def lifespan(app_: FastAPI):
|
||||
"""
|
||||
worker_id = os.getpid()
|
||||
|
||||
if telemetry_settings.profiler:
|
||||
try:
|
||||
import googlecloudprofiler
|
||||
|
||||
googlecloudprofiler.start(
|
||||
service="memgpt-server",
|
||||
service_version=str(letta_version),
|
||||
verbose=3,
|
||||
)
|
||||
logger.info("Profiler started.")
|
||||
except (ValueError, NotImplementedError) as exc:
|
||||
logger.info("Profiler not enabled: %", exc)
|
||||
|
||||
logger.info(f"[Worker {worker_id}] Starting lifespan initialization")
|
||||
logger.info(f"[Worker {worker_id}] Initializing database connections")
|
||||
db_registry.initialize_sync()
|
||||
@@ -283,11 +278,14 @@ def create_application() -> "FastAPI":
|
||||
|
||||
if (os.getenv("LETTA_SERVER_SECURE") == "true") or "--secure" in sys.argv:
|
||||
print(f"▶ Using secure mode with password: {random_password}")
|
||||
app.add_middleware(CheckPasswordMiddleware)
|
||||
app.add_middleware(CheckPasswordMiddleware, password=random_password)
|
||||
|
||||
# Add reverse proxy middleware to handle X-Forwarded-* headers
|
||||
# app.add_middleware(ReverseProxyMiddleware, base_path=settings.server_base_path)
|
||||
|
||||
if telemetry_settings.profiler:
|
||||
app.add_middleware(ProfilerContextMiddleware)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
|
||||
4
letta/server/rest_api/middleware/__init__.py
Normal file
4
letta/server/rest_api/middleware/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from letta.server.rest_api.middleware.check_password import CheckPasswordMiddleware
|
||||
from letta.server.rest_api.middleware.profiler_context import ProfilerContextMiddleware
|
||||
|
||||
__all__ = ["CheckPasswordMiddleware", "ProfilerContextMiddleware"]
|
||||
24
letta/server/rest_api/middleware/check_password.py
Normal file
24
letta/server/rest_api/middleware/check_password.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class CheckPasswordMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, password: str):
|
||||
super().__init__(app)
|
||||
self.password = password
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
# Exclude health check endpoint from password protection
|
||||
if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}:
|
||||
return await call_next(request)
|
||||
|
||||
if (
|
||||
request.headers.get("X-BARE-PASSWORD") == f"password {self.password}"
|
||||
or request.headers.get("Authorization") == f"Bearer {self.password}"
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
return JSONResponse(
|
||||
content={"detail": "Unauthorized"},
|
||||
status_code=401,
|
||||
)
|
||||
25
letta/server/rest_api/middleware/profiler_context.py
Normal file
25
letta/server/rest_api/middleware/profiler_context.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
class ProfilerContextMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to set context if using profiler. Currently just uses google-cloud-profiler."""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
ctx = None
|
||||
if request.url.path in {"/v1/health", "/v1/health/"}:
|
||||
return await call_next(request)
|
||||
try:
|
||||
labels = {
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"endpoint": request.url.path,
|
||||
}
|
||||
import googlecloudprofiler
|
||||
|
||||
ctx = googlecloudprofiler.context.set_labels(**labels)
|
||||
except:
|
||||
return await call_next(request)
|
||||
if ctx:
|
||||
with ctx:
|
||||
return await call_next(request)
|
||||
return await call_next(request)
|
||||
@@ -331,6 +331,11 @@ class LogSettings(BaseSettings):
|
||||
verbose_telemetry_logging: bool = Field(False)
|
||||
|
||||
|
||||
class TelemetrySettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_telemetry_", extra="ignore")
|
||||
profiler: bool | None = Field(False, description="Enable use of the profiler.")
|
||||
|
||||
|
||||
# singleton
|
||||
settings = Settings(_env_parse_none_str="None")
|
||||
test_settings = TestSettings()
|
||||
@@ -338,3 +343,4 @@ model_settings = ModelSettings()
|
||||
tool_settings = ToolSettings()
|
||||
summarizer_settings = SummarizerSettings()
|
||||
log_settings = LogSettings()
|
||||
telemetry_settings = TelemetrySettings()
|
||||
|
||||
723
poetry.lock
generated
723
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -101,6 +101,7 @@ certifi = "^2025.6.15"
|
||||
aioboto3 = {version = "^14.3.0", optional = true}
|
||||
pinecone = {extras = ["asyncio"], version = "^7.3.0", optional = true}
|
||||
markitdown = {extras = ["docx", "pdf", "pptx"], version = "^0.1.2"}
|
||||
google-cloud-profiler = {version = "^4.1.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
@@ -108,7 +109,7 @@ postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "asyncpg"]
|
||||
redis = ["redis"]
|
||||
pinecone = ["pinecone"]
|
||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "locust"]
|
||||
experimental = ["uvloop", "granian"]
|
||||
experimental = ["uvloop", "granian", "google-cloud-profiler"]
|
||||
server = ["websockets", "fastapi", "uvicorn"]
|
||||
cloud-tool-sandbox = ["e2b-code-interpreter"]
|
||||
external-tools = ["docker", "langchain", "wikipedia", "langchain-community", "firecrawl-py"]
|
||||
@@ -116,7 +117,7 @@ tests = ["wikipedia"]
|
||||
bedrock = ["boto3", "aioboto3"]
|
||||
google = ["google-genai"]
|
||||
desktop = ["pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone", "google-cloud-profiler"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.4.2"
|
||||
|
||||
Reference in New Issue
Block a user