feat: add no_refresh flag to sqlalchemy helpers (#3377)

This commit is contained in:
cthomas
2025-07-17 14:20:50 -07:00
committed by GitHub
parent 7939d2c366
commit f670bd7619

View File

@@ -656,7 +656,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
self._handle_dbapi_error(e)
@handle_db_timeout
async def create_async(self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
async def create_async(
self,
db_session: "AsyncSession",
actor: Optional["User"] = None,
no_commit: bool = False,
no_refresh: bool = False,
) -> "SqlalchemyBase":
"""Async version of create function"""
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
@@ -668,7 +674,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
await db_session.flush() # no commit, just flush to get PK
else:
await db_session.commit()
await db_session.refresh(self)
if not no_refresh:
await db_session.refresh(self)
return self
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)
@@ -717,7 +725,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
@classmethod
@handle_db_timeout
async def batch_create_async(
cls, items: List["SqlalchemyBase"], db_session: "AsyncSession", actor: Optional["User"] = None
cls,
items: List["SqlalchemyBase"],
db_session: "AsyncSession",
actor: Optional["User"] = None,
no_commit: bool = False,
no_refresh: bool = False,
) -> List["SqlalchemyBase"]:
"""
Async version of batch_create method.
@@ -726,10 +739,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
items: List of model instances to create
db_session: AsyncSession session
actor: Optional user performing the action
no_commit: Whether to commit the transaction
no_refresh: Whether to refresh the created objects
Returns:
List of created model instances
"""
logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}")
if not items:
return []
@@ -740,21 +756,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
try:
db_session.add_all(items)
await db_session.flush() # Flush to generate IDs but don't commit yet
if no_commit:
await db_session.flush()
else:
await db_session.commit()
# Collect IDs to fetch the complete objects after commit
item_ids = [item.id for item in items]
await db_session.commit()
# Re-query the objects to get them with relationships loaded
query = select(cls).where(cls.id.in_(item_ids))
if hasattr(cls, "created_at"):
query = query.order_by(cls.created_at)
result = await db_session.execute(query)
return list(result.scalars())
if no_refresh:
return items
else:
# Re-query the objects to get them with relationships loaded
item_ids = [item.id for item in items]
query = select(cls).where(cls.id.in_(item_ids))
if hasattr(cls, "created_at"):
query = query.order_by(cls.created_at)
result = await db_session.execute(query)
return list(result.scalars())
except (DBAPIError, IntegrityError) as e:
cls._handle_dbapi_error(e)
@@ -854,20 +871,27 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
return self
@handle_db_timeout
async def update_async(self, db_session: AsyncSession, actor: "User | None" = None, no_commit: bool = False) -> "SqlalchemyBase":
async def update_async(
self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False
) -> "SqlalchemyBase":
"""Async version of update function"""
logger.debug(...)
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
self._set_created_and_updated_by_fields(actor.id)
self.set_updated_at()
try:
db_session.add(self)
if no_commit:
await db_session.flush()
else:
await db_session.commit()
db_session.add(self)
if no_commit:
await db_session.flush()
else:
await db_session.commit()
await db_session.refresh(self)
return self
if not no_refresh:
await db_session.refresh(self)
return self
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)
@classmethod
def _size_preprocess(