fix: fix null precedence and pagination for list agents (#2927)

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-06-20 14:43:34 -07:00
committed by GitHub
parent 61cbcba472
commit 6889835bd7

View File

@@ -2,7 +2,7 @@ import datetime
from typing import List, Literal, Optional
import numpy as np
from sqlalchemy import Select, and_, asc, desc, func, literal, or_, select, union_all
from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all
from sqlalchemy.sql.expression import exists
from letta import system
@@ -430,23 +430,47 @@ def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) ->
return True
def _cursor_filter(created_at_col, id_col, ref_created_at, ref_id, forward: bool):
def _cursor_filter(sort_col, id_col, ref_sort_col, ref_id, forward: bool, nulls_last: bool = False):
"""
Returns a SQLAlchemy filter expression for cursor-based pagination.
If `forward` is True, returns records after the reference.
If `forward` is False, returns records before the reference.
Handles NULL values in the sort column properly when nulls_last is True.
"""
if forward:
return or_(
created_at_col > ref_created_at,
and_(created_at_col == ref_created_at, id_col > ref_id),
)
if not nulls_last:
# Simple case: no special NULL handling needed
if forward:
return or_(
sort_col > ref_sort_col,
and_(sort_col == ref_sort_col, id_col > ref_id),
)
else:
return or_(
sort_col < ref_sort_col,
and_(sort_col == ref_sort_col, id_col < ref_id),
)
# Handle nulls_last case
# TODO: add tests to check if this works for ascending order but nulls are stil last?
if ref_sort_col is None:
# Reference cursor is at a NULL value
if forward:
# Moving forward (e.g. previous) from NULL: either other NULLs with greater IDs or non-NULLs
return or_(and_(sort_col.is_(None), id_col > ref_id), sort_col.isnot(None))
else:
# Moving backward (e.g. next) from NULL: NULLs with smaller IDs
return and_(sort_col.is_(None), id_col < ref_id)
else:
return or_(
created_at_col < ref_created_at,
and_(created_at_col == ref_created_at, id_col < ref_id),
)
# Reference cursor is at a non-NULL value
if forward:
# Moving forward (e.g. previous) from non-NULL: only greater non-NULL values
# (NULLs are at the end, so we don't include them when moving forward from non-NULL)
return and_(sort_col.isnot(None), or_(sort_col > ref_sort_col, and_(sort_col == ref_sort_col, id_col > ref_id)))
else:
# Moving backward (e.g. next) from non-NULL: smaller non-NULL values or NULLs
return or_(sort_col.is_(None), or_(sort_col < ref_sort_col, and_(sort_col == ref_sort_col, id_col < ref_id)))
def _apply_pagination(
@@ -455,30 +479,30 @@ def _apply_pagination(
# Determine the sort column
if sort_by == "last_run_completion":
sort_column = AgentModel.last_run_completion
sort_nulls_last = True # TODO: handle this as a query param eventually
else:
sort_column = AgentModel.created_at
sort_nulls_last = False
if after:
if sort_by == "last_run_completion":
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after)).first()
else:
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after)).first()
if result:
after_sort_value, after_id = result
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
query = query.where(
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
)
if before:
if sort_by == "last_run_completion":
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before)).first()
else:
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before)).first()
if result:
before_sort_value, before_id = result
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
query = query.where(
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
)
# Apply ordering
order_fn = asc if ascending else desc
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
return query
@@ -488,30 +512,30 @@ async def _apply_pagination_async(
# Determine the sort column
if sort_by == "last_run_completion":
sort_column = AgentModel.last_run_completion
sort_nulls_last = True # TODO: handle this as a query param eventually
else:
sort_column = AgentModel.created_at
sort_nulls_last = False
if after:
if sort_by == "last_run_completion":
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after))).first()
else:
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after))).first()
if result:
after_sort_value, after_id = result
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
query = query.where(
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
)
if before:
if sort_by == "last_run_completion":
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before))).first()
else:
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before))).first()
if result:
before_sort_value, before_id = result
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
query = query.where(
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
)
# Apply ordering
order_fn = asc if ascending else desc
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
return query