Revert "fix: force statement_timeout=0 on all database connections (#9184)" This reverts commit 0d3d9ea76bae586520ae8f50badb203ffd441675.
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
import asyncio
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator
|
|
|
|
from sqlalchemy import NullPool
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
|
|
from letta.database_utils import get_database_uri_for_context
|
|
from letta.settings import settings
|
|
|
|
# Convert PostgreSQL URI to async format using common utility
|
|
async_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "async")
|
|
|
|
# Build engine configuration based on settings
|
|
engine_args = {
|
|
"echo": settings.pg_echo,
|
|
"pool_pre_ping": settings.pool_pre_ping,
|
|
}
|
|
|
|
# Configure pooling
|
|
if settings.disable_sqlalchemy_pooling:
|
|
engine_args["poolclass"] = NullPool
|
|
else:
|
|
# Use default AsyncAdaptedQueuePool with configured settings
|
|
engine_args.update(
|
|
{
|
|
"pool_size": settings.pg_pool_size,
|
|
"max_overflow": settings.pg_max_overflow,
|
|
"pool_timeout": settings.pg_pool_timeout,
|
|
"pool_recycle": settings.pg_pool_recycle,
|
|
}
|
|
)
|
|
|
|
# Add asyncpg-specific settings for connection
|
|
if not settings.disable_sqlalchemy_pooling:
|
|
connect_args = {
|
|
"timeout": settings.pg_pool_timeout,
|
|
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__",
|
|
"statement_cache_size": 0,
|
|
"prepared_statement_cache_size": 0,
|
|
}
|
|
# Only add SSL if not already specified in connection string
|
|
if "sslmode" not in async_pg_uri and "ssl" not in async_pg_uri:
|
|
connect_args["ssl"] = "require"
|
|
|
|
engine_args["connect_args"] = connect_args
|
|
|
|
# Create the engine once at module level
|
|
engine: AsyncEngine = create_async_engine(async_pg_uri, **engine_args)
|
|
|
|
# Create session factory once at module level
|
|
async_session_factory = async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
|
|
class DatabaseRegistry:
|
|
"""Dummy registry to maintain the existing interface."""
|
|
|
|
@asynccontextmanager
|
|
async def async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get an async database session.
|
|
|
|
Note: We explicitly handle asyncio.CancelledError separately because it's
|
|
a BaseException (not Exception) in Python 3.8+. Without this, cancelled
|
|
tasks would skip rollback() and return connections to the pool with
|
|
uncommitted transactions, causing "idle in transaction" connection leaks.
|
|
"""
|
|
async with async_session_factory() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except asyncio.CancelledError:
|
|
# Task was cancelled (client disconnect, timeout, explicit cancellation)
|
|
# Must rollback to avoid returning connection with open transaction
|
|
await session.rollback()
|
|
raise
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
session.expunge_all()
|
|
await session.close()
|
|
|
|
|
|
# Create singleton instance to match existing interface
|
|
db_registry = DatabaseRegistry()
|
|
|
|
|
|
# Backwards compatibility function
|
|
def get_db_registry() -> DatabaseRegistry:
|
|
"""Get the global database registry instance."""
|
|
return db_registry
|
|
|
|
|
|
# FastAPI dependency helper
|
|
async def get_db_async() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get an async database session."""
|
|
async with db_registry.async_session() as session:
|
|
yield session
|
|
|
|
|
|
# Optional: cleanup function for graceful shutdown
|
|
async def close_db() -> None:
|
|
"""Close the database engine."""
|
|
await engine.dispose()
|
|
|
|
|
|
# Usage remains the same:
|
|
# async with db_registry.async_session() as session:
|
|
# result = await session.execute(select(User))
|
|
# users = result.scalars().all()
|