diff --git a/letta/orm/errors.py b/letta/orm/errors.py index 28e5807f..a574e74c 100644 --- a/letta/orm/errors.py +++ b/letta/orm/errors.py @@ -12,3 +12,11 @@ class UniqueConstraintViolationError(ValueError): class ForeignKeyConstraintViolationError(ValueError): """Custom exception for foreign key constraint violations.""" + + +class DatabaseTimeoutError(Exception): + """Custom exception for database timeout issues.""" + + def __init__(self, message="Database operation timed out", original_exception=None): + super().__init__(message) + self.original_exception = original_exception diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 48b8c44a..6879c74b 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -1,14 +1,16 @@ from datetime import datetime from enum import Enum +from functools import wraps from typing import TYPE_CHECKING, List, Literal, Optional from sqlalchemy import String, desc, func, or_, select -from sqlalchemy.exc import DBAPIError, IntegrityError +from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError from sqlalchemy.orm import Mapped, Session, mapped_column from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.errors import ( + DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError, @@ -23,6 +25,20 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def handle_db_timeout(func): + """Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except TimeoutError as e: + logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") + raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e) + + return wrapper + + class AccessType(str, Enum): ORGANIZATION = "organization" USER = "user" @@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): id: Mapped[str] = mapped_column(String, primary_key=True) @classmethod - def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]: - """Get a record by ID. - - Args: - db_session: SQLAlchemy session - id: Record ID to retrieve - - Returns: - Optional[SqlalchemyBase]: The record if found, None otherwise - """ - try: - return db_session.query(cls).filter(cls.id == id).first() - except DBAPIError: - return None - - @classmethod + @handle_db_timeout def list( cls, *, @@ -180,6 +181,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): return list(session.execute(query).scalars()) @classmethod + @handle_db_timeout def read( cls, db_session: "Session", @@ -231,6 +233,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions" raise NoResultFound(f"{cls.__name__} not found with {conditions_str}") + @handle_db_timeout def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") @@ -245,6 +248,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): except (DBAPIError, IntegrityError) as e: self._handle_dbapi_error(e) + @handle_db_timeout def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") @@ -254,6 +258,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): self.is_deleted = True return self.update(db_session) + @handle_db_timeout def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None: """Permanently removes the record from the database.""" logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") @@ -269,6 +274,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): else: logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") + @handle_db_timeout def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}") if actor: @@ -281,6 +287,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): return self @classmethod + @handle_db_timeout def size( cls, *, diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index b5117408..8cb9b27e 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -15,7 +15,12 @@ from letta.__init__ import __version__ from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError from letta.log import get_logger -from letta.orm.errors import NoResultFound +from letta.orm.errors import ( + DatabaseTimeoutError, + ForeignKeyConstraintViolationError, + NoResultFound, + UniqueConstraintViolationError, +) from letta.schemas.letta_response import LettaResponse from letta.server.constants import REST_DEFAULT_PORT @@ -175,7 +180,6 @@ def create_application() -> "FastAPI": @app.exception_handler(NoResultFound) async def no_result_found_handler(request: Request, exc: NoResultFound): - logger.error(f"NoResultFound request: {request}") logger.error(f"NoResultFound: {exc}") return JSONResponse( @@ -183,6 +187,32 @@ def create_application() -> "FastAPI": content={"detail": str(exc)}, ) + @app.exception_handler(ForeignKeyConstraintViolationError) + async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError): + logger.error(f"ForeignKeyConstraintViolationError: {exc}") + + return JSONResponse( + status_code=409, + content={"detail": str(exc)}, + ) + + @app.exception_handler(UniqueConstraintViolationError) + async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError): + logger.error(f"UniqueConstraintViolationError: {exc}") + + return JSONResponse( + status_code=409, + content={"detail": str(exc)}, + ) + + @app.exception_handler(DatabaseTimeoutError) + async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError): + logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}") + return JSONResponse( + status_code=503, + content={"detail": "The database is temporarily unavailable. Please try again later."}, + ) + @app.exception_handler(ValueError) async def value_error_handler(request: Request, exc: ValueError): return JSONResponse(status_code=400, content={"detail": str(exc)}) @@ -235,11 +265,6 @@ def create_application() -> "FastAPI": @app.on_event("startup") def on_startup(): - # load the default tools - # from letta.orm.tool import Tool - - # Tool.load_default_tools(get_db_session()) - generate_openapi_schema(app) @app.on_event("shutdown") diff --git a/letta/server/server.py b/letta/server/server.py index 1d9321e2..71b0ac78 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -190,7 +190,14 @@ if settings.letta_pg_uri_no_default: config.archival_storage_uri = settings.letta_pg_uri_no_default # create engine - engine = create_engine(settings.letta_pg_uri) + engine = create_engine( + settings.letta_pg_uri, + pool_size=settings.pg_pool_size, + max_overflow=settings.pg_max_overflow, + pool_timeout=settings.pg_pool_timeout, + pool_recycle=settings.pg_pool_recycle, + echo=settings.pg_echo, + ) else: # TODO: don't rely on config storage engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")) diff --git a/letta/settings.py b/letta/settings.py index 20a0c1c5..d6907b11 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -17,7 +17,7 @@ class ToolSettings(BaseSettings): class ModelSettings(BaseSettings): - model_config = SettingsConfigDict(env_file='.env', extra='ignore') + model_config = SettingsConfigDict(env_file=".env", extra="ignore") # env_prefix='my_prefix_' @@ -64,7 +64,7 @@ cors_origins = ["http://letta.localhost", "http://localhost:8283", "http://local class Settings(BaseSettings): - model_config = SettingsConfigDict(env_prefix="letta_", extra='ignore') + model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore") letta_dir: Optional[Path] = Field(Path.home() / ".letta", env="LETTA_DIR") debug: Optional[bool] = False @@ -76,7 +76,12 @@ class Settings(BaseSettings): pg_password: Optional[str] = None pg_host: Optional[str] = None pg_port: Optional[int] = None - pg_uri: Optional[str] = None # option to specifiy full uri + pg_uri: Optional[str] = None # option to specify full uri + pg_pool_size: int = 20 # Concurrent connections + pg_max_overflow: int = 10 # Overflow limit + pg_pool_timeout: int = 30 # Seconds to wait for a connection + pg_pool_recycle: int = 1800 # When to recycle connections + pg_echo: bool = False # Logging # tools configuration load_default_external_tools: Optional[bool] = None @@ -103,7 +108,7 @@ class Settings(BaseSettings): class TestSettings(Settings): - model_config = SettingsConfigDict(env_prefix="letta_test_", extra='ignore') + model_config = SettingsConfigDict(env_prefix="letta_test_", extra="ignore") letta_dir: Optional[Path] = Field(Path.home() / ".letta/test", env="LETTA_TEST_DIR")