fix: fix null precedence and pagination for list agents (#2927)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -2,7 +2,7 @@ import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import Select, and_, asc, desc, func, literal, or_, select, union_all
|
||||
from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all
|
||||
from sqlalchemy.sql.expression import exists
|
||||
|
||||
from letta import system
|
||||
@@ -430,23 +430,47 @@ def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) ->
|
||||
return True
|
||||
|
||||
|
||||
def _cursor_filter(created_at_col, id_col, ref_created_at, ref_id, forward: bool):
|
||||
def _cursor_filter(sort_col, id_col, ref_sort_col, ref_id, forward: bool, nulls_last: bool = False):
|
||||
"""
|
||||
Returns a SQLAlchemy filter expression for cursor-based pagination.
|
||||
|
||||
If `forward` is True, returns records after the reference.
|
||||
If `forward` is False, returns records before the reference.
|
||||
|
||||
Handles NULL values in the sort column properly when nulls_last is True.
|
||||
"""
|
||||
if forward:
|
||||
return or_(
|
||||
created_at_col > ref_created_at,
|
||||
and_(created_at_col == ref_created_at, id_col > ref_id),
|
||||
)
|
||||
if not nulls_last:
|
||||
# Simple case: no special NULL handling needed
|
||||
if forward:
|
||||
return or_(
|
||||
sort_col > ref_sort_col,
|
||||
and_(sort_col == ref_sort_col, id_col > ref_id),
|
||||
)
|
||||
else:
|
||||
return or_(
|
||||
sort_col < ref_sort_col,
|
||||
and_(sort_col == ref_sort_col, id_col < ref_id),
|
||||
)
|
||||
|
||||
# Handle nulls_last case
|
||||
# TODO: add tests to check if this works for ascending order but nulls are stil last?
|
||||
if ref_sort_col is None:
|
||||
# Reference cursor is at a NULL value
|
||||
if forward:
|
||||
# Moving forward (e.g. previous) from NULL: either other NULLs with greater IDs or non-NULLs
|
||||
return or_(and_(sort_col.is_(None), id_col > ref_id), sort_col.isnot(None))
|
||||
else:
|
||||
# Moving backward (e.g. next) from NULL: NULLs with smaller IDs
|
||||
return and_(sort_col.is_(None), id_col < ref_id)
|
||||
else:
|
||||
return or_(
|
||||
created_at_col < ref_created_at,
|
||||
and_(created_at_col == ref_created_at, id_col < ref_id),
|
||||
)
|
||||
# Reference cursor is at a non-NULL value
|
||||
if forward:
|
||||
# Moving forward (e.g. previous) from non-NULL: only greater non-NULL values
|
||||
# (NULLs are at the end, so we don't include them when moving forward from non-NULL)
|
||||
return and_(sort_col.isnot(None), or_(sort_col > ref_sort_col, and_(sort_col == ref_sort_col, id_col > ref_id)))
|
||||
else:
|
||||
# Moving backward (e.g. next) from non-NULL: smaller non-NULL values or NULLs
|
||||
return or_(sort_col.is_(None), or_(sort_col < ref_sort_col, and_(sort_col == ref_sort_col, id_col < ref_id)))
|
||||
|
||||
|
||||
def _apply_pagination(
|
||||
@@ -455,30 +479,30 @@ def _apply_pagination(
|
||||
# Determine the sort column
|
||||
if sort_by == "last_run_completion":
|
||||
sort_column = AgentModel.last_run_completion
|
||||
sort_nulls_last = True # TODO: handle this as a query param eventually
|
||||
else:
|
||||
sort_column = AgentModel.created_at
|
||||
sort_nulls_last = False
|
||||
|
||||
if after:
|
||||
if sort_by == "last_run_completion":
|
||||
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after)).first()
|
||||
else:
|
||||
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
|
||||
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after)).first()
|
||||
if result:
|
||||
after_sort_value, after_id = result
|
||||
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
|
||||
query = query.where(
|
||||
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
|
||||
)
|
||||
|
||||
if before:
|
||||
if sort_by == "last_run_completion":
|
||||
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before)).first()
|
||||
else:
|
||||
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
|
||||
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before)).first()
|
||||
if result:
|
||||
before_sort_value, before_id = result
|
||||
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
|
||||
query = query.where(
|
||||
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
|
||||
)
|
||||
|
||||
# Apply ordering
|
||||
order_fn = asc if ascending else desc
|
||||
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
|
||||
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
|
||||
return query
|
||||
|
||||
|
||||
@@ -488,30 +512,30 @@ async def _apply_pagination_async(
|
||||
# Determine the sort column
|
||||
if sort_by == "last_run_completion":
|
||||
sort_column = AgentModel.last_run_completion
|
||||
sort_nulls_last = True # TODO: handle this as a query param eventually
|
||||
else:
|
||||
sort_column = AgentModel.created_at
|
||||
sort_nulls_last = False
|
||||
|
||||
if after:
|
||||
if sort_by == "last_run_completion":
|
||||
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after))).first()
|
||||
else:
|
||||
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
|
||||
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after))).first()
|
||||
if result:
|
||||
after_sort_value, after_id = result
|
||||
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
|
||||
query = query.where(
|
||||
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
|
||||
)
|
||||
|
||||
if before:
|
||||
if sort_by == "last_run_completion":
|
||||
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before))).first()
|
||||
else:
|
||||
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
|
||||
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before))).first()
|
||||
if result:
|
||||
before_sort_value, before_id = result
|
||||
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
|
||||
query = query.where(
|
||||
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
|
||||
)
|
||||
|
||||
# Apply ordering
|
||||
order_fn = asc if ascending else desc
|
||||
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
|
||||
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
|
||||
return query
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user