diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 70c0dbf1..1276684c 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -5,6 +5,7 @@ from functools import wraps from pprint import pformat from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union +from asyncpg.exceptions import QueryCanceledError from sqlalchemy import Sequence, String, and_, delete, func, or_, select from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError from sqlalchemy.ext.asyncio import AsyncSession @@ -26,7 +27,11 @@ logger = get_logger(__name__) def handle_db_timeout(func): - """Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception.""" + """Decorator to handle database timeout errors and wrap them in a custom exception. + + Catches both SQLAlchemy TimeoutError (pool/connection timeout) and asyncpg's + QueryCanceledError (PostgreSQL statement_timeout triggered). + """ if not inspect.iscoroutinefunction(func): @wraps(func) @@ -36,6 +41,9 @@ def handle_db_timeout(func): except TimeoutError as e: logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e) + except QueryCanceledError as e: + logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") + raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e) return wrapper else: @@ -47,6 +55,9 @@ def handle_db_timeout(func): except TimeoutError as e: logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e) + except QueryCanceledError as e: + logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") + raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e) return async_wrapper @@ -771,6 +782,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): error_message = str(orig) if orig else str(e) logger.info(f"Handling DBAPIError: {error_message}") + # Handle asyncpg QueryCanceledError (wrapped in DBAPIError) + # This occurs when PostgreSQL's statement_timeout kills a long-running query + if isinstance(orig, QueryCanceledError): + logger.error(f"Query canceled (statement timeout) for {cls.__name__}: {e}") + raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout for {cls.__name__}.", original_exception=e) from e + # Handle SQLite-specific errors if "UNIQUE constraint failed" in error_message: raise UniqueConstraintViolationError(