diff --git a/letta/server/db.py b/letta/server/db.py index 357de19c..87d87300 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -1,4 +1,5 @@ import os +import threading from contextlib import contextmanager from rich.console import Console @@ -10,13 +11,17 @@ from sqlalchemy.orm import sessionmaker from letta.config import LettaConfig from letta.log import get_logger from letta.orm import Base - -# NOTE: hack to see if single session management works from letta.settings import settings -config = LettaConfig.load() +# Use globals for the lock and initialization flag +_engine_lock = threading.Lock() +_engine_initialized = False +# Create variables in global scope but don't initialize them yet +config = LettaConfig.load() logger = get_logger(__name__) +engine = None +SessionLocal = None def print_sqlite_schema_error(): @@ -49,59 +54,80 @@ def db_error_handler(): exit(1) -if settings.letta_pg_uri_no_default: - print("Creating postgres engine") - config.recall_storage_type = "postgres" - config.recall_storage_uri = settings.letta_pg_uri_no_default - config.archival_storage_type = "postgres" - config.archival_storage_uri = settings.letta_pg_uri_no_default +def initialize_engine(): + """Initialize the database engine only when needed.""" + global engine, SessionLocal, _engine_initialized - # create engine - engine = create_engine( - settings.letta_pg_uri, - # f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8", - pool_size=settings.pg_pool_size, - max_overflow=settings.pg_max_overflow, - pool_timeout=settings.pg_pool_timeout, - pool_recycle=settings.pg_pool_recycle, - echo=settings.pg_echo, - # connect_args={"client_encoding": "utf8"}, - ) -else: - # TODO: don't rely on config storage - engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db") - logger.info("Creating sqlite engine " + engine_path) + with _engine_lock: + # Check again inside the lock to prevent race conditions + if _engine_initialized: + return - engine = create_engine(engine_path) + if settings.letta_pg_uri_no_default: + logger.info("Creating postgres engine") + config.recall_storage_type = "postgres" + config.recall_storage_uri = settings.letta_pg_uri_no_default + config.archival_storage_type = "postgres" + config.archival_storage_uri = settings.letta_pg_uri_no_default - # Store the original connect method - original_connect = engine.connect + # create engine + engine = create_engine( + settings.letta_pg_uri, + # f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8", + pool_size=settings.pg_pool_size, + max_overflow=settings.pg_max_overflow, + pool_timeout=settings.pg_pool_timeout, + pool_recycle=settings.pg_pool_recycle, + echo=settings.pg_echo, + # connect_args={"client_encoding": "utf8"}, + ) + else: + # TODO: don't rely on config storage + engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db") + logger.info("Creating sqlite engine " + engine_path) - def wrapped_connect(*args, **kwargs): - with db_error_handler(): - # Get the connection - connection = original_connect(*args, **kwargs) + engine = create_engine(engine_path) - # Store the original execution method - original_execute = connection.execute + # Store the original connect method + original_connect = engine.connect - # Wrap the execute method of the connection - def wrapped_execute(*args, **kwargs): + def wrapped_connect(*args, **kwargs): with db_error_handler(): - return original_execute(*args, **kwargs) + # Get the connection + connection = original_connect(*args, **kwargs) - # Replace the connection's execute method - connection.execute = wrapped_execute + # Store the original execution method + original_execute = connection.execute - return connection + # Wrap the execute method of the connection + def wrapped_execute(*args, **kwargs): + with db_error_handler(): + return original_execute(*args, **kwargs) - # Replace the engine's connect method - engine.connect = wrapped_connect + # Replace the connection's execute method + connection.execute = wrapped_execute - Base.metadata.create_all(bind=engine) + return connection + + # Replace the engine's connect method + engine.connect = wrapped_connect + + Base.metadata.create_all(bind=engine) + + # Create the session factory + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + _engine_initialized = True def get_db(): + """Get a database session, initializing the engine if needed.""" + global engine, SessionLocal + + # Make sure engine is initialized + if not _engine_initialized: + initialize_engine() + + # Now SessionLocal should be defined and callable db = SessionLocal() try: yield db @@ -109,5 +135,5 @@ def get_db(): db.close() -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# Define db_context as a context manager that uses get_db db_context = contextmanager(get_db)