* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
987 lines
42 KiB
Python
987 lines
42 KiB
Python
import asyncio
|
|
import inspect
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from functools import wraps
|
|
from pprint import pformat
|
|
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
|
|
|
|
from asyncpg.exceptions import DeadlockDetectedError, LockNotAvailableError as AsyncpgLockNotAvailableError, QueryCanceledError
|
|
from sqlalchemy import Sequence, String, and_, delete, func, or_, select
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
|
from sqlalchemy.orm.exc import StaleDataError
|
|
from sqlalchemy.orm.interfaces import ORMOption
|
|
|
|
from letta.errors import ConcurrentUpdateError
|
|
from letta.log import get_logger
|
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
|
from letta.orm.errors import (
|
|
DatabaseDeadlockError,
|
|
DatabaseLockNotAvailableError,
|
|
DatabaseTimeoutError,
|
|
ForeignKeyConstraintViolationError,
|
|
NoResultFound,
|
|
UniqueConstraintViolationError,
|
|
)
|
|
from letta.settings import DatabaseChoice
|
|
|
|
if TYPE_CHECKING:
|
|
from pydantic import BaseModel
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_DEADLOCK_MAX_RETRIES = 3
|
|
_DEADLOCK_BASE_DELAY = 0.1
|
|
|
|
|
|
def _is_deadlock_error(exc: Exception) -> bool:
|
|
"""Check if an exception is a database deadlock error (PostgreSQL error code 40P01)."""
|
|
orig = getattr(exc, "orig", exc)
|
|
if isinstance(orig, DeadlockDetectedError):
|
|
return True
|
|
if hasattr(orig, "pgcode") and getattr(orig, "pgcode", None) == "40P01":
|
|
return True
|
|
if hasattr(orig, "args") and orig.args and isinstance(orig.args[0], dict):
|
|
if orig.args[0].get("C") == "40P01":
|
|
return True
|
|
return False
|
|
|
|
|
|
def handle_db_timeout(func):
|
|
"""Decorator to handle database timeout errors and wrap them in a custom exception.
|
|
|
|
Catches both SQLAlchemy TimeoutError (pool/connection timeout) and asyncpg's
|
|
QueryCanceledError (PostgreSQL statement_timeout triggered).
|
|
"""
|
|
if not inspect.iscoroutinefunction(func):
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except TimeoutError as e:
|
|
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
|
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
|
except QueryCanceledError as e:
|
|
logger.error(
|
|
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
|
|
)
|
|
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
|
|
|
|
return wrapper
|
|
else:
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except TimeoutError as e:
|
|
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
|
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
|
except QueryCanceledError as e:
|
|
logger.error(
|
|
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
|
|
)
|
|
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
|
|
|
|
return async_wrapper
|
|
|
|
|
|
def is_postgresql_session(session: Session) -> bool:
|
|
"""Check if the database session is PostgreSQL instead of SQLite for setting query options."""
|
|
return session.bind.dialect.name == "postgresql"
|
|
|
|
|
|
class AccessType(str, Enum):
|
|
ORGANIZATION = "organization"
|
|
USER = "user"
|
|
|
|
|
|
class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
__abstract__ = True
|
|
|
|
__order_by_default__ = "created_at"
|
|
|
|
id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
|
|
@classmethod
|
|
@handle_db_timeout
|
|
async def list_async(
|
|
cls,
|
|
*,
|
|
db_session: "AsyncSession",
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
limit: Optional[int] = 50,
|
|
query_text: Optional[str] = None,
|
|
query_embedding: Optional[List[float]] = None,
|
|
ascending: bool = True,
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
join_model: Optional[Base] = None,
|
|
join_conditions: Optional[Union[Tuple, List]] = None,
|
|
identifier_keys: Optional[List[str]] = None,
|
|
identity_id: Optional[str] = None,
|
|
query_options: Sequence[ORMOption] | None = None, # ← new
|
|
has_feedback: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> List["SqlalchemyBase"]:
|
|
"""
|
|
Async version of list method above.
|
|
NOTE: Keep in sync.
|
|
List records with before/after pagination, ordering by created_at.
|
|
Can use both before and after to fetch a window of records.
|
|
|
|
Args:
|
|
db_session: SQLAlchemy session
|
|
before: ID of item to paginate before (upper bound)
|
|
after: ID of item to paginate after (lower bound)
|
|
start_date: Filter items after this date
|
|
end_date: Filter items before this date
|
|
limit: Maximum number of items to return
|
|
query_text: Text to search for
|
|
query_embedding: Vector to search for similar embeddings
|
|
ascending: Sort direction
|
|
**kwargs: Additional filters to apply
|
|
"""
|
|
if start_date and end_date and start_date > end_date:
|
|
raise ValueError("start_date must be earlier than or equal to end_date")
|
|
|
|
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
|
|
|
# Get the reference objects for pagination
|
|
before_obj = None
|
|
after_obj = None
|
|
|
|
if before:
|
|
before_obj = await db_session.get(cls, before)
|
|
if not before_obj:
|
|
raise NoResultFound(f"No {cls.__name__} found with id {before}")
|
|
|
|
if after:
|
|
after_obj = await db_session.get(cls, after)
|
|
if not after_obj:
|
|
raise NoResultFound(f"No {cls.__name__} found with id {after}")
|
|
|
|
# Validate that before comes after the after object if both are provided
|
|
if before_obj and after_obj and before_obj.created_at < after_obj.created_at:
|
|
raise ValueError("'before' reference must be later than 'after' reference")
|
|
|
|
query = cls._list_preprocess(
|
|
before_obj=before_obj,
|
|
after_obj=after_obj,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=limit,
|
|
query_text=query_text,
|
|
query_embedding=query_embedding,
|
|
ascending=ascending,
|
|
actor=actor,
|
|
access=access,
|
|
access_type=access_type,
|
|
join_model=join_model,
|
|
join_conditions=join_conditions,
|
|
identifier_keys=identifier_keys,
|
|
identity_id=identity_id,
|
|
has_feedback=has_feedback,
|
|
**kwargs,
|
|
)
|
|
if query_options:
|
|
for opt in query_options:
|
|
query = query.options(opt)
|
|
|
|
# Execute the query
|
|
results = await db_session.execute(query)
|
|
|
|
results = list(results.scalars())
|
|
results = cls._list_postprocess(
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
results=results,
|
|
)
|
|
|
|
return results
|
|
|
|
@classmethod
|
|
def _list_preprocess(
|
|
cls,
|
|
*,
|
|
before_obj,
|
|
after_obj,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
limit: Optional[int] = 50,
|
|
query_text: Optional[str] = None,
|
|
query_embedding: Optional[List[float]] = None,
|
|
ascending: bool = True,
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
join_model: Optional[Base] = None,
|
|
join_conditions: Optional[Union[Tuple, List]] = None,
|
|
identifier_keys: Optional[List[str]] = None,
|
|
identity_id: Optional[str] = None,
|
|
check_is_deleted: bool = False,
|
|
has_feedback: Optional[bool] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Constructs the query for listing records.
|
|
"""
|
|
# Security check: if the model has organization_id column, actor should be provided
|
|
if actor is None and hasattr(cls, "organization_id"):
|
|
logger.warning(f"SECURITY: Listing org-scoped model {cls.__name__} without actor. This bypasses organization filtering.")
|
|
|
|
query = select(cls)
|
|
|
|
if join_model and join_conditions:
|
|
query = query.join(join_model, and_(*join_conditions))
|
|
|
|
# Apply access predicate if actor is provided
|
|
if actor:
|
|
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
|
|
if identifier_keys and hasattr(cls, "identities"):
|
|
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
|
|
|
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
|
|
if identity_id and hasattr(cls, "identities"):
|
|
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
|
|
|
|
# Apply filtering logic from kwargs
|
|
# 1 part: <column> // 2 parts: <table>.<column> OR <column>.<json_key> // 3 parts: <table>.<column>.<json_key>
|
|
# TODO (cliandy): can make this more robust down the line
|
|
for key, value in kwargs.items():
|
|
parts = key.split(".")
|
|
if len(parts) == 1:
|
|
column = getattr(cls, key)
|
|
elif len(parts) == 2:
|
|
if locals().get(parts[0]) or globals().get(parts[0]):
|
|
# It's a joined table column
|
|
joined_table = locals().get(parts[0]) or globals().get(parts[0])
|
|
column = getattr(joined_table, parts[1])
|
|
else:
|
|
# It's a JSON field on the main table
|
|
column = getattr(cls, parts[0])
|
|
column = column.op("->>")(parts[1])
|
|
elif len(parts) == 3:
|
|
table_name, column_name, json_key = parts
|
|
joined_table = locals().get(table_name) or globals().get(table_name)
|
|
column = getattr(joined_table, column_name)
|
|
column = column.op("->>")(json_key)
|
|
else:
|
|
raise ValueError(f"Unhandled column name {key}")
|
|
|
|
if isinstance(value, (list, tuple, set)):
|
|
query = query.where(column.in_(value))
|
|
else:
|
|
query = query.where(column == value)
|
|
|
|
# Date range filtering
|
|
if start_date:
|
|
query = query.filter(cls.created_at > start_date)
|
|
if end_date:
|
|
query = query.filter(cls.created_at < end_date)
|
|
|
|
# Feedback filtering
|
|
if has_feedback is not None and hasattr(cls, "feedback"):
|
|
if has_feedback:
|
|
query = query.filter(cls.feedback.isnot(None))
|
|
else:
|
|
query = query.filter(cls.feedback.is_(None))
|
|
|
|
# Handle pagination based on before/after
|
|
if before_obj or after_obj:
|
|
conditions = []
|
|
|
|
if before_obj and after_obj:
|
|
# Window-based query - get records between before and after
|
|
# Skip pagination if either object has null created_at
|
|
if before_obj.created_at is not None and after_obj.created_at is not None:
|
|
conditions.append(
|
|
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
|
|
)
|
|
conditions.append(
|
|
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id))
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Skipping pagination: before_obj.created_at={before_obj.created_at}, after_obj.created_at={after_obj.created_at}"
|
|
)
|
|
else:
|
|
# Pure pagination query
|
|
if before_obj:
|
|
if before_obj.created_at is not None:
|
|
conditions.append(
|
|
or_(
|
|
cls.created_at < before_obj.created_at if ascending else cls.created_at > before_obj.created_at,
|
|
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
|
|
)
|
|
)
|
|
else:
|
|
logger.warning(f"Skipping 'before' pagination: before_obj.created_at is None (id={before_obj.id})")
|
|
if after_obj:
|
|
if after_obj.created_at is not None:
|
|
conditions.append(
|
|
or_(
|
|
cls.created_at > after_obj.created_at if ascending else cls.created_at < after_obj.created_at,
|
|
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
|
|
)
|
|
)
|
|
else:
|
|
logger.warning(f"Skipping 'after' pagination: after_obj.created_at is None (id={after_obj.id})")
|
|
|
|
if conditions:
|
|
query = query.where(and_(*conditions))
|
|
|
|
# Text search
|
|
if query_text:
|
|
if hasattr(cls, "text"):
|
|
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
|
elif hasattr(cls, "name"):
|
|
# Special case for Agent model - search across name
|
|
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
|
|
|
|
# Embedding search (for Passages)
|
|
is_ordered = False
|
|
if query_embedding:
|
|
if not hasattr(cls, "embedding"):
|
|
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
|
|
|
from letta.settings import settings
|
|
|
|
if settings.database_engine is DatabaseChoice.POSTGRES:
|
|
# PostgreSQL with pgvector
|
|
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
|
else:
|
|
# SQLite with custom vector type
|
|
from letta.orm.sqlite_functions import adapt_array
|
|
|
|
query_embedding_binary = adapt_array(query_embedding)
|
|
query = query.order_by(
|
|
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
|
cls.created_at.asc() if ascending else cls.created_at.desc(),
|
|
cls.id.asc(),
|
|
)
|
|
is_ordered = True
|
|
|
|
# Handle soft deletes
|
|
if check_is_deleted and hasattr(cls, "is_deleted"):
|
|
query = query.where(cls.is_deleted == False)
|
|
|
|
# Apply ordering
|
|
if not is_ordered:
|
|
if ascending:
|
|
query = query.order_by(cls.created_at.asc(), cls.id.asc())
|
|
else:
|
|
query = query.order_by(cls.created_at.desc(), cls.id.desc())
|
|
|
|
# Apply limit, adjusting for both bounds if necessary
|
|
if before_obj and after_obj:
|
|
# When both bounds are provided, we need to fetch enough records to satisfy
|
|
# the limit while respecting both bounds. We'll fetch more and then trim.
|
|
query = query.limit(limit * 2)
|
|
else:
|
|
query = query.limit(limit)
|
|
return query
|
|
|
|
@classmethod
|
|
def _list_postprocess(
|
|
cls,
|
|
before: str | None,
|
|
after: str | None,
|
|
limit: int | None,
|
|
results: list,
|
|
):
|
|
# If we have both bounds, take the middle portion
|
|
if before and after and len(results) > limit:
|
|
middle = len(results) // 2
|
|
start = max(0, middle - limit // 2)
|
|
end = min(len(results), start + limit)
|
|
results = results[start:end]
|
|
return results
|
|
|
|
@classmethod
|
|
@handle_db_timeout
|
|
async def read_async(
|
|
cls,
|
|
db_session: "AsyncSession",
|
|
identifier: Optional[str] = None,
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
check_is_deleted: bool = False,
|
|
**kwargs,
|
|
) -> "SqlalchemyBase":
|
|
"""The primary accessor for an ORM record. Async version of read method.
|
|
Args:
|
|
db_session: the database session to use when retrieving the record
|
|
identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility
|
|
actor: if specified, results will be scoped only to records the user is able to access
|
|
access: if actor is specified, records will be filtered to the minimum permission level for the actor
|
|
kwargs: additional arguments to pass to the read, used for more complex objects
|
|
Returns:
|
|
The matching object
|
|
Raises:
|
|
NoResultFound: if the object is not found
|
|
"""
|
|
identifiers = [] if identifier is None else [identifier]
|
|
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, check_is_deleted, **kwargs)
|
|
if query is None:
|
|
raise NoResultFound(f"{cls.__name__} not found with identifier {identifier}")
|
|
|
|
result = await db_session.execute(query)
|
|
item = result.scalar_one_or_none()
|
|
|
|
if item is None:
|
|
raise NoResultFound(f"{cls.__name__} not found with {', '.join(query_conditions if query_conditions else ['no conditions'])}")
|
|
return item
|
|
|
|
@classmethod
|
|
@handle_db_timeout
|
|
async def read_multiple_async(
|
|
cls,
|
|
db_session: "AsyncSession",
|
|
identifiers: List[str] = [],
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
check_is_deleted: bool = False,
|
|
**kwargs,
|
|
) -> List["SqlalchemyBase"]:
|
|
"""
|
|
Async version of read_multiple(...)
|
|
The primary accessor for ORM record(s)
|
|
"""
|
|
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, check_is_deleted, **kwargs)
|
|
if query is None:
|
|
return []
|
|
results = await db_session.execute(query)
|
|
return cls._read_multiple_postprocess(results.scalars().all(), identifiers, query_conditions)
|
|
|
|
@classmethod
|
|
def _read_multiple_preprocess(
|
|
cls,
|
|
identifiers: List[str],
|
|
actor: Optional["User"], # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]],
|
|
access_type: AccessType,
|
|
check_is_deleted: bool,
|
|
**kwargs,
|
|
):
|
|
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
|
|
|
|
# Security check: if the model has organization_id column, actor should be provided
|
|
# to ensure proper org-scoping. Log a warning if actor is None.
|
|
if actor is None and hasattr(cls, "organization_id"):
|
|
logger.warning(
|
|
f"SECURITY: Reading org-scoped model {cls.__name__} without actor. "
|
|
f"IDs: {identifiers}. This bypasses organization filtering."
|
|
)
|
|
|
|
# Start the query
|
|
query = select(cls)
|
|
# Collect query conditions for better error reporting
|
|
query_conditions = []
|
|
|
|
# If an identifier is provided, add it to the query conditions
|
|
if identifiers:
|
|
if len(identifiers) == 1:
|
|
query = query.where(cls.id == identifiers[0])
|
|
else:
|
|
query = query.where(cls.id.in_(identifiers))
|
|
query_conditions.append(f"id='{identifiers}'")
|
|
elif not kwargs:
|
|
logger.debug(f"No identifiers provided for {cls.__name__}, returning empty list")
|
|
return None, query_conditions
|
|
|
|
if kwargs:
|
|
query = query.filter_by(**kwargs)
|
|
query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
|
|
|
|
if actor:
|
|
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
query_conditions.append(f"access level in {access} for actor='{actor}'")
|
|
|
|
if check_is_deleted and hasattr(cls, "is_deleted"):
|
|
query = query.where(cls.is_deleted == False)
|
|
query_conditions.append("is_deleted=False")
|
|
|
|
return query, query_conditions
|
|
|
|
@classmethod
|
|
def _read_multiple_postprocess(cls, results, identifiers: List[str], query_conditions) -> List["SqlalchemyBase"]:
|
|
if results: # if empty list a.k.a. no results
|
|
if len(identifiers) > 0:
|
|
# find which identifiers were not found
|
|
# only when identifier length is greater than 0 (so it was used in the actual query)
|
|
identifier_set = set(identifiers)
|
|
results_set = set(map(lambda obj: obj.id, results))
|
|
|
|
# we log a warning message if any of the queried IDs were not found.
|
|
# TODO: should we error out instead?
|
|
if identifier_set != results_set:
|
|
# Construct a detailed error message based on query conditions
|
|
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
|
logger.debug(f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}")
|
|
return results
|
|
|
|
# Construct a detailed error message based on query conditions
|
|
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
|
logger.debug(f"{cls.__name__} not found with {conditions_str}")
|
|
return []
|
|
|
|
@handle_db_timeout
|
|
async def create_async(
|
|
self,
|
|
db_session: "AsyncSession",
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
no_commit: bool = False,
|
|
no_refresh: bool = False,
|
|
ignore_conflicts: bool = False,
|
|
) -> Optional["SqlalchemyBase"]:
|
|
"""Async version of create function
|
|
|
|
Args:
|
|
ignore_conflicts: If True, uses INSERT ... ON CONFLICT DO NOTHING and returns
|
|
None if a conflict occurred (no exception raised).
|
|
"""
|
|
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
|
|
if actor:
|
|
self._set_created_and_updated_by_fields(actor.id)
|
|
|
|
if ignore_conflicts:
|
|
values = {
|
|
col.name: getattr(self, col.key)
|
|
for col in self.__table__.columns
|
|
if not (getattr(self, col.key) is None and col.server_default is not None)
|
|
}
|
|
stmt = pg_insert(self.__table__).values(**values).on_conflict_do_nothing()
|
|
result = await db_session.execute(stmt)
|
|
if not no_commit:
|
|
await db_session.commit()
|
|
return self if result.rowcount > 0 else None
|
|
|
|
for attempt in range(_DEADLOCK_MAX_RETRIES):
|
|
try:
|
|
db_session.add(self)
|
|
if no_commit:
|
|
await db_session.flush()
|
|
else:
|
|
await db_session.commit()
|
|
|
|
if not no_refresh:
|
|
await db_session.refresh(self)
|
|
return self
|
|
except (DBAPIError, IntegrityError) as e:
|
|
if _is_deadlock_error(e) and attempt < _DEADLOCK_MAX_RETRIES - 1:
|
|
logger.warning(
|
|
f"Deadlock detected in {self.__class__.__name__}.create_async "
|
|
f"(attempt {attempt + 1}/{_DEADLOCK_MAX_RETRIES}), retrying..."
|
|
)
|
|
await db_session.rollback()
|
|
await asyncio.sleep(_DEADLOCK_BASE_DELAY * (2**attempt))
|
|
continue
|
|
self._handle_dbapi_error(e)
|
|
|
|
@classmethod
|
|
@handle_db_timeout
|
|
async def batch_create_async(
|
|
cls,
|
|
items: List["SqlalchemyBase"],
|
|
db_session: "AsyncSession",
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
no_commit: bool = False,
|
|
no_refresh: bool = False,
|
|
) -> List["SqlalchemyBase"]:
|
|
"""
|
|
Async version of batch_create method.
|
|
Create multiple records in a single transaction for better performance.
|
|
Args:
|
|
items: List of model instances to create
|
|
db_session: AsyncSession session
|
|
actor: Optional user performing the action
|
|
no_commit: Whether to commit the transaction
|
|
no_refresh: Whether to refresh the created objects
|
|
Returns:
|
|
List of created model instances
|
|
"""
|
|
logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}")
|
|
|
|
if not items:
|
|
return []
|
|
|
|
if actor:
|
|
for item in items:
|
|
item._set_created_and_updated_by_fields(actor.id)
|
|
|
|
for attempt in range(_DEADLOCK_MAX_RETRIES):
|
|
try:
|
|
db_session.add_all(items)
|
|
if no_commit:
|
|
await db_session.flush()
|
|
else:
|
|
await db_session.commit()
|
|
|
|
if no_refresh:
|
|
return items
|
|
else:
|
|
item_ids = [item.id for item in items]
|
|
query = select(cls).where(cls.id.in_(item_ids))
|
|
if hasattr(cls, "created_at"):
|
|
query = query.order_by(cls.created_at)
|
|
|
|
result = await db_session.execute(query)
|
|
return list(result.scalars())
|
|
except (DBAPIError, IntegrityError) as e:
|
|
if _is_deadlock_error(e) and attempt < _DEADLOCK_MAX_RETRIES - 1:
|
|
logger.warning(
|
|
f"Deadlock detected in {cls.__name__}.batch_create_async "
|
|
f"(attempt {attempt + 1}/{_DEADLOCK_MAX_RETRIES}), retrying..."
|
|
)
|
|
await db_session.rollback()
|
|
await asyncio.sleep(_DEADLOCK_BASE_DELAY * (2**attempt))
|
|
continue
|
|
cls._handle_dbapi_error(e)
|
|
|
|
@handle_db_timeout
|
|
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase": # noqa: F821
|
|
"""Soft delete a record asynchronously (mark as deleted)."""
|
|
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
|
|
|
if actor:
|
|
self._set_created_and_updated_by_fields(actor.id)
|
|
|
|
self.is_deleted = True
|
|
return await self.update_async(db_session)
|
|
|
|
@handle_db_timeout
|
|
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None: # noqa: F821
|
|
"""Permanently removes the record from the database asynchronously."""
|
|
obj_id = self.id
|
|
obj_class = self.__class__.__name__
|
|
logger.debug(f"Hard deleting {obj_class} with ID: {obj_id} with actor={actor} (async)")
|
|
|
|
for attempt in range(_DEADLOCK_MAX_RETRIES):
|
|
try:
|
|
await db_session.delete(self)
|
|
await db_session.commit()
|
|
return
|
|
except Exception as e:
|
|
if _is_deadlock_error(e) and attempt < _DEADLOCK_MAX_RETRIES - 1:
|
|
logger.warning(
|
|
f"Deadlock detected in {obj_class}.hard_delete_async (attempt {attempt + 1}/{_DEADLOCK_MAX_RETRIES}), retrying..."
|
|
)
|
|
await db_session.rollback()
|
|
await asyncio.sleep(_DEADLOCK_BASE_DELAY * (2**attempt))
|
|
continue
|
|
await db_session.rollback()
|
|
logger.exception(f"Failed to hard delete {obj_class} with ID {obj_id}")
|
|
raise ValueError(f"Failed to hard delete {obj_class} with ID {obj_id}: {e}")
|
|
|
|
@classmethod
|
|
@handle_db_timeout
|
|
async def bulk_hard_delete_async(
|
|
cls,
|
|
db_session: "AsyncSession",
|
|
identifiers: List[str],
|
|
actor: Optional["User"], # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["write"],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
) -> None:
|
|
"""Permanently removes the record from the database asynchronously."""
|
|
logger.debug(f"Hard deleting {cls.__name__} with IDs: {identifiers} with actor={actor} (async)")
|
|
|
|
if len(identifiers) == 0:
|
|
logger.debug(f"No identifiers provided for {cls.__name__}, nothing to delete")
|
|
return
|
|
|
|
for attempt in range(_DEADLOCK_MAX_RETRIES):
|
|
query = delete(cls)
|
|
query = query.where(cls.id.in_(identifiers))
|
|
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
try:
|
|
result = await db_session.execute(query)
|
|
await db_session.commit()
|
|
logger.debug(f"Successfully deleted {result.rowcount} {cls.__name__} records")
|
|
return
|
|
except Exception as e:
|
|
if _is_deadlock_error(e) and attempt < _DEADLOCK_MAX_RETRIES - 1:
|
|
logger.warning(
|
|
f"Deadlock detected in {cls.__name__}.bulk_hard_delete_async "
|
|
f"(attempt {attempt + 1}/{_DEADLOCK_MAX_RETRIES}), retrying..."
|
|
)
|
|
await db_session.rollback()
|
|
await asyncio.sleep(_DEADLOCK_BASE_DELAY * (2**attempt))
|
|
continue
|
|
await db_session.rollback()
|
|
logger.exception(f"Failed to hard delete {cls.__name__} with identifiers {identifiers}")
|
|
raise ValueError(f"Failed to hard delete {cls.__name__} with identifiers {identifiers}: {e}")
|
|
|
|
@handle_db_timeout
|
|
async def update_async(
|
|
self,
|
|
db_session: "AsyncSession",
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
no_commit: bool = False,
|
|
no_refresh: bool = False,
|
|
) -> "SqlalchemyBase":
|
|
"""Async version of update function"""
|
|
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
|
|
if actor:
|
|
self._set_created_and_updated_by_fields(actor.id)
|
|
self.set_updated_at()
|
|
|
|
object_id = self.id
|
|
class_name = self.__class__.__name__
|
|
|
|
# Snapshot column values before commit so they survive rollback's expire-on-rollback behavior
|
|
_col_snapshot = {c.key: self.__dict__[c.key] for c in self.__class__.__table__.columns if c.key in self.__dict__}
|
|
|
|
for attempt in range(_DEADLOCK_MAX_RETRIES):
|
|
try:
|
|
db_session.add(self)
|
|
if no_commit:
|
|
await db_session.flush()
|
|
else:
|
|
await db_session.commit()
|
|
|
|
if not no_refresh:
|
|
await db_session.refresh(self)
|
|
return self
|
|
except StaleDataError as e:
|
|
raise ConcurrentUpdateError(resource_type=class_name, resource_id=object_id) from e
|
|
except (DBAPIError, IntegrityError) as e:
|
|
if _is_deadlock_error(e) and attempt < _DEADLOCK_MAX_RETRIES - 1:
|
|
logger.warning(
|
|
f"Deadlock detected in {class_name}.update_async (attempt {attempt + 1}/{_DEADLOCK_MAX_RETRIES}), retrying..."
|
|
)
|
|
await db_session.rollback()
|
|
for key, value in _col_snapshot.items():
|
|
setattr(self, key, value)
|
|
await asyncio.sleep(_DEADLOCK_BASE_DELAY * (2**attempt))
|
|
continue
|
|
self._handle_dbapi_error(e)
|
|
|
|
@classmethod
|
|
def _size_preprocess(
|
|
cls,
|
|
*,
|
|
db_session: "Session",
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
check_is_deleted: bool = False,
|
|
**kwargs,
|
|
):
|
|
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
|
|
|
# Security check: if the model has organization_id column, actor should be provided
|
|
if actor is None and hasattr(cls, "organization_id"):
|
|
logger.warning(
|
|
f"SECURITY: Calculating size for org-scoped model {cls.__name__} without actor. This bypasses organization filtering."
|
|
)
|
|
query = select(func.count(1)).select_from(cls)
|
|
|
|
if actor:
|
|
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
|
|
# Apply filtering logic based on kwargs
|
|
for key, value in kwargs.items():
|
|
if value:
|
|
column = getattr(cls, key, None)
|
|
if not column:
|
|
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
|
if isinstance(value, (list, tuple, set)): # Check for iterables
|
|
query = query.where(column.in_(value))
|
|
else: # Single value for equality filtering
|
|
query = query.where(column == value)
|
|
|
|
if check_is_deleted and hasattr(cls, "is_deleted"):
|
|
query = query.where(cls.is_deleted == False)
|
|
|
|
return query
|
|
|
|
@classmethod
|
|
@handle_db_timeout
|
|
async def size_async(
|
|
cls,
|
|
*,
|
|
db_session: "AsyncSession",
|
|
actor: Optional["User"] = None, # noqa: F821
|
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
check_is_deleted: bool = False,
|
|
**kwargs,
|
|
) -> int:
|
|
"""
|
|
Get the count of rows that match the provided filters.
|
|
Args:
|
|
db_session: SQLAlchemy session
|
|
**kwargs: Filters to apply to the query (e.g., column_name=value)
|
|
Returns:
|
|
int: The count of rows that match the filters
|
|
Raises:
|
|
DBAPIError: If a database error occurs
|
|
"""
|
|
query = cls._size_preprocess(
|
|
db_session=db_session,
|
|
actor=actor,
|
|
access=access,
|
|
access_type=access_type,
|
|
check_is_deleted=check_is_deleted,
|
|
**kwargs,
|
|
)
|
|
|
|
try:
|
|
result = await db_session.execute(query)
|
|
count = result.scalar()
|
|
return count if count else 0
|
|
except DBAPIError as e:
|
|
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
|
raise e
|
|
|
|
@classmethod
|
|
def apply_access_predicate(
|
|
cls,
|
|
query: "Select", # noqa: F821
|
|
actor: "User", # noqa: F821
|
|
access: List[Literal["read", "write", "admin"]],
|
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
) -> "Select": # noqa: F821
|
|
"""applies a WHERE clause restricting results to the given actor and access level
|
|
Args:
|
|
query: The initial sqlalchemy select statement
|
|
actor: The user acting on the query. **Note**: this is called 'actor' to identify the
|
|
person or system acting. Users can act on users, making naming very sticky otherwise.
|
|
access:
|
|
what mode of access should the query restrict to? This will be used with granular permissions,
|
|
but because of how it will impact every query we want to be explicitly calling access ahead of time.
|
|
Returns:
|
|
the sqlalchemy select statement restricted to the given access.
|
|
"""
|
|
del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
|
|
if access_type == AccessType.ORGANIZATION:
|
|
org_id = getattr(actor, "organization_id", None)
|
|
if not org_id:
|
|
raise ValueError(f"object {actor} has no organization accessor")
|
|
return query.where(cls.organization_id == org_id)
|
|
elif access_type == AccessType.USER:
|
|
user_id = getattr(actor, "id", None)
|
|
if not user_id:
|
|
raise ValueError(f"object {actor} has no user accessor")
|
|
return query.where(cls.user_id == user_id)
|
|
else:
|
|
raise ValueError(f"unknown access_type: {access_type}")
|
|
|
|
@classmethod
|
|
def _handle_dbapi_error(cls, e: DBAPIError):
|
|
"""Handle database errors and raise appropriate custom exceptions."""
|
|
orig = e.orig # Extract the original error from the DBAPIError
|
|
error_code = None
|
|
error_message = str(orig) if orig else str(e)
|
|
logger.info(f"Handling DBAPIError: {error_message}")
|
|
|
|
# Handle asyncpg QueryCanceledError (wrapped in DBAPIError)
|
|
# This occurs when PostgreSQL's statement_timeout kills a long-running query
|
|
if isinstance(orig, QueryCanceledError):
|
|
logger.error(f"Query canceled (statement timeout) for {cls.__name__}: {e}")
|
|
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout for {cls.__name__}.", original_exception=e) from e
|
|
|
|
if isinstance(orig, DeadlockDetectedError):
|
|
logger.error(f"Deadlock detected for {cls.__name__}: {e}")
|
|
raise DatabaseDeadlockError(message=f"A database deadlock was detected for {cls.__name__}.", original_exception=e) from e
|
|
|
|
# Handle asyncpg LockNotAvailableError (wrapped in DBAPIError)
|
|
# This occurs when a SELECT ... FOR UPDATE NOWAIT or similar fails to acquire a lock
|
|
if isinstance(orig, AsyncpgLockNotAvailableError):
|
|
logger.warning(f"Lock not available for {cls.__name__}: {e}")
|
|
raise DatabaseLockNotAvailableError(
|
|
message=f"Could not acquire lock for {cls.__name__}. Another operation is in progress.", original_exception=e
|
|
) from e
|
|
|
|
# Handle SQLite-specific errors
|
|
if "UNIQUE constraint failed" in error_message:
|
|
raise UniqueConstraintViolationError(
|
|
f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}"
|
|
) from e
|
|
|
|
if "FOREIGN KEY constraint failed" in error_message:
|
|
raise ForeignKeyConstraintViolationError(
|
|
f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}"
|
|
) from e
|
|
|
|
# For psycopg2
|
|
if hasattr(orig, "pgcode"):
|
|
error_code = orig.pgcode
|
|
# For pg8000
|
|
elif hasattr(orig, "args") and len(orig.args) > 0:
|
|
# The first argument contains the error details as a dictionary
|
|
err_dict = orig.args[0]
|
|
if isinstance(err_dict, dict):
|
|
error_code = err_dict.get("C") # 'C' is the error code field
|
|
logger.info(f"Extracted error_code: {error_code}")
|
|
|
|
# Handle unique constraint violations
|
|
if error_code == "23505":
|
|
raise UniqueConstraintViolationError(
|
|
f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}"
|
|
) from e
|
|
|
|
# Handle foreign key violations
|
|
if error_code == "23503":
|
|
raise ForeignKeyConstraintViolationError(
|
|
f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}"
|
|
) from e
|
|
|
|
# Handle deadlock detected
|
|
if error_code == "40P01":
|
|
logger.error(f"Deadlock detected for {cls.__name__}: {e}")
|
|
raise DatabaseDeadlockError(message=f"A database deadlock was detected for {cls.__name__}.", original_exception=e) from e
|
|
|
|
# Handle lock not available (e.g. NOWAIT or lock_timeout exceeded)
|
|
if error_code == "55P03":
|
|
logger.warning(f"Lock not available for {cls.__name__}: {e}")
|
|
raise DatabaseLockNotAvailableError(
|
|
message=f"Could not acquire lock for {cls.__name__}. Another operation is in progress.", original_exception=e
|
|
) from e
|
|
|
|
# Re-raise for other unhandled DBAPI errors
|
|
raise
|
|
|
|
@property
|
|
def __pydantic_model__(self) -> "BaseModel":
|
|
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
|
|
|
|
def to_pydantic(self) -> "BaseModel":
|
|
"""Converts the SQLAlchemy model to its corresponding Pydantic model."""
|
|
model = self.__pydantic_model__.model_validate(self, from_attributes=True)
|
|
|
|
# Explicitly map metadata_ to metadata in Pydantic model
|
|
if hasattr(self, "metadata_") and hasattr(model, "metadata_"):
|
|
setattr(model, "metadata_", self.metadata_) # Ensures correct assignment
|
|
|
|
return model
|
|
|
|
def pretty_print_columns(self) -> str:
|
|
"""
|
|
Pretty prints all columns of the current SQLAlchemy object along with their values.
|
|
"""
|
|
if not hasattr(self, "__table__") or not hasattr(self.__table__, "columns"):
|
|
raise NotImplementedError("This object does not have a '__table__.columns' attribute.")
|
|
|
|
# Iterate over the columns correctly
|
|
column_data = {column.name: getattr(self, column.name, None) for column in self.__table__.columns}
|
|
|
|
return pformat(column_data, indent=4, sort_dicts=True)
|