Files
letta-server/letta/server/db.py

146 lines
4.9 KiB
Python

import asyncio
import uuid
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import NullPool
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from letta.database_utils import get_database_uri_for_context
from letta.log import get_logger
from letta.settings import settings
logger = get_logger(__name__)
# Convert PostgreSQL URI to async format using common utility
async_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "async")
# Build engine configuration based on settings
engine_args = {
"echo": settings.pg_echo,
"pool_pre_ping": settings.pool_pre_ping,
}
# Configure pooling
if settings.disable_sqlalchemy_pooling:
engine_args["poolclass"] = NullPool
else:
# Use default AsyncAdaptedQueuePool with configured settings
engine_args.update(
{
"pool_size": settings.pg_pool_size,
"max_overflow": settings.pg_max_overflow,
"pool_timeout": settings.pg_pool_timeout,
"pool_recycle": settings.pg_pool_recycle,
}
)
# Add asyncpg-specific settings for connection
if not settings.disable_sqlalchemy_pooling:
connect_args = {
"timeout": settings.pg_pool_timeout,
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__",
"statement_cache_size": 0,
"prepared_statement_cache_size": 0,
}
# Only add SSL if not already specified in connection string
if "sslmode" not in async_pg_uri and "ssl" not in async_pg_uri:
connect_args["ssl"] = "require"
engine_args["connect_args"] = connect_args
# Create the engine once at module level
engine: AsyncEngine = create_async_engine(async_pg_uri, **engine_args)
# Create session factory once at module level
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
class DatabaseRegistry:
"""Dummy registry to maintain the existing interface."""
@asynccontextmanager
async def async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Get an async database session.
Note: We explicitly handle asyncio.CancelledError separately because it's
a BaseException (not Exception) in Python 3.8+. Without this, cancelled
tasks would skip rollback() and return connections to the pool with
uncommitted transactions, causing "idle in transaction" connection leaks.
Implements retry logic for transient connection errors (e.g., SSL handshake failures).
"""
max_retries = 3
retry_delay = 0.1
for attempt in range(max_retries):
try:
async with async_session_factory() as session:
try:
yield session
await session.commit()
except asyncio.CancelledError:
# Task was cancelled (client disconnect, timeout, explicit cancellation)
# Must rollback to avoid returning connection with open transaction
await session.rollback()
raise
except Exception:
await session.rollback()
raise
finally:
session.expunge_all()
await session.close()
return
except ConnectionError as e:
if attempt < max_retries - 1:
logger.warning(f"Database connection error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay *= 2
else:
logger.error(f"Database connection failed after {max_retries} attempts: {e}")
from letta.errors import LettaServiceUnavailableError
raise LettaServiceUnavailableError(
"Database connection temporarily unavailable. Please retry your request.", service_name="database"
) from e
# Create singleton instance to match existing interface
db_registry = DatabaseRegistry()
# Backwards compatibility function
def get_db_registry() -> DatabaseRegistry:
"""Get the global database registry instance."""
return db_registry
# FastAPI dependency helper
async def get_db_async() -> AsyncGenerator[AsyncSession, None]:
"""Get an async database session."""
async with db_registry.async_session() as session:
yield session
# Optional: cleanup function for graceful shutdown
async def close_db() -> None:
"""Close the database engine."""
await engine.dispose()
# Usage remains the same:
# async with db_registry.async_session() as session:
# result = await session.execute(select(User))
# users = result.scalars().all()