From 89eb7828b65a853e4ad4f7fc847b918d18ac1710 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 10 Sep 2024 14:32:23 -0700 Subject: [PATCH] fix: fix the static file mounting handler breaking the API (#1743) --- memgpt/server/rest_api/app.py | 34 +++++++++++++++++++++++--- memgpt/server/rest_api/static_files.py | 18 ++++++++++++-- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/memgpt/server/rest_api/app.py b/memgpt/server/rest_api/app.py index 2f32397e..973a362d 100644 --- a/memgpt/server/rest_api/app.py +++ b/memgpt/server/rest_api/app.py @@ -1,12 +1,16 @@ +import importlib.util import json import logging +import os import secrets from pathlib import Path from typing import Optional import typer import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.responses import FileResponse +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from memgpt.server.constants import REST_DEFAULT_PORT @@ -31,7 +35,6 @@ from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes from memgpt.server.rest_api.routers.v1.users import ( router as users_router, # TODO: decide on admin ) -from memgpt.server.rest_api.static_files import mount_static_files from memgpt.server.server import SyncServer from memgpt.settings import settings @@ -55,6 +58,29 @@ API_PREFIX = "/v1" OPENAI_API_PREFIX = "/openai" +class SmartStaticFilesMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # List of API prefixes that should bypass static handling + api_prefixes = [API_PREFIX, OPENAI_API_PREFIX, ADMIN_PREFIX] + path = request.url.path + + # Check if the request path starts with any API prefix + if any(path.startswith(prefix) for prefix in api_prefixes): + # If it's an API call, process normally + print(f"API request detected: {path}") + response = await call_next(request) + else: + print(f"Static request detected: {path}") + # Try to serve static files, catch any errors like 404, etc. + static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") + filepath = os.path.join(static_files_path, path.lstrip("/")) + if os.path.isfile(filepath): + return FileResponse(filepath) + else: + response = await call_next(request) + return response + + def create_application() -> "FastAPI": """the application start routine""" @@ -72,6 +98,8 @@ def create_application() -> "FastAPI": allow_methods=["*"], allow_headers=["*"], ) + # Usage in your create_application function: + app.add_middleware(SmartStaticFilesMiddleware) for route in v1_routes: app.include_router(route, prefix=API_PREFIX) @@ -95,7 +123,7 @@ def create_application() -> "FastAPI": app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX) # / static files - mount_static_files(app) + # mount_static_files(app) @app.on_event("startup") def on_startup(): diff --git a/memgpt/server/rest_api/static_files.py b/memgpt/server/rest_api/static_files.py index b6899fca..d9b0b39a 100644 --- a/memgpt/server/rest_api/static_files.py +++ b/memgpt/server/rest_api/static_files.py @@ -21,11 +21,25 @@ def mount_static_files(app: FastAPI): static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") if os.path.exists(static_files_path): app.mount( - # "/", - "/app", + "/", + # "/app", SPAStaticFiles( directory=static_files_path, html=True, ), name="spa-static-files", ) + + +# def mount_static_files(app: FastAPI): +# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") +# if os.path.exists(static_files_path): + +# @app.get("/{full_path:path}") +# async def serve_spa(full_path: str): +# if full_path.startswith("v1"): +# raise HTTPException(status_code=404, detail="Not found") +# file_path = os.path.join(static_files_path, full_path) +# if os.path.isfile(file_path): +# return FileResponse(file_path) +# return FileResponse(os.path.join(static_files_path, "index.html"))