feat: Add lazy initialization to db.py (#1716)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
from rich.console import Console
|
||||
@@ -10,13 +11,17 @@ from sqlalchemy.orm import sessionmaker
|
||||
from letta.config import LettaConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Base
|
||||
|
||||
# NOTE: hack to see if single session management works
|
||||
from letta.settings import settings
|
||||
|
||||
config = LettaConfig.load()
|
||||
# Use globals for the lock and initialization flag
|
||||
_engine_lock = threading.Lock()
|
||||
_engine_initialized = False
|
||||
|
||||
# Create variables in global scope but don't initialize them yet
|
||||
config = LettaConfig.load()
|
||||
logger = get_logger(__name__)
|
||||
engine = None
|
||||
SessionLocal = None
|
||||
|
||||
|
||||
def print_sqlite_schema_error():
|
||||
@@ -49,59 +54,80 @@ def db_error_handler():
|
||||
exit(1)
|
||||
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
print("Creating postgres engine")
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
config.archival_storage_type = "postgres"
|
||||
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
def initialize_engine():
|
||||
"""Initialize the database engine only when needed."""
|
||||
global engine, SessionLocal, _engine_initialized
|
||||
|
||||
# create engine
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
# connect_args={"client_encoding": "utf8"},
|
||||
)
|
||||
else:
|
||||
# TODO: don't rely on config storage
|
||||
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
||||
logger.info("Creating sqlite engine " + engine_path)
|
||||
with _engine_lock:
|
||||
# Check again inside the lock to prevent race conditions
|
||||
if _engine_initialized:
|
||||
return
|
||||
|
||||
engine = create_engine(engine_path)
|
||||
if settings.letta_pg_uri_no_default:
|
||||
logger.info("Creating postgres engine")
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
config.archival_storage_type = "postgres"
|
||||
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
|
||||
# Store the original connect method
|
||||
original_connect = engine.connect
|
||||
# create engine
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
# connect_args={"client_encoding": "utf8"},
|
||||
)
|
||||
else:
|
||||
# TODO: don't rely on config storage
|
||||
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
||||
logger.info("Creating sqlite engine " + engine_path)
|
||||
|
||||
def wrapped_connect(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
# Get the connection
|
||||
connection = original_connect(*args, **kwargs)
|
||||
engine = create_engine(engine_path)
|
||||
|
||||
# Store the original execution method
|
||||
original_execute = connection.execute
|
||||
# Store the original connect method
|
||||
original_connect = engine.connect
|
||||
|
||||
# Wrap the execute method of the connection
|
||||
def wrapped_execute(*args, **kwargs):
|
||||
def wrapped_connect(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
return original_execute(*args, **kwargs)
|
||||
# Get the connection
|
||||
connection = original_connect(*args, **kwargs)
|
||||
|
||||
# Replace the connection's execute method
|
||||
connection.execute = wrapped_execute
|
||||
# Store the original execution method
|
||||
original_execute = connection.execute
|
||||
|
||||
return connection
|
||||
# Wrap the execute method of the connection
|
||||
def wrapped_execute(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the engine's connect method
|
||||
engine.connect = wrapped_connect
|
||||
# Replace the connection's execute method
|
||||
connection.execute = wrapped_execute
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return connection
|
||||
|
||||
# Replace the engine's connect method
|
||||
engine.connect = wrapped_connect
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create the session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
_engine_initialized = True
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Get a database session, initializing the engine if needed."""
|
||||
global engine, SessionLocal
|
||||
|
||||
# Make sure engine is initialized
|
||||
if not _engine_initialized:
|
||||
initialize_engine()
|
||||
|
||||
# Now SessionLocal should be defined and callable
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
@@ -109,5 +135,5 @@ def get_db():
|
||||
db.close()
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
# Define db_context as a context manager that uses get_db
|
||||
db_context = contextmanager(get_db)
|
||||
|
||||
Reference in New Issue
Block a user