fix: fix the static file mounting handler breaking the API (#1743)
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from memgpt.server.constants import REST_DEFAULT_PORT
|
||||
@@ -31,7 +35,6 @@ from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
from memgpt.server.rest_api.routers.v1.users import (
|
||||
router as users_router, # TODO: decide on admin
|
||||
)
|
||||
from memgpt.server.rest_api.static_files import mount_static_files
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
@@ -55,6 +58,29 @@ API_PREFIX = "/v1"
|
||||
OPENAI_API_PREFIX = "/openai"
|
||||
|
||||
|
||||
class SmartStaticFilesMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# List of API prefixes that should bypass static handling
|
||||
api_prefixes = [API_PREFIX, OPENAI_API_PREFIX, ADMIN_PREFIX]
|
||||
path = request.url.path
|
||||
|
||||
# Check if the request path starts with any API prefix
|
||||
if any(path.startswith(prefix) for prefix in api_prefixes):
|
||||
# If it's an API call, process normally
|
||||
print(f"API request detected: {path}")
|
||||
response = await call_next(request)
|
||||
else:
|
||||
print(f"Static request detected: {path}")
|
||||
# Try to serve static files, catch any errors like 404, etc.
|
||||
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
|
||||
filepath = os.path.join(static_files_path, path.lstrip("/"))
|
||||
if os.path.isfile(filepath):
|
||||
return FileResponse(filepath)
|
||||
else:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
def create_application() -> "FastAPI":
|
||||
"""the application start routine"""
|
||||
|
||||
@@ -72,6 +98,8 @@ def create_application() -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Usage in your create_application function:
|
||||
app.add_middleware(SmartStaticFilesMiddleware)
|
||||
|
||||
for route in v1_routes:
|
||||
app.include_router(route, prefix=API_PREFIX)
|
||||
@@ -95,7 +123,7 @@ def create_application() -> "FastAPI":
|
||||
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
# / static files
|
||||
mount_static_files(app)
|
||||
# mount_static_files(app)
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
|
||||
@@ -21,11 +21,25 @@ def mount_static_files(app: FastAPI):
|
||||
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
|
||||
if os.path.exists(static_files_path):
|
||||
app.mount(
|
||||
# "/",
|
||||
"/app",
|
||||
"/",
|
||||
# "/app",
|
||||
SPAStaticFiles(
|
||||
directory=static_files_path,
|
||||
html=True,
|
||||
),
|
||||
name="spa-static-files",
|
||||
)
|
||||
|
||||
|
||||
# def mount_static_files(app: FastAPI):
|
||||
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
|
||||
# if os.path.exists(static_files_path):
|
||||
|
||||
# @app.get("/{full_path:path}")
|
||||
# async def serve_spa(full_path: str):
|
||||
# if full_path.startswith("v1"):
|
||||
# raise HTTPException(status_code=404, detail="Not found")
|
||||
# file_path = os.path.join(static_files_path, full_path)
|
||||
# if os.path.isfile(file_path):
|
||||
# return FileResponse(file_path)
|
||||
# return FileResponse(os.path.join(static_files_path, "index.html"))
|
||||
|
||||
Reference in New Issue
Block a user