fix: Scale up database (#2263)

This commit is contained in:
Matthew Zhou
2024-12-17 15:02:28 -08:00
committed by GitHub
parent 27ea364a32
commit e09bde67ef
5 changed files with 81 additions and 29 deletions

View File

@@ -1,14 +1,16 @@
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import TYPE_CHECKING, List, Literal, Optional
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import (
DatabaseTimeoutError,
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
@@ -23,6 +25,20 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def handle_db_timeout(func):
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
return wrapper
class AccessType(str, Enum):
ORGANIZATION = "organization"
USER = "user"
@@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
id: Mapped[str] = mapped_column(String, primary_key=True)
@classmethod
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
"""Get a record by ID.
Args:
db_session: SQLAlchemy session
id: Record ID to retrieve
Returns:
Optional[SqlalchemyBase]: The record if found, None otherwise
"""
try:
return db_session.query(cls).filter(cls.id == id).first()
except DBAPIError:
return None
@classmethod
@handle_db_timeout
def list(
cls,
*,
@@ -180,6 +181,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
return list(session.execute(query).scalars())
@classmethod
@handle_db_timeout
def read(
cls,
db_session: "Session",
@@ -231,6 +233,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
@handle_db_timeout
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -245,6 +248,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)
@handle_db_timeout
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -254,6 +258,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
self.is_deleted = True
return self.update(db_session)
@handle_db_timeout
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database."""
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -269,6 +274,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
else:
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
@handle_db_timeout
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
@@ -281,6 +287,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
return self
@classmethod
@handle_db_timeout
def size(
cls,
*,