Add retry mechanism for ConnectionError during asyncpg SSL handshake failures. Implements exponential backoff (3 attempts) and returns 503 on exhaustion. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> Issue-ID: 8caf1136-0200-11f1-8f4d-da7ad0900000
146 lines
4.9 KiB
Python
146 lines
4.9 KiB
Python
import asyncio
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator
|
|
|
|
from sqlalchemy import NullPool
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
|
|
from letta.database_utils import get_database_uri_for_context
|
|
from letta.log import get_logger
|
|
from letta.settings import settings
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Convert PostgreSQL URI to async format using common utility
|
|
async_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "async")
|
|
|
|
# Build engine configuration based on settings
|
|
engine_args = {
|
|
"echo": settings.pg_echo,
|
|
"pool_pre_ping": settings.pool_pre_ping,
|
|
}
|
|
|
|
# Configure pooling
|
|
if settings.disable_sqlalchemy_pooling:
|
|
engine_args["poolclass"] = NullPool
|
|
else:
|
|
# Use default AsyncAdaptedQueuePool with configured settings
|
|
engine_args.update(
|
|
{
|
|
"pool_size": settings.pg_pool_size,
|
|
"max_overflow": settings.pg_max_overflow,
|
|
"pool_timeout": settings.pg_pool_timeout,
|
|
"pool_recycle": settings.pg_pool_recycle,
|
|
}
|
|
)
|
|
|
|
# Add asyncpg-specific settings for connection
|
|
if not settings.disable_sqlalchemy_pooling:
|
|
connect_args = {
|
|
"timeout": settings.pg_pool_timeout,
|
|
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid.uuid4()}__",
|
|
"statement_cache_size": 0,
|
|
"prepared_statement_cache_size": 0,
|
|
}
|
|
# Only add SSL if not already specified in connection string
|
|
if "sslmode" not in async_pg_uri and "ssl" not in async_pg_uri:
|
|
connect_args["ssl"] = "require"
|
|
|
|
engine_args["connect_args"] = connect_args
|
|
|
|
# Create the engine once at module level
|
|
engine: AsyncEngine = create_async_engine(async_pg_uri, **engine_args)
|
|
|
|
# Create session factory once at module level
|
|
async_session_factory = async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
|
|
class DatabaseRegistry:
|
|
"""Dummy registry to maintain the existing interface."""
|
|
|
|
@asynccontextmanager
|
|
async def async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get an async database session.
|
|
|
|
Note: We explicitly handle asyncio.CancelledError separately because it's
|
|
a BaseException (not Exception) in Python 3.8+. Without this, cancelled
|
|
tasks would skip rollback() and return connections to the pool with
|
|
uncommitted transactions, causing "idle in transaction" connection leaks.
|
|
|
|
Implements retry logic for transient connection errors (e.g., SSL handshake failures).
|
|
"""
|
|
max_retries = 3
|
|
retry_delay = 0.1
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
async with async_session_factory() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except asyncio.CancelledError:
|
|
# Task was cancelled (client disconnect, timeout, explicit cancellation)
|
|
# Must rollback to avoid returning connection with open transaction
|
|
await session.rollback()
|
|
raise
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
session.expunge_all()
|
|
await session.close()
|
|
return
|
|
except ConnectionError as e:
|
|
if attempt < max_retries - 1:
|
|
logger.warning(f"Database connection error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {retry_delay}s...")
|
|
await asyncio.sleep(retry_delay)
|
|
retry_delay *= 2
|
|
else:
|
|
logger.error(f"Database connection failed after {max_retries} attempts: {e}")
|
|
from letta.errors import LettaServiceUnavailableError
|
|
|
|
raise LettaServiceUnavailableError(
|
|
"Database connection temporarily unavailable. Please retry your request.", service_name="database"
|
|
) from e
|
|
|
|
|
|
# Create singleton instance to match existing interface
|
|
db_registry = DatabaseRegistry()
|
|
|
|
|
|
# Backwards compatibility function
|
|
def get_db_registry() -> DatabaseRegistry:
|
|
"""Get the global database registry instance."""
|
|
return db_registry
|
|
|
|
|
|
# FastAPI dependency helper
|
|
async def get_db_async() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get an async database session."""
|
|
async with db_registry.async_session() as session:
|
|
yield session
|
|
|
|
|
|
# Optional: cleanup function for graceful shutdown
|
|
async def close_db() -> None:
|
|
"""Close the database engine."""
|
|
await engine.dispose()
|
|
|
|
|
|
# Usage remains the same:
|
|
# async with db_registry.async_session() as session:
|
|
# result = await session.execute(select(User))
|
|
# users = result.scalars().all()
|