feat: add no_refresh flag to sqlalchemy helpers (#3377)
This commit is contained in:
@@ -656,7 +656,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
@handle_db_timeout
|
||||
async def create_async(self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
|
||||
async def create_async(
|
||||
self,
|
||||
db_session: "AsyncSession",
|
||||
actor: Optional["User"] = None,
|
||||
no_commit: bool = False,
|
||||
no_refresh: bool = False,
|
||||
) -> "SqlalchemyBase":
|
||||
"""Async version of create function"""
|
||||
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
@@ -668,7 +674,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
await db_session.flush() # no commit, just flush to get PK
|
||||
else:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(self)
|
||||
|
||||
if not no_refresh:
|
||||
await db_session.refresh(self)
|
||||
return self
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
self._handle_dbapi_error(e)
|
||||
@@ -717,7 +725,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
async def batch_create_async(
|
||||
cls, items: List["SqlalchemyBase"], db_session: "AsyncSession", actor: Optional["User"] = None
|
||||
cls,
|
||||
items: List["SqlalchemyBase"],
|
||||
db_session: "AsyncSession",
|
||||
actor: Optional["User"] = None,
|
||||
no_commit: bool = False,
|
||||
no_refresh: bool = False,
|
||||
) -> List["SqlalchemyBase"]:
|
||||
"""
|
||||
Async version of batch_create method.
|
||||
@@ -726,10 +739,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
items: List of model instances to create
|
||||
db_session: AsyncSession session
|
||||
actor: Optional user performing the action
|
||||
no_commit: Whether to commit the transaction
|
||||
no_refresh: Whether to refresh the created objects
|
||||
Returns:
|
||||
List of created model instances
|
||||
"""
|
||||
logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}")
|
||||
|
||||
if not items:
|
||||
return []
|
||||
|
||||
@@ -740,21 +756,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
try:
|
||||
db_session.add_all(items)
|
||||
await db_session.flush() # Flush to generate IDs but don't commit yet
|
||||
if no_commit:
|
||||
await db_session.flush()
|
||||
else:
|
||||
await db_session.commit()
|
||||
|
||||
# Collect IDs to fetch the complete objects after commit
|
||||
item_ids = [item.id for item in items]
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
# Re-query the objects to get them with relationships loaded
|
||||
query = select(cls).where(cls.id.in_(item_ids))
|
||||
if hasattr(cls, "created_at"):
|
||||
query = query.order_by(cls.created_at)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
return list(result.scalars())
|
||||
if no_refresh:
|
||||
return items
|
||||
else:
|
||||
# Re-query the objects to get them with relationships loaded
|
||||
item_ids = [item.id for item in items]
|
||||
query = select(cls).where(cls.id.in_(item_ids))
|
||||
if hasattr(cls, "created_at"):
|
||||
query = query.order_by(cls.created_at)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
return list(result.scalars())
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
cls._handle_dbapi_error(e)
|
||||
|
||||
@@ -854,20 +871,27 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
return self
|
||||
|
||||
@handle_db_timeout
|
||||
async def update_async(self, db_session: AsyncSession, actor: "User | None" = None, no_commit: bool = False) -> "SqlalchemyBase":
|
||||
async def update_async(
|
||||
self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False
|
||||
) -> "SqlalchemyBase":
|
||||
"""Async version of update function"""
|
||||
logger.debug(...)
|
||||
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
self.set_updated_at()
|
||||
try:
|
||||
db_session.add(self)
|
||||
if no_commit:
|
||||
await db_session.flush()
|
||||
else:
|
||||
await db_session.commit()
|
||||
|
||||
db_session.add(self)
|
||||
if no_commit:
|
||||
await db_session.flush()
|
||||
else:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(self)
|
||||
return self
|
||||
if not no_refresh:
|
||||
await db_session.refresh(self)
|
||||
return self
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
@classmethod
|
||||
def _size_preprocess(
|
||||
|
||||
Reference in New Issue
Block a user