feat: move from server.py to app.py (#1740)

This commit is contained in:
Charles Packer
2024-09-10 13:31:03 -07:00
committed by GitHub
parent 60d5899a8c
commit a3112662e2
7 changed files with 178 additions and 304 deletions

View File

@@ -1,10 +1,8 @@
import json
import logging
import os
import subprocess
import sys
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional
import questionary
@@ -24,7 +22,6 @@ from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import OptionState
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import ChatMemory, Memory
from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.server import logger as server_logger
# from memgpt.interface import CLIInterface as interface # for printing to terminal
@@ -304,9 +301,6 @@ def server(
type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest",
port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
host: Annotated[Optional[str], typer.Option(help="Host to run the server on (default to localhost)")] = None,
use_ssl: Annotated[bool, typer.Option(help="Run the server using HTTPS?")] = False,
ssl_cert: Annotated[Optional[str], typer.Option(help="Path to SSL certificate (if use_ssl is True)")] = None,
ssl_key: Annotated[Optional[str], typer.Option(help="Path to SSL key file (if use_ssl is True)")] = None,
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False,
):
"""Launch a MemGPT server process"""
@@ -317,22 +311,15 @@ def server(
if MemGPTConfig.exists():
config = MemGPTConfig.load()
MetadataStore(config)
client = create_client() # triggers user creation
_ = create_client() # triggers user creation
else:
typer.secho(f"No configuration exists. Run memgpt configure before starting the server.", fg=typer.colors.RED)
sys.exit(1)
try:
from memgpt.server.rest_api.server import start_server
from memgpt.server.rest_api.app import start_server
start_server(
port=port,
host=host,
use_ssl=use_ssl,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
debug=debug,
)
start_server(port=port, host=host, debug=debug)
except KeyboardInterrupt:
# Handle CTRL-C
@@ -340,48 +327,7 @@ def server(
sys.exit(0)
elif type == ServerChoice.ws_api:
if debug:
from memgpt.server.server import logger as server_logger
# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)
if port is None:
port = WS_DEFAULT_PORT
# Change to the desired directory
script_path = Path(__file__).resolve()
script_dir = script_path.parent
server_directory = os.path.join(script_dir.parent, "server", "ws_api")
command = f"python server.py {port}"
# Run the command
typer.secho(f"Running WS (websockets) server: {command} (inside {server_directory})")
process = None
try:
# Start the subprocess in a new session
process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory)
process.wait()
except KeyboardInterrupt:
# Handle CTRL-C
if process is not None:
typer.secho("Terminating the server...")
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
typer.secho("Server terminated with kill()")
sys.exit(0)
raise NotImplementedError("WS suppport deprecated")
def run(

View File

@@ -0,0 +1,169 @@
import json
import logging
import secrets
from pathlib import Path
from typing import Optional
import typer
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from memgpt.server.constants import REST_DEFAULT_PORT
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from memgpt.server.rest_api.auth.index import (
setup_auth_router, # TODO: probably remove right?
)
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.routers.openai.assistants.assistants import (
router as openai_assistants_router,
)
from memgpt.server.rest_api.routers.openai.assistants.threads import (
router as openai_threads_router,
)
from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import (
router as openai_chat_completions_router,
)
# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.server import SyncServer
from memgpt.settings import settings
# TODO(ethan)
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
interface: StreamingServerInterface = StreamingServerInterface
server: SyncServer = SyncServer(default_interface_factory=lambda: interface())
# TODO(ethan): eventuall remove
if password := settings.server_pass:
# if the pass was specified in the environment, use it
print(f"Using existing admin server password from environment.")
else:
# Autogenerate a password for this session and dump it to stdout
password = secrets.token_urlsafe(16)
typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN)
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"
def create_application() -> "FastAPI":
"""the application start routine"""
app = FastAPI(
swagger_ui_parameters={"docExpansion": "none"},
# openapi_tags=TAGS_METADATA,
title="MemGPT",
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
version="1.0.0", # TODO wire this up to the version in the package
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
# we should always tie this to the newest version of the api.
app.include_router(route, prefix="", include_in_schema=False)
app.include_router(route, prefix="/latest", include_in_schema=False)
# NOTE: ethan these are the extra routes
# TODO(ethan) remove
# admin/users
app.include_router(users_router, prefix=ADMIN_PREFIX)
# openai
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
# /api/auth endpoints
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
# / static files
mount_static_files(app)
@app.on_event("startup")
def on_startup():
# load the default tools
# from memgpt.orm.tool import Tool
# Tool.load_default_tools(get_db_session())
# Update the OpenAPI schema
if not app.openapi_schema:
app.openapi_schema = app.openapi()
openai_docs, memgpt_docs = [app.openapi_schema.copy() for _ in range(2)]
openai_docs["paths"] = {k: v for k, v in openai_docs["paths"].items() if k.startswith("/openai")}
openai_docs["info"]["title"] = "OpenAI Assistants API"
memgpt_docs["paths"] = {k: v for k, v in memgpt_docs["paths"].items() if not k.startswith("/openai")}
memgpt_docs["info"]["title"] = "MemGPT API"
# Split the API docs into MemGPT API, and OpenAI Assistants compatible API
for name, docs in [
(
"openai",
openai_docs,
),
(
"memgpt",
memgpt_docs,
),
]:
if settings.cors_origins:
docs["servers"] = [{"url": host} for host in settings.cors_origins]
Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2))
@app.on_event("shutdown")
def on_shutdown():
global server
server.save_agents()
server = None
return app
app = create_application()
def start_server(
port: Optional[int] = None,
host: Optional[str] = None,
debug: bool = False,
):
"""Convenience method to start the server from within Python"""
if debug:
from memgpt.server.server import logger as server_logger
# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)
print(f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
)

View File

@@ -1,241 +0,0 @@
import json
import logging
import os
import secrets
import subprocess
from typing import Optional
import typer
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette.middleware.cors import CORSMiddleware
from memgpt.server.constants import REST_DEFAULT_PORT
from memgpt.server.rest_api.auth.index import (
setup_auth_router, # TODO: probably remove right?
)
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.routers.openai.assistants.assistants import (
router as openai_assistants_router,
)
from memgpt.server.rest_api.routers.openai.assistants.threads import (
router as openai_threads_router,
)
from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import (
router as openai_chat_completions_router,
)
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.server import SyncServer
from memgpt.settings import settings
"""
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
Start the server with:
cd memgpt/server/rest_api
poetry run uvicorn server:app --reload
"""
interface: StreamingServerInterface = StreamingServerInterface
server: SyncServer = SyncServer(default_interface_factory=lambda: interface())
if password := settings.server_pass:
# if the pass was specified in the environment, use it
print(f"Using existing admin server password from environment.")
else:
# Autogenerate a password for this session and dump it to stdout
password = secrets.token_urlsafe(16)
typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN)
security = HTTPBearer()
def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""REST requests going to /admin are protected with a bearer token (that must match the password)"""
if credentials.credentials != password:
raise HTTPException(status_code=401, detail="Unauthorized")
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# v1_routes are the MemGPT API routes
for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
# we should always tie this to the newest version of the api.
app.include_router(route, prefix="", include_in_schema=False)
app.include_router(route, prefix="/latest", include_in_schema=False)
# admin/users
app.include_router(users_router, prefix=ADMIN_PREFIX)
# openai
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
# /api/auth endpoints
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
# # Serve static files
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
# app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="static")
# # Serve favicon
# @app.get("/favicon.ico")
# async def favicon():
# return FileResponse(os.path.join(static_files_path, "favicon.ico"))
# # Middleware to handle API routes first
# @app.middleware("http")
# async def handle_api_routes(request: Request, call_next):
# if request.url.path.startswith(("/v1/", "/openai/")):
# response = await call_next(request)
# if response.status_code != 404:
# return response
# return await serve_spa(request.url.path)
# # Catch-all route for SPA
# async def serve_spa(full_path: str):
# return FileResponse(os.path.join(static_files_path, "index.html"))
mount_static_files(app)
@app.on_event("startup")
def on_startup():
# Update the OpenAPI schema
if not app.openapi_schema:
app.openapi_schema = app.openapi()
if app.openapi_schema:
app.openapi_schema["servers"] = [{"url": host} for host in settings.cors_origins]
app.openapi_schema["info"]["title"] = "MemGPT API"
# Split the API docs into MemGPT API, and OpenAI Assistants compatible API
memgpt_api = app.openapi_schema.copy()
memgpt_api["paths"] = {key: value for key, value in memgpt_api["paths"].items() if not key.startswith(OPENAI_API_PREFIX)}
memgpt_api["info"]["title"] = "MemGPT API"
with open("openapi_memgpt.json", "w", encoding="utf-8") as file:
print(f"Writing out openapi_memgpt.json file")
json.dump(memgpt_api, file, indent=2)
openai_assistants_api = app.openapi_schema.copy()
openai_assistants_api["paths"] = {
key: value
for key, value in openai_assistants_api["paths"].items()
if not (key.startswith(API_PREFIX) or key.startswith(ADMIN_PREFIX))
}
openai_assistants_api["info"]["title"] = "OpenAI Assistants API"
with open("openapi_assistants.json", "w", encoding="utf-8") as file:
print(f"Writing out openapi_assistants.json file")
json.dump(openai_assistants_api, file, indent=2)
@app.on_event("shutdown")
def on_shutdown():
global server
if server:
server.save_agents()
server = None
def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"):
"""Generate a self-signed SSL certificate.
NOTE: intended to be used for development only.
"""
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:4096",
"-keyout",
key_path,
"-out",
cert_path,
"-days",
"365",
"-nodes",
"-subj",
"/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost",
],
check=True,
)
return cert_path, key_path
def start_server(
port: Optional[int] = None,
host: Optional[str] = None,
use_ssl: bool = False,
ssl_cert: Optional[str] = None,
ssl_key: Optional[str] = None,
debug: bool = False,
):
print("DEBUG", debug)
if debug:
from memgpt.server.server import logger as server_logger
# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)
if use_ssl:
if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
ssl_certfile, ssl_keyfile = generate_self_signed_cert()
print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
else:
ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
# This will start the server on HTTPS
assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
print(
f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
)
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
)
else:
# Start the subprocess in a new session
print(f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
)

View File

@@ -2,7 +2,7 @@
echo "Starting MEMGPT server..."
if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then
echo "Starting in development mode!"
uvicorn memgpt.server.rest_api.server:app --reload --reload-dir /memgpt --host 0.0.0.0 --port 8083
uvicorn memgpt.server.rest_api.app:app --reload --reload-dir /memgpt --host 0.0.0.0 --port 8083
else
uvicorn memgpt.server.rest_api.server:app --host 0.0.0.0 --port 8083
uvicorn memgpt.server.rest_api.app:app --host 0.0.0.0 --port 8083
fi

View File

@@ -12,7 +12,7 @@ test_server_token = "test_server_token"
def run_server():
from memgpt.server.rest_api.server import start_server
from memgpt.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)

View File

@@ -37,7 +37,7 @@ def run_server():
# _reset_config()
from memgpt.server.rest_api.server import start_server
from memgpt.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)

View File

@@ -32,7 +32,7 @@ def run_server():
# _reset_config()
from memgpt.server.rest_api.server import start_server
from memgpt.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)