From a61a67bc47e2ab100c6c18e3cb3fdb132e37f851 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Fri, 26 Apr 2024 16:00:36 -0700 Subject: [PATCH] chore: better database errors (#1299) --- memgpt/cli/cli.py | 42 ------------------------- memgpt/metadata.py | 53 ++++++++++++++++++++++---------- memgpt/server/rest_api/server.py | 6 ---- 3 files changed, 37 insertions(+), 64 deletions(-) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index c39c47cd..a9ff31d4 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -310,19 +310,6 @@ def server( ): """Launch a MemGPT server process""" - # if debug: - # from memgpt.server.server import logger as server_logger - - # # Set the logging level - # server_logger.setLevel(logging.DEBUG) - # # Create a StreamHandler - # stream_handler = logging.StreamHandler() - # # Set the formatter (optional) - # formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - # stream_handler.setFormatter(formatter) - # # Add the handler to the logger - # server_logger.addHandler(stream_handler) - if type == ServerChoice.rest_api: pass @@ -345,35 +332,6 @@ def server( ssl_key=ssl_key, debug=debug, ) - # if use_ssl: - # if ssl_cert is None: # No certificate path provided, generate a self-signed certificate - # ssl_certfile, ssl_keyfile = generate_self_signed_cert() - # print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}") - # else: - # ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both - # print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}") - - # # This will start the server on HTTPS - # assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile - # assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile - # print( - # f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}" - # ) - # uvicorn.run( - # app, - # host=host or "localhost", - # port=port or REST_DEFAULT_PORT, - # ssl_keyfile=ssl_keyfile, - # ssl_certfile=ssl_certfile, - # ) - # else: - # # Start the subprocess in a new session - # print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}") - # uvicorn.run( - # app, - # host=host or "localhost", - # port=port or REST_DEFAULT_PORT, - # ) except KeyboardInterrupt: # Handle CTRL-C diff --git a/memgpt/metadata.py b/memgpt/metadata.py index f02f953f..bd5cd97a 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -5,11 +5,13 @@ import os import secrets import uuid from typing import List, Optional +import traceback from sqlalchemy import BIGINT, CHAR, JSON, Boolean, Column, DateTime, String, TypeDecorator, create_engine, func, inspect from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.sql import func +from sqlalchemy.exc import InterfaceError from memgpt.config import MemGPTConfig from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, Preset, Source, Token, User @@ -312,22 +314,41 @@ class MetadataStore: # Check if tables need to be created self.engine = create_engine(self.uri) - Base.metadata.create_all( - self.engine, - tables=[ - UserModel.__table__, - AgentModel.__table__, - SourceModel.__table__, - AgentSourceMappingModel.__table__, - TokenModel.__table__, - PresetModel.__table__, - PresetSourceMapping.__table__, - HumanModel.__table__, - PersonaModel.__table__, - ToolModel.__table__, - JobModel.__table__, - ], - ) + try: + Base.metadata.create_all( + self.engine, + tables=[ + UserModel.__table__, + AgentModel.__table__, + SourceModel.__table__, + AgentSourceMappingModel.__table__, + TokenModel.__table__, + PresetModel.__table__, + PresetSourceMapping.__table__, + HumanModel.__table__, + PersonaModel.__table__, + ToolModel.__table__, + JobModel.__table__, + ], + ) + except InterfaceError as e: + traceback.print_exc() + if config.metadata_storage_type == "postgres": + raise ValueError( + f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. " + + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). " + + "\npostgres detected: Make sure the postgres database is running (https://memgpt.readme.io/docs/storage#postgres)." + ) + elif config.metadata_storage_type == "sqlite": + raise ValueError( + f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. " + + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). " + + "\nsqlite detected: Make sure that the sqlite.db file exists at the URI." + ) + else: + raise e + except: + raise self.session_maker = sessionmaker(bind=self.engine) @enforce_types diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 0d24bd23..4a2ea3af 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -10,7 +10,6 @@ from fastapi import Depends, FastAPI, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from starlette.middleware.cors import CORSMiddleware -from memgpt.config import MemGPTConfig from memgpt.server.constants import REST_DEFAULT_PORT from memgpt.server.rest_api.admin.users import setup_admin_router from memgpt.server.rest_api.agents.command import setup_agents_command_router @@ -115,11 +114,6 @@ def on_startup(): app.openapi_schema["servers"] = [{"url": host} for host in settings.cors_origins] app.openapi_schema["info"]["title"] = "MemGPT API" - # Write out the OpenAPI schema to a file - # with open("openapi.json", "w") as file: - # print(f"Writing out openapi.json file") - # json.dump(app.openapi_schema, file, indent=2) - # Split the API docs into MemGPT API, and OpenAI Assistants compatible API memgpt_api = app.openapi_schema.copy() memgpt_api["paths"] = {key: value for key, value in memgpt_api["paths"].items() if not key.startswith(OPENAI_API_PREFIX)}