feat: support a new secure flag (#2030)

Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
Shubham Naik
2024-11-12 21:34:36 -08:00
committed by GitHub
parent c9c10e945e
commit b4a2a227e2
2 changed files with 28 additions and 0 deletions

View File

@@ -47,6 +47,7 @@ def server(
host: Annotated[Optional[str], typer.Option(help="Host to run the server on (default to localhost)")] = None,
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False,
ade: Annotated[bool, typer.Option(help="Allows remote access")] = False,
secure: Annotated[bool, typer.Option(help="Adds simple security access")] = False,
):
"""Launch a Letta server process"""
if type == ServerChoice.rest_api:

View File

@@ -6,6 +6,8 @@ from typing import Optional
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
@@ -94,6 +96,27 @@ def generate_openapi_schema(app: FastAPI):
Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2))
# middleware that only allows requests to pass through if user provides a password thats randomly generated and stored in memory
def generate_password():
import secrets
return secrets.token_urlsafe(16)
random_password = generate_password()
class CheckPasswordMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
if request.headers.get("X-BARE-PASSWORD") == f"password {random_password}":
return await call_next(request)
return JSONResponse(
content={"detail": "Unauthorized"},
status_code=401,
)
def create_application() -> "FastAPI":
"""the application start routine"""
# global server
@@ -113,6 +136,10 @@ def create_application() -> "FastAPI":
settings.cors_origins.append("https://app.letta.com")
print(f"▶ View using ADE at: https://app.letta.com/local-project/agents")
if "--secure" in sys.argv:
print(f"▶ Using secure mode with password: {random_password}")
app.add_middleware(CheckPasswordMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,