feat: Add lazy initialization to db.py (#1716)

This commit is contained in:
Matthew Zhou
2025-04-15 12:01:46 -07:00
committed by GitHub
parent 99b8ebd79c
commit b07bda4c19

View File

@@ -1,4 +1,5 @@
import os
import threading
from contextlib import contextmanager
from rich.console import Console
@@ -10,13 +11,17 @@ from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
from letta.log import get_logger
from letta.orm import Base
# NOTE: hack to see if single session management works
from letta.settings import settings
config = LettaConfig.load()
# Use globals for the lock and initialization flag
_engine_lock = threading.Lock()
_engine_initialized = False
# Create variables in global scope but don't initialize them yet
config = LettaConfig.load()
logger = get_logger(__name__)
engine = None
SessionLocal = None
def print_sqlite_schema_error():
@@ -49,59 +54,80 @@ def db_error_handler():
exit(1)
if settings.letta_pg_uri_no_default:
print("Creating postgres engine")
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
config.archival_storage_type = "postgres"
config.archival_storage_uri = settings.letta_pg_uri_no_default
def initialize_engine():
"""Initialize the database engine only when needed."""
global engine, SessionLocal, _engine_initialized
# create engine
engine = create_engine(
settings.letta_pg_uri,
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
# connect_args={"client_encoding": "utf8"},
)
else:
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
logger.info("Creating sqlite engine " + engine_path)
with _engine_lock:
# Check again inside the lock to prevent race conditions
if _engine_initialized:
return
engine = create_engine(engine_path)
if settings.letta_pg_uri_no_default:
logger.info("Creating postgres engine")
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
config.archival_storage_type = "postgres"
config.archival_storage_uri = settings.letta_pg_uri_no_default
# Store the original connect method
original_connect = engine.connect
# create engine
engine = create_engine(
settings.letta_pg_uri,
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
# connect_args={"client_encoding": "utf8"},
)
else:
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
logger.info("Creating sqlite engine " + engine_path)
def wrapped_connect(*args, **kwargs):
with db_error_handler():
# Get the connection
connection = original_connect(*args, **kwargs)
engine = create_engine(engine_path)
# Store the original execution method
original_execute = connection.execute
# Store the original connect method
original_connect = engine.connect
# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
def wrapped_connect(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
# Get the connection
connection = original_connect(*args, **kwargs)
# Replace the connection's execute method
connection.execute = wrapped_execute
# Store the original execution method
original_execute = connection.execute
return connection
# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
# Replace the engine's connect method
engine.connect = wrapped_connect
# Replace the connection's execute method
connection.execute = wrapped_execute
Base.metadata.create_all(bind=engine)
return connection
# Replace the engine's connect method
engine.connect = wrapped_connect
Base.metadata.create_all(bind=engine)
# Create the session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
_engine_initialized = True
def get_db():
"""Get a database session, initializing the engine if needed."""
global engine, SessionLocal
# Make sure engine is initialized
if not _engine_initialized:
initialize_engine()
# Now SessionLocal should be defined and callable
db = SessionLocal()
try:
yield db
@@ -109,5 +135,5 @@ def get_db():
db.close()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Define db_context as a context manager that uses get_db
db_context = contextmanager(get_db)