fix: fix the static file mounting handler breaking the API (#1743)

This commit is contained in:
Charles Packer
2024-09-10 14:32:23 -07:00
committed by GitHub
parent a3112662e2
commit 89eb7828b6
2 changed files with 47 additions and 5 deletions

View File

@@ -1,12 +1,16 @@
import importlib.util
import json
import logging
import os
import secrets
from pathlib import Path
from typing import Optional
import typer
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from memgpt.server.constants import REST_DEFAULT_PORT
@@ -31,7 +35,6 @@ from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.server import SyncServer
from memgpt.settings import settings
@@ -55,6 +58,29 @@ API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"
class SmartStaticFilesMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# List of API prefixes that should bypass static handling
api_prefixes = [API_PREFIX, OPENAI_API_PREFIX, ADMIN_PREFIX]
path = request.url.path
# Check if the request path starts with any API prefix
if any(path.startswith(prefix) for prefix in api_prefixes):
# If it's an API call, process normally
print(f"API request detected: {path}")
response = await call_next(request)
else:
print(f"Static request detected: {path}")
# Try to serve static files, catch any errors like 404, etc.
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
filepath = os.path.join(static_files_path, path.lstrip("/"))
if os.path.isfile(filepath):
return FileResponse(filepath)
else:
response = await call_next(request)
return response
def create_application() -> "FastAPI":
"""the application start routine"""
@@ -72,6 +98,8 @@ def create_application() -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
# Usage in your create_application function:
app.add_middleware(SmartStaticFilesMiddleware)
for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
@@ -95,7 +123,7 @@ def create_application() -> "FastAPI":
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
# / static files
mount_static_files(app)
# mount_static_files(app)
@app.on_event("startup")
def on_startup():

View File

@@ -21,11 +21,25 @@ def mount_static_files(app: FastAPI):
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
if os.path.exists(static_files_path):
app.mount(
# "/",
"/app",
"/",
# "/app",
SPAStaticFiles(
directory=static_files_path,
html=True,
),
name="spa-static-files",
)
# def mount_static_files(app: FastAPI):
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
# if os.path.exists(static_files_path):
# @app.get("/{full_path:path}")
# async def serve_spa(full_path: str):
# if full_path.startswith("v1"):
# raise HTTPException(status_code=404, detail="Not found")
# file_path = os.path.join(static_files_path, full_path)
# if os.path.isfile(file_path):
# return FileResponse(file_path)
# return FileResponse(os.path.join(static_files_path, "index.html"))