diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 2847aa3b..81602119 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -1,10 +1,8 @@ import json import logging import os -import subprocess import sys from enum import Enum -from pathlib import Path from typing import Annotated, Optional import questionary @@ -24,7 +22,6 @@ from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.schemas.enums import OptionState from memgpt.schemas.llm_config import LLMConfig from memgpt.schemas.memory import ChatMemory, Memory -from memgpt.server.constants import WS_DEFAULT_PORT from memgpt.server.server import logger as server_logger # from memgpt.interface import CLIInterface as interface # for printing to terminal @@ -304,9 +301,6 @@ def server( type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest", port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None, host: Annotated[Optional[str], typer.Option(help="Host to run the server on (default to localhost)")] = None, - use_ssl: Annotated[bool, typer.Option(help="Run the server using HTTPS?")] = False, - ssl_cert: Annotated[Optional[str], typer.Option(help="Path to SSL certificate (if use_ssl is True)")] = None, - ssl_key: Annotated[Optional[str], typer.Option(help="Path to SSL key file (if use_ssl is True)")] = None, debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False, ): """Launch a MemGPT server process""" @@ -317,22 +311,15 @@ def server( if MemGPTConfig.exists(): config = MemGPTConfig.load() MetadataStore(config) - client = create_client() # triggers user creation + _ = create_client() # triggers user creation else: typer.secho(f"No configuration exists. Run memgpt configure before starting the server.", fg=typer.colors.RED) sys.exit(1) try: - from memgpt.server.rest_api.server import start_server + from memgpt.server.rest_api.app import start_server - start_server( - port=port, - host=host, - use_ssl=use_ssl, - ssl_cert=ssl_cert, - ssl_key=ssl_key, - debug=debug, - ) + start_server(port=port, host=host, debug=debug) except KeyboardInterrupt: # Handle CTRL-C @@ -340,48 +327,7 @@ def server( sys.exit(0) elif type == ServerChoice.ws_api: - if debug: - from memgpt.server.server import logger as server_logger - - # Set the logging level - server_logger.setLevel(logging.DEBUG) - # Create a StreamHandler - stream_handler = logging.StreamHandler() - # Set the formatter (optional) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - stream_handler.setFormatter(formatter) - # Add the handler to the logger - server_logger.addHandler(stream_handler) - - if port is None: - port = WS_DEFAULT_PORT - - # Change to the desired directory - script_path = Path(__file__).resolve() - script_dir = script_path.parent - - server_directory = os.path.join(script_dir.parent, "server", "ws_api") - command = f"python server.py {port}" - - # Run the command - typer.secho(f"Running WS (websockets) server: {command} (inside {server_directory})") - - process = None - try: - # Start the subprocess in a new session - process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory) - process.wait() - except KeyboardInterrupt: - # Handle CTRL-C - if process is not None: - typer.secho("Terminating the server...") - process.terminate() - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - process.kill() - typer.secho("Server terminated with kill()") - sys.exit(0) + raise NotImplementedError("WS suppport deprecated") def run( diff --git a/memgpt/server/rest_api/app.py b/memgpt/server/rest_api/app.py new file mode 100644 index 00000000..2f32397e --- /dev/null +++ b/memgpt/server/rest_api/app.py @@ -0,0 +1,169 @@ +import json +import logging +import secrets +from pathlib import Path +from typing import Optional + +import typer +import uvicorn +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from memgpt.server.constants import REST_DEFAULT_PORT + +# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests +from memgpt.server.rest_api.auth.index import ( + setup_auth_router, # TODO: probably remove right? +) +from memgpt.server.rest_api.interface import StreamingServerInterface +from memgpt.server.rest_api.routers.openai.assistants.assistants import ( + router as openai_assistants_router, +) +from memgpt.server.rest_api.routers.openai.assistants.threads import ( + router as openai_threads_router, +) +from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import ( + router as openai_chat_completions_router, +) + +# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM +from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes +from memgpt.server.rest_api.routers.v1.users import ( + router as users_router, # TODO: decide on admin +) +from memgpt.server.rest_api.static_files import mount_static_files +from memgpt.server.server import SyncServer +from memgpt.settings import settings + +# TODO(ethan) +# NOTE(charles): @ethan I had to add this to get the global as the bottom to work +interface: StreamingServerInterface = StreamingServerInterface +server: SyncServer = SyncServer(default_interface_factory=lambda: interface()) + +# TODO(ethan): eventuall remove +if password := settings.server_pass: + # if the pass was specified in the environment, use it + print(f"Using existing admin server password from environment.") +else: + # Autogenerate a password for this session and dump it to stdout + password = secrets.token_urlsafe(16) + typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN) + + +ADMIN_PREFIX = "/v1/admin" +API_PREFIX = "/v1" +OPENAI_API_PREFIX = "/openai" + + +def create_application() -> "FastAPI": + """the application start routine""" + + app = FastAPI( + swagger_ui_parameters={"docExpansion": "none"}, + # openapi_tags=TAGS_METADATA, + title="MemGPT", + summary="Create LLM agents with long-term memory and custom tools 📚🦙", + version="1.0.0", # TODO wire this up to the version in the package + ) + app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + for route in v1_routes: + app.include_router(route, prefix=API_PREFIX) + # this gives undocumented routes for "latest" and bare api calls. + # we should always tie this to the newest version of the api. + app.include_router(route, prefix="", include_in_schema=False) + app.include_router(route, prefix="/latest", include_in_schema=False) + + # NOTE: ethan these are the extra routes + # TODO(ethan) remove + + # admin/users + app.include_router(users_router, prefix=ADMIN_PREFIX) + + # openai + app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX) + app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX) + app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX) + + # /api/auth endpoints + app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX) + + # / static files + mount_static_files(app) + + @app.on_event("startup") + def on_startup(): + # load the default tools + # from memgpt.orm.tool import Tool + + # Tool.load_default_tools(get_db_session()) + + # Update the OpenAPI schema + if not app.openapi_schema: + app.openapi_schema = app.openapi() + + openai_docs, memgpt_docs = [app.openapi_schema.copy() for _ in range(2)] + + openai_docs["paths"] = {k: v for k, v in openai_docs["paths"].items() if k.startswith("/openai")} + openai_docs["info"]["title"] = "OpenAI Assistants API" + memgpt_docs["paths"] = {k: v for k, v in memgpt_docs["paths"].items() if not k.startswith("/openai")} + memgpt_docs["info"]["title"] = "MemGPT API" + + # Split the API docs into MemGPT API, and OpenAI Assistants compatible API + for name, docs in [ + ( + "openai", + openai_docs, + ), + ( + "memgpt", + memgpt_docs, + ), + ]: + if settings.cors_origins: + docs["servers"] = [{"url": host} for host in settings.cors_origins] + Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2)) + + @app.on_event("shutdown") + def on_shutdown(): + global server + server.save_agents() + server = None + + return app + + +app = create_application() + + +def start_server( + port: Optional[int] = None, + host: Optional[str] = None, + debug: bool = False, +): + """Convenience method to start the server from within Python""" + if debug: + from memgpt.server.server import logger as server_logger + + # Set the logging level + server_logger.setLevel(logging.DEBUG) + # Create a StreamHandler + stream_handler = logging.StreamHandler() + # Set the formatter (optional) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + stream_handler.setFormatter(formatter) + # Add the handler to the logger + server_logger.addHandler(stream_handler) + + print(f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}") + uvicorn.run( + app, + host=host or "localhost", + port=port or REST_DEFAULT_PORT, + ) diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py deleted file mode 100644 index 42ae5732..00000000 --- a/memgpt/server/rest_api/server.py +++ /dev/null @@ -1,241 +0,0 @@ -import json -import logging -import os -import secrets -import subprocess -from typing import Optional - -import typer -import uvicorn -from fastapi import Depends, FastAPI, HTTPException -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from starlette.middleware.cors import CORSMiddleware - -from memgpt.server.constants import REST_DEFAULT_PORT -from memgpt.server.rest_api.auth.index import ( - setup_auth_router, # TODO: probably remove right? -) -from memgpt.server.rest_api.interface import StreamingServerInterface -from memgpt.server.rest_api.routers.openai.assistants.assistants import ( - router as openai_assistants_router, -) -from memgpt.server.rest_api.routers.openai.assistants.threads import ( - router as openai_threads_router, -) -from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import ( - router as openai_chat_completions_router, -) -from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes -from memgpt.server.rest_api.routers.v1.users import ( - router as users_router, # TODO: decide on admin -) -from memgpt.server.rest_api.static_files import mount_static_files -from memgpt.server.server import SyncServer -from memgpt.settings import settings - -""" -Basic REST API sitting on top of the internal MemGPT python server (SyncServer) - -Start the server with: - cd memgpt/server/rest_api - poetry run uvicorn server:app --reload -""" - -interface: StreamingServerInterface = StreamingServerInterface -server: SyncServer = SyncServer(default_interface_factory=lambda: interface()) - -if password := settings.server_pass: - # if the pass was specified in the environment, use it - print(f"Using existing admin server password from environment.") -else: - # Autogenerate a password for this session and dump it to stdout - password = secrets.token_urlsafe(16) - typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN) - -security = HTTPBearer() - - -def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security)): - """REST requests going to /admin are protected with a bearer token (that must match the password)""" - if credentials.credentials != password: - raise HTTPException(status_code=401, detail="Unauthorized") - - -ADMIN_PREFIX = "/v1/admin" -API_PREFIX = "/v1" -OPENAI_API_PREFIX = "/openai" - -app = FastAPI() - -app.add_middleware( - CORSMiddleware, - allow_origins=settings.cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# v1_routes are the MemGPT API routes -for route in v1_routes: - app.include_router(route, prefix=API_PREFIX) - # this gives undocumented routes for "latest" and bare api calls. - # we should always tie this to the newest version of the api. - app.include_router(route, prefix="", include_in_schema=False) - app.include_router(route, prefix="/latest", include_in_schema=False) - -# admin/users -app.include_router(users_router, prefix=ADMIN_PREFIX) - -# openai -app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX) -app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX) -app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX) - -# /api/auth endpoints -app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX) - -# # Serve static files -# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") -# app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="static") - - -# # Serve favicon -# @app.get("/favicon.ico") -# async def favicon(): -# return FileResponse(os.path.join(static_files_path, "favicon.ico")) - - -# # Middleware to handle API routes first -# @app.middleware("http") -# async def handle_api_routes(request: Request, call_next): -# if request.url.path.startswith(("/v1/", "/openai/")): -# response = await call_next(request) -# if response.status_code != 404: -# return response -# return await serve_spa(request.url.path) - - -# # Catch-all route for SPA -# async def serve_spa(full_path: str): -# return FileResponse(os.path.join(static_files_path, "index.html")) - - -mount_static_files(app) - - -@app.on_event("startup") -def on_startup(): - # Update the OpenAPI schema - if not app.openapi_schema: - app.openapi_schema = app.openapi() - - if app.openapi_schema: - app.openapi_schema["servers"] = [{"url": host} for host in settings.cors_origins] - app.openapi_schema["info"]["title"] = "MemGPT API" - - # Split the API docs into MemGPT API, and OpenAI Assistants compatible API - memgpt_api = app.openapi_schema.copy() - memgpt_api["paths"] = {key: value for key, value in memgpt_api["paths"].items() if not key.startswith(OPENAI_API_PREFIX)} - memgpt_api["info"]["title"] = "MemGPT API" - with open("openapi_memgpt.json", "w", encoding="utf-8") as file: - print(f"Writing out openapi_memgpt.json file") - json.dump(memgpt_api, file, indent=2) - - openai_assistants_api = app.openapi_schema.copy() - openai_assistants_api["paths"] = { - key: value - for key, value in openai_assistants_api["paths"].items() - if not (key.startswith(API_PREFIX) or key.startswith(ADMIN_PREFIX)) - } - openai_assistants_api["info"]["title"] = "OpenAI Assistants API" - with open("openapi_assistants.json", "w", encoding="utf-8") as file: - print(f"Writing out openapi_assistants.json file") - json.dump(openai_assistants_api, file, indent=2) - - -@app.on_event("shutdown") -def on_shutdown(): - global server - if server: - server.save_agents() - server = None - - -def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"): - """Generate a self-signed SSL certificate. - - NOTE: intended to be used for development only. - """ - subprocess.run( - [ - "openssl", - "req", - "-x509", - "-newkey", - "rsa:4096", - "-keyout", - key_path, - "-out", - cert_path, - "-days", - "365", - "-nodes", - "-subj", - "/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost", - ], - check=True, - ) - return cert_path, key_path - - -def start_server( - port: Optional[int] = None, - host: Optional[str] = None, - use_ssl: bool = False, - ssl_cert: Optional[str] = None, - ssl_key: Optional[str] = None, - debug: bool = False, -): - print("DEBUG", debug) - if debug: - from memgpt.server.server import logger as server_logger - - # Set the logging level - server_logger.setLevel(logging.DEBUG) - # Create a StreamHandler - stream_handler = logging.StreamHandler() - # Set the formatter (optional) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - stream_handler.setFormatter(formatter) - # Add the handler to the logger - server_logger.addHandler(stream_handler) - - if use_ssl: - if ssl_cert is None: # No certificate path provided, generate a self-signed certificate - ssl_certfile, ssl_keyfile = generate_self_signed_cert() - print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}") - else: - ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both - print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}") - - # This will start the server on HTTPS - assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile - assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile - print( - f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}" - ) - uvicorn.run( - app, - host=host or "localhost", - port=port or REST_DEFAULT_PORT, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ) - else: - # Start the subprocess in a new session - print(f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}") - uvicorn.run( - app, - host=host or "localhost", - port=port or REST_DEFAULT_PORT, - ) diff --git a/memgpt/server/startup.sh b/memgpt/server/startup.sh index f91c669a..ed46be79 100755 --- a/memgpt/server/startup.sh +++ b/memgpt/server/startup.sh @@ -2,7 +2,7 @@ echo "Starting MEMGPT server..." if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then echo "Starting in development mode!" - uvicorn memgpt.server.rest_api.server:app --reload --reload-dir /memgpt --host 0.0.0.0 --port 8083 + uvicorn memgpt.server.rest_api.app:app --reload --reload-dir /memgpt --host 0.0.0.0 --port 8083 else - uvicorn memgpt.server.rest_api.server:app --host 0.0.0.0 --port 8083 + uvicorn memgpt.server.rest_api.app:app --host 0.0.0.0 --port 8083 fi diff --git a/tests/test_admin_client.py b/tests/test_admin_client.py index 39669c12..8df80f61 100644 --- a/tests/test_admin_client.py +++ b/tests/test_admin_client.py @@ -12,7 +12,7 @@ test_server_token = "test_server_token" def run_server(): - from memgpt.server.rest_api.server import start_server + from memgpt.server.rest_api.app import start_server print("Starting server...") start_server(debug=True) diff --git a/tests/test_client.py b/tests/test_client.py index 6111149f..ff47d4ac 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -37,7 +37,7 @@ def run_server(): # _reset_config() - from memgpt.server.rest_api.server import start_server + from memgpt.server.rest_api.app import start_server print("Starting server...") start_server(debug=True) diff --git a/tests/test_tools.py b/tests/test_tools.py index 4363f22f..71e76f34 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -32,7 +32,7 @@ def run_server(): # _reset_config() - from memgpt.server.rest_api.server import start_server + from memgpt.server.rest_api.app import start_server print("Starting server...") start_server(debug=True)