242 lines
8.2 KiB
Python
242 lines
8.2 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import subprocess
|
|
from typing import Optional
|
|
|
|
import typer
|
|
import uvicorn
|
|
from fastapi import Depends, FastAPI, HTTPException
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
from memgpt.server.constants import REST_DEFAULT_PORT
|
|
from memgpt.server.rest_api.auth.index import (
|
|
setup_auth_router, # TODO: probably remove right?
|
|
)
|
|
from memgpt.server.rest_api.interface import StreamingServerInterface
|
|
from memgpt.server.rest_api.routers.openai.assistants.assistants import (
|
|
router as openai_assistants_router,
|
|
)
|
|
from memgpt.server.rest_api.routers.openai.assistants.threads import (
|
|
router as openai_threads_router,
|
|
)
|
|
from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import (
|
|
router as openai_chat_completions_router,
|
|
)
|
|
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
|
from memgpt.server.rest_api.routers.v1.users import (
|
|
router as users_router, # TODO: decide on admin
|
|
)
|
|
from memgpt.server.rest_api.static_files import mount_static_files
|
|
from memgpt.server.server import SyncServer
|
|
from memgpt.settings import settings
|
|
|
|
"""
|
|
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
|
|
|
|
Start the server with:
|
|
cd memgpt/server/rest_api
|
|
poetry run uvicorn server:app --reload
|
|
"""
|
|
|
|
interface: StreamingServerInterface = StreamingServerInterface
|
|
server: SyncServer = SyncServer(default_interface_factory=lambda: interface())
|
|
|
|
if password := settings.server_pass:
|
|
# if the pass was specified in the environment, use it
|
|
print(f"Using existing admin server password from environment.")
|
|
else:
|
|
# Autogenerate a password for this session and dump it to stdout
|
|
password = secrets.token_urlsafe(16)
|
|
typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN)
|
|
|
|
security = HTTPBearer()
|
|
|
|
|
|
def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""REST requests going to /admin are protected with a bearer token (that must match the password)"""
|
|
if credentials.credentials != password:
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
|
|
ADMIN_PREFIX = "/v1/admin"
|
|
API_PREFIX = "/v1"
|
|
OPENAI_API_PREFIX = "/openai"
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# v1_routes are the MemGPT API routes
|
|
for route in v1_routes:
|
|
app.include_router(route, prefix=API_PREFIX)
|
|
# this gives undocumented routes for "latest" and bare api calls.
|
|
# we should always tie this to the newest version of the api.
|
|
app.include_router(route, prefix="", include_in_schema=False)
|
|
app.include_router(route, prefix="/latest", include_in_schema=False)
|
|
|
|
# admin/users
|
|
app.include_router(users_router, prefix=ADMIN_PREFIX)
|
|
|
|
# openai
|
|
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
|
|
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
|
|
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
|
|
|
|
# /api/auth endpoints
|
|
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
|
|
|
|
# # Serve static files
|
|
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
|
|
# app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="static")
|
|
|
|
|
|
# # Serve favicon
|
|
# @app.get("/favicon.ico")
|
|
# async def favicon():
|
|
# return FileResponse(os.path.join(static_files_path, "favicon.ico"))
|
|
|
|
|
|
# # Middleware to handle API routes first
|
|
# @app.middleware("http")
|
|
# async def handle_api_routes(request: Request, call_next):
|
|
# if request.url.path.startswith(("/v1/", "/openai/")):
|
|
# response = await call_next(request)
|
|
# if response.status_code != 404:
|
|
# return response
|
|
# return await serve_spa(request.url.path)
|
|
|
|
|
|
# # Catch-all route for SPA
|
|
# async def serve_spa(full_path: str):
|
|
# return FileResponse(os.path.join(static_files_path, "index.html"))
|
|
|
|
|
|
mount_static_files(app)
|
|
|
|
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
# Update the OpenAPI schema
|
|
if not app.openapi_schema:
|
|
app.openapi_schema = app.openapi()
|
|
|
|
if app.openapi_schema:
|
|
app.openapi_schema["servers"] = [{"url": host} for host in settings.cors_origins]
|
|
app.openapi_schema["info"]["title"] = "MemGPT API"
|
|
|
|
# Split the API docs into MemGPT API, and OpenAI Assistants compatible API
|
|
memgpt_api = app.openapi_schema.copy()
|
|
memgpt_api["paths"] = {key: value for key, value in memgpt_api["paths"].items() if not key.startswith(OPENAI_API_PREFIX)}
|
|
memgpt_api["info"]["title"] = "MemGPT API"
|
|
with open("openapi_memgpt.json", "w", encoding="utf-8") as file:
|
|
print(f"Writing out openapi_memgpt.json file")
|
|
json.dump(memgpt_api, file, indent=2)
|
|
|
|
openai_assistants_api = app.openapi_schema.copy()
|
|
openai_assistants_api["paths"] = {
|
|
key: value
|
|
for key, value in openai_assistants_api["paths"].items()
|
|
if not (key.startswith(API_PREFIX) or key.startswith(ADMIN_PREFIX))
|
|
}
|
|
openai_assistants_api["info"]["title"] = "OpenAI Assistants API"
|
|
with open("openapi_assistants.json", "w", encoding="utf-8") as file:
|
|
print(f"Writing out openapi_assistants.json file")
|
|
json.dump(openai_assistants_api, file, indent=2)
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
def on_shutdown():
|
|
global server
|
|
if server:
|
|
server.save_agents()
|
|
server = None
|
|
|
|
|
|
def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"):
|
|
"""Generate a self-signed SSL certificate.
|
|
|
|
NOTE: intended to be used for development only.
|
|
"""
|
|
subprocess.run(
|
|
[
|
|
"openssl",
|
|
"req",
|
|
"-x509",
|
|
"-newkey",
|
|
"rsa:4096",
|
|
"-keyout",
|
|
key_path,
|
|
"-out",
|
|
cert_path,
|
|
"-days",
|
|
"365",
|
|
"-nodes",
|
|
"-subj",
|
|
"/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost",
|
|
],
|
|
check=True,
|
|
)
|
|
return cert_path, key_path
|
|
|
|
|
|
def start_server(
|
|
port: Optional[int] = None,
|
|
host: Optional[str] = None,
|
|
use_ssl: bool = False,
|
|
ssl_cert: Optional[str] = None,
|
|
ssl_key: Optional[str] = None,
|
|
debug: bool = False,
|
|
):
|
|
print("DEBUG", debug)
|
|
if debug:
|
|
from memgpt.server.server import logger as server_logger
|
|
|
|
# Set the logging level
|
|
server_logger.setLevel(logging.DEBUG)
|
|
# Create a StreamHandler
|
|
stream_handler = logging.StreamHandler()
|
|
# Set the formatter (optional)
|
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
stream_handler.setFormatter(formatter)
|
|
# Add the handler to the logger
|
|
server_logger.addHandler(stream_handler)
|
|
|
|
if use_ssl:
|
|
if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
|
|
ssl_certfile, ssl_keyfile = generate_self_signed_cert()
|
|
print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
|
else:
|
|
ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
|
|
print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
|
|
|
# This will start the server on HTTPS
|
|
assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
|
|
assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
|
|
print(
|
|
f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
|
|
)
|
|
uvicorn.run(
|
|
app,
|
|
host=host or "localhost",
|
|
port=port or REST_DEFAULT_PORT,
|
|
ssl_keyfile=ssl_keyfile,
|
|
ssl_certfile=ssl_certfile,
|
|
)
|
|
else:
|
|
# Start the subprocess in a new session
|
|
print(f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
|
|
uvicorn.run(
|
|
app,
|
|
host=host or "localhost",
|
|
port=port or REST_DEFAULT_PORT,
|
|
)
|