diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 1afdba1b..56b79fb1 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -47,6 +47,7 @@ def server( host: Annotated[Optional[str], typer.Option(help="Host to run the server on (default to localhost)")] = None, debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False, ade: Annotated[bool, typer.Option(help="Allows remote access")] = False, + secure: Annotated[bool, typer.Option(help="Adds simple security access")] = False, ): """Launch a Letta server process""" if type == ServerChoice.rest_api: diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index fb4c4049..a2659cf6 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -6,6 +6,8 @@ from typing import Optional import uvicorn from fastapi import FastAPI +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from letta.__init__ import __version__ @@ -94,6 +96,27 @@ def generate_openapi_schema(app: FastAPI): Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2)) +# middleware that only allows requests to pass through if user provides a password thats randomly generated and stored in memory +def generate_password(): + import secrets + + return secrets.token_urlsafe(16) + + +random_password = generate_password() + + +class CheckPasswordMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + if request.headers.get("X-BARE-PASSWORD") == f"password {random_password}": + return await call_next(request) + + return JSONResponse( + content={"detail": "Unauthorized"}, + status_code=401, + ) + + def create_application() -> "FastAPI": """the application start routine""" # global server @@ -113,6 +136,10 @@ def create_application() -> "FastAPI": settings.cors_origins.append("https://app.letta.com") print(f"▶ View using ADE at: https://app.letta.com/local-project/agents") + if "--secure" in sys.argv: + print(f"▶ Using secure mode with password: {random_password}") + app.add_middleware(CheckPasswordMiddleware) + app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins,