import asyncio import uuid from contextlib import asynccontextmanager from typing import AsyncGenerator from sqlalchemy import NullPool from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from letta.database_utils import get_database_uri_for_context from letta.log import get_logger from letta.settings import settings logger = get_logger(__name__) # Convert PostgreSQL URI to async format using common utility async_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "async") # Build engine configuration based on settings engine_args = { "echo": settings.pg_echo, "pool_pre_ping": settings.pool_pre_ping, } # Configure pooling if settings.disable_sqlalchemy_pooling: engine_args["poolclass"] = NullPool else: # Use default AsyncAdaptedQueuePool with configured settings engine_args.update( { "pool_size": settings.pg_pool_size, "max_overflow": settings.pg_max_overflow, "pool_timeout": settings.pg_pool_timeout, "pool_recycle": settings.pg_pool_recycle, } ) # Add asyncpg-specific settings for connection if not settings.disable_sqlalchemy_pooling: connect_args = { "timeout": settings.pg_pool_timeout, "prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__", "statement_cache_size": 0, "prepared_statement_cache_size": 0, } # Only add SSL if not already specified in connection string if "sslmode" not in async_pg_uri and "ssl" not in async_pg_uri: connect_args["ssl"] = "require" engine_args["connect_args"] = connect_args # Create the engine once at module level engine: AsyncEngine = create_async_engine(async_pg_uri, **engine_args) # Create session factory once at module level async_session_factory = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) class DatabaseRegistry: """Dummy registry to maintain the existing interface.""" @asynccontextmanager async def async_session(self) -> AsyncGenerator[AsyncSession, None]: """Get an async database session. Note: We explicitly handle asyncio.CancelledError separately because it's a BaseException (not Exception) in Python 3.8+. Without this, cancelled tasks would skip rollback() and return connections to the pool with uncommitted transactions, causing "idle in transaction" connection leaks. Implements retry logic for transient connection errors (e.g., SSL handshake failures). """ max_retries = 3 retry_delay = 0.1 for attempt in range(max_retries): try: async with async_session_factory() as session: try: yield session await session.commit() except asyncio.CancelledError: # Task was cancelled (client disconnect, timeout, explicit cancellation) # Must rollback to avoid returning connection with open transaction await session.rollback() raise except Exception: await session.rollback() raise finally: session.expunge_all() await session.close() return except ConnectionError as e: if attempt < max_retries - 1: logger.warning(f"Database connection error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {retry_delay}s...") await asyncio.sleep(retry_delay) retry_delay *= 2 else: logger.error(f"Database connection failed after {max_retries} attempts: {e}") from letta.errors import LettaServiceUnavailableError raise LettaServiceUnavailableError( "Database connection temporarily unavailable. Please retry your request.", service_name="database" ) from e # Create singleton instance to match existing interface db_registry = DatabaseRegistry() # Backwards compatibility function def get_db_registry() -> DatabaseRegistry: """Get the global database registry instance.""" return db_registry # FastAPI dependency helper async def get_db_async() -> AsyncGenerator[AsyncSession, None]: """Get an async database session.""" async with db_registry.async_session() as session: yield session # Optional: cleanup function for graceful shutdown async def close_db() -> None: """Close the database engine.""" await engine.dispose() # Usage remains the same: # async with db_registry.async_session() as session: # result = await session.execute(select(User)) # users = result.scalars().all()