fix: Modify the list ORM function (#2208)

This commit is contained in:
Matthew Zhou
2024-12-09 19:35:58 -08:00
committed by GitHub
parent d61b2f9545
commit 2125421bd8
8 changed files with 83 additions and 106 deletions

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, List, Literal, Optional, Type
from sqlalchemy import String, func, select
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Mapped, Session, mapped_column
@@ -60,14 +60,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
ascending: bool = True,
**kwargs,
) -> List[Type["SqlalchemyBase"]]:
"""List records with advanced filtering and pagination options."""
"""
List records with cursor-based pagination, ordering by created_at.
Cursor is an ID, but pagination is based on the cursor object's created_at value.
"""
if start_date and end_date and start_date > end_date:
raise ValueError("start_date must be earlier than or equal to end_date")
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
with db_session as session:
# If cursor provided, get the reference object
cursor_obj = None
if cursor:
cursor_obj = session.get(cls, cursor)
if not cursor_obj:
raise NoResultFound(f"No {cls.__name__} found with id {cursor}")
query = select(cls)
# Apply filtering logic
@@ -80,22 +91,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# Date range filtering
if start_date:
query = query.filter(cls.created_at >= start_date)
query = query.filter(cls.created_at > start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
query = query.filter(cls.created_at < end_date)
# Cursor-based pagination
if cursor:
query = query.where(cls.id > cursor)
# Cursor-based pagination using created_at
# TODO: There is a really nasty race condition issue here with Sqlite
# TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
if cursor_obj:
if ascending:
query = query.where(cls.created_at >= cursor_obj.created_at).where(
or_(cls.created_at > cursor_obj.created_at, cls.id > cursor_obj.id)
)
else:
query = query.where(cls.created_at <= cursor_obj.created_at).where(
or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
)
# Apply text search
if query_text:
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
# Handle ordering and soft deletes
# Handle soft deletes
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
query = query.order_by(cls.id).limit(limit)
# Apply ordering by created_at
if ascending:
query = query.order_by(cls.created_at, cls.id)
else:
query = query.order_by(desc(cls.created_at), desc(cls.id))
query = query.limit(limit)
return list(session.execute(query).scalars())