Files
letta-server/letta/server/db.py

106 lines
3.0 KiB
Python

import uuid
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import NullPool
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from letta.database_utils import get_database_uri_for_context
from letta.settings import settings
# Convert PostgreSQL URI to async format using common utility
async_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "async")
# Build engine configuration based on settings
engine_args = {
"echo": settings.pg_echo,
"pool_pre_ping": settings.pool_pre_ping,
}
# Configure pooling
if settings.disable_sqlalchemy_pooling:
engine_args["poolclass"] = NullPool
else:
# Use default AsyncAdaptedQueuePool with configured settings
engine_args.update(
{
"pool_size": settings.pg_pool_size,
"max_overflow": settings.pg_max_overflow,
"pool_timeout": settings.pg_pool_timeout,
"pool_recycle": settings.pg_pool_recycle,
}
)
# Add asyncpg-specific settings for connection
if not settings.disable_sqlalchemy_pooling:
engine_args["connect_args"] = {
"timeout": settings.pg_pool_timeout,
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__",
"statement_cache_size": 0,
"prepared_statement_cache_size": 0,
}
# Create the engine once at module level
engine: AsyncEngine = create_async_engine(async_pg_uri, **engine_args)
# Create session factory once at module level
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
class DatabaseRegistry:
"""Dummy registry to maintain the existing interface."""
@asynccontextmanager
async def async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Get an async database session."""
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
session.expunge_all()
await session.close()
# Create singleton instance to match existing interface
db_registry = DatabaseRegistry()
# Backwards compatibility function
def get_db_registry() -> DatabaseRegistry:
"""Get the global database registry instance."""
return db_registry
# FastAPI dependency helper
async def get_db_async() -> AsyncGenerator[AsyncSession, None]:
"""Get an async database session."""
async with db_registry.async_session() as session:
yield session
# Optional: cleanup function for graceful shutdown
async def close_db() -> None:
"""Close the database engine."""
await engine.dispose()
# Usage remains the same:
# async with db_registry.async_session() as session:
# result = await session.execute(select(User))
# users = result.scalars().all()