Files
letta-server/letta/server/db.py
Andy Li 80f6e97ca9 feat: otel metrics and expanded collecting (#2647)
(passed tests in last run)
2025-06-05 17:20:14 -07:00

273 lines
10 KiB
Python

import os
import threading
import uuid
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, Generator
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import Engine, NullPool, QueuePool, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.settings import settings
logger = get_logger(__name__)
def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
console = Console()
error_text = Text()
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
error_text.append(") or use Postgres by setting ", style="white")
error_text.append("LETTA_PG_URI", style="yellow")
error_text.append(".\n\n", style="white")
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
error_text.append(" or downgrade to your previous version of Letta.", style="white")
console.print(Panel(error_text, border_style="red"))
@contextmanager
def db_error_handler():
"""Context manager for handling database errors"""
try:
yield
except Exception as e:
# Handle other SQLAlchemy errors
print(e)
print_sqlite_schema_error()
# raise ValueError(f"SQLite DB error: {str(e)}")
exit(1)
class DatabaseRegistry:
"""Registry for database connections and sessions.
This class manages both synchronous and asynchronous database connections
and provides context managers for session handling.
"""
def __init__(self):
self._engines: dict[str, Engine] = {}
self._async_engines: dict[str, AsyncEngine] = {}
self._session_factories: dict[str, sessionmaker] = {}
self._async_session_factories: dict[str, async_sessionmaker] = {}
self._initialized: dict[str, bool] = {"sync": False, "async": False}
self._lock = threading.Lock()
self.config = LettaConfig.load()
self.logger = get_logger(__name__)
def initialize_sync(self, force: bool = False) -> None:
"""Initialize the synchronous database engine if not already initialized."""
with self._lock:
if self._initialized.get("sync") and not force:
return
# Postgres engine
if settings.letta_pg_uri_no_default:
self.logger.info("Creating postgres engine")
self.config.recall_storage_type = "postgres"
self.config.recall_storage_uri = settings.letta_pg_uri_no_default
self.config.archival_storage_type = "postgres"
self.config.archival_storage_uri = settings.letta_pg_uri_no_default
engine = create_engine(settings.letta_pg_uri, **self._build_sqlalchemy_engine_args(is_async=False))
self._engines["default"] = engine
# SQLite engine
else:
from letta.orm import Base
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(self.config.recall_storage_path, "sqlite.db")
self.logger.info("Creating sqlite engine " + engine_path)
engine = create_engine(engine_path)
# Wrap the engine with error handling
self._wrap_sqlite_engine(engine)
Base.metadata.create_all(bind=engine)
self._engines["default"] = engine
# Create session factory
self._session_factories["default"] = sessionmaker(autocommit=False, autoflush=False, bind=self._engines["default"])
self._initialized["sync"] = True
def initialize_async(self, force: bool = False) -> None:
"""Initialize the asynchronous database engine if not already initialized."""
with self._lock:
if self._initialized.get("async") and not force:
return
if settings.letta_pg_uri_no_default:
self.logger.info("Creating async postgres engine")
# Create async engine - convert URI to async format
pg_uri = settings.letta_pg_uri
if pg_uri.startswith("postgresql://"):
async_pg_uri = pg_uri.replace("postgresql://", "postgresql+asyncpg://")
else:
async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri
async_pg_uri = async_pg_uri.replace("sslmode=", "ssl=")
async_engine = create_async_engine(async_pg_uri, **self._build_sqlalchemy_engine_args(is_async=True))
else:
# create sqlite async engine
self._initialized["async"] = False
# TODO: remove self.config
engine_path = "sqlite+aiosqlite:///" + os.path.join(self.config.recall_storage_path, "sqlite.db")
self.logger.info("Creating sqlite engine " + engine_path)
async_engine = create_async_engine(engine_path, **self._build_sqlalchemy_engine_args(is_async=True))
# Create async session factory
self._async_engines["default"] = async_engine
self._async_session_factories["default"] = async_sessionmaker(
expire_on_commit=True,
close_resets_only=False,
autocommit=False,
autoflush=False,
bind=self._async_engines["default"],
class_=AsyncSession,
)
self._initialized["async"] = True
def _build_sqlalchemy_engine_args(self, *, is_async: bool) -> dict:
"""Prepare keyword arguments for create_engine / create_async_engine."""
use_null_pool = settings.disable_sqlalchemy_pooling
if use_null_pool:
logger.info("Disabling pooling on SqlAlchemy")
pool_cls = NullPool
else:
logger.info("Enabling pooling on SqlAlchemy")
pool_cls = QueuePool if not is_async else None
base_args = {
"echo": settings.pg_echo,
"pool_pre_ping": settings.pool_pre_ping,
}
if pool_cls:
base_args["poolclass"] = pool_cls
if not use_null_pool:
base_args.update(
{
"pool_size": settings.pg_pool_size,
"max_overflow": settings.pg_max_overflow,
"pool_timeout": settings.pg_pool_timeout,
"pool_recycle": settings.pg_pool_recycle,
}
)
if not is_async:
base_args.update(
{
"pool_use_lifo": settings.pool_use_lifo,
}
)
elif is_async:
# For asyncpg, statement_cache_size should be in connect_args
base_args.update(
{
"connect_args": {
"timeout": settings.pg_pool_timeout,
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__",
"statement_cache_size": 0,
"prepared_statement_cache_size": 0,
},
}
)
return base_args
def _wrap_sqlite_engine(self, engine: Engine) -> None:
"""Wrap SQLite engine with error handling."""
original_connect = engine.connect
def wrapped_connect(*args, **kwargs):
with db_error_handler():
connection = original_connect(*args, **kwargs)
original_execute = connection.execute
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
connection.execute = wrapped_execute
return connection
engine.connect = wrapped_connect
def get_engine(self, name: str = "default") -> Engine:
"""Get a database engine by name."""
self.initialize_sync()
return self._engines.get(name)
def get_session_factory(self, name: str = "default") -> sessionmaker:
"""Get a session factory by name."""
self.initialize_sync()
return self._session_factories.get(name)
def get_async_session_factory(self, name: str = "default") -> async_sessionmaker:
"""Get an async session factory by name."""
self.initialize_async()
return self._async_session_factories.get(name)
@trace_method
@contextmanager
def session(self, name: str = "default") -> Generator[Any, None, None]:
"""Context manager for database sessions."""
session_factory = self.get_session_factory(name)
if not session_factory:
raise ValueError(f"No session factory found for '{name}'")
session = session_factory()
try:
yield session
finally:
session.close()
@trace_method
@asynccontextmanager
async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]:
"""Async context manager for database sessions."""
session_factory = self.get_async_session_factory(name)
if not session_factory:
raise ValueError(f"No async session factory found for '{name}' or async database is not configured")
session = session_factory()
try:
yield session
finally:
await session.close()
# Create a singleton instance
db_registry = DatabaseRegistry()
def get_db():
"""Get a database session."""
with db_registry.session() as session:
yield session
async def get_db_async():
"""Get an async database session."""
async with db_registry.async_session() as session:
yield session
# Prefer calling db_registry.session() or db_registry.async_session() directly
# This is for backwards compatibility
db_context = contextmanager(get_db)