fix: Scale up database (#2263)
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import (
|
||||
DatabaseTimeoutError,
|
||||
ForeignKeyConstraintViolationError,
|
||||
NoResultFound,
|
||||
UniqueConstraintViolationError,
|
||||
@@ -23,6 +25,20 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def handle_db_timeout(func):
|
||||
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AccessType(str, Enum):
|
||||
ORGANIZATION = "organization"
|
||||
USER = "user"
|
||||
@@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
@classmethod
|
||||
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
|
||||
"""Get a record by ID.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
id: Record ID to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[SqlalchemyBase]: The record if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db_session.query(cls).filter(cls.id == id).first()
|
||||
except DBAPIError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def list(
|
||||
cls,
|
||||
*,
|
||||
@@ -180,6 +181,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def read(
|
||||
cls,
|
||||
db_session: "Session",
|
||||
@@ -231,6 +233,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
|
||||
|
||||
@handle_db_timeout
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
@@ -245,6 +248,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
@handle_db_timeout
|
||||
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
@@ -254,6 +258,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
self.is_deleted = True
|
||||
return self.update(db_session)
|
||||
|
||||
@handle_db_timeout
|
||||
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database."""
|
||||
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
@@ -269,6 +274,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
else:
|
||||
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
||||
|
||||
@handle_db_timeout
|
||||
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
if actor:
|
||||
@@ -281,6 +287,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def size(
|
||||
cls,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user