From f670bd7619090259f4155c9edba3f540f7c9b35f Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 17 Jul 2025 14:20:50 -0700 Subject: [PATCH] feat: add no_refresh flag to sqlalchemy helpers (#3377) --- letta/orm/sqlalchemy_base.py | 74 ++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 25 deletions(-) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 368d5e30..6555365f 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -656,7 +656,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): self._handle_dbapi_error(e) @handle_db_timeout - async def create_async(self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": + async def create_async( + self, + db_session: "AsyncSession", + actor: Optional["User"] = None, + no_commit: bool = False, + no_refresh: bool = False, + ) -> "SqlalchemyBase": """Async version of create function""" logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") @@ -668,7 +674,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): await db_session.flush() # no commit, just flush to get PK else: await db_session.commit() - await db_session.refresh(self) + + if not no_refresh: + await db_session.refresh(self) return self except (DBAPIError, IntegrityError) as e: self._handle_dbapi_error(e) @@ -717,7 +725,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): @classmethod @handle_db_timeout async def batch_create_async( - cls, items: List["SqlalchemyBase"], db_session: "AsyncSession", actor: Optional["User"] = None + cls, + items: List["SqlalchemyBase"], + db_session: "AsyncSession", + actor: Optional["User"] = None, + no_commit: bool = False, + no_refresh: bool = False, ) -> List["SqlalchemyBase"]: """ Async version of batch_create method. @@ -726,10 +739,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): items: List of model instances to create db_session: AsyncSession session actor: Optional user performing the action + no_commit: Whether to commit the transaction + no_refresh: Whether to refresh the created objects Returns: List of created model instances """ logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}") + if not items: return [] @@ -740,21 +756,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): try: db_session.add_all(items) - await db_session.flush() # Flush to generate IDs but don't commit yet + if no_commit: + await db_session.flush() + else: + await db_session.commit() - # Collect IDs to fetch the complete objects after commit - item_ids = [item.id for item in items] - - await db_session.commit() - - # Re-query the objects to get them with relationships loaded - query = select(cls).where(cls.id.in_(item_ids)) - if hasattr(cls, "created_at"): - query = query.order_by(cls.created_at) - - result = await db_session.execute(query) - return list(result.scalars()) + if no_refresh: + return items + else: + # Re-query the objects to get them with relationships loaded + item_ids = [item.id for item in items] + query = select(cls).where(cls.id.in_(item_ids)) + if hasattr(cls, "created_at"): + query = query.order_by(cls.created_at) + result = await db_session.execute(query) + return list(result.scalars()) except (DBAPIError, IntegrityError) as e: cls._handle_dbapi_error(e) @@ -854,20 +871,27 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): return self @handle_db_timeout - async def update_async(self, db_session: AsyncSession, actor: "User | None" = None, no_commit: bool = False) -> "SqlalchemyBase": + async def update_async( + self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False + ) -> "SqlalchemyBase": """Async version of update function""" - logger.debug(...) + logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}") + if actor: self._set_created_and_updated_by_fields(actor.id) self.set_updated_at() + try: + db_session.add(self) + if no_commit: + await db_session.flush() + else: + await db_session.commit() - db_session.add(self) - if no_commit: - await db_session.flush() - else: - await db_session.commit() - await db_session.refresh(self) - return self + if not no_refresh: + await db_session.refresh(self) + return self + except (DBAPIError, IntegrityError) as e: + self._handle_dbapi_error(e) @classmethod def _size_preprocess(