From 6889835bd721befeba64182b9396718d35c28e83 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Fri, 20 Jun 2025 14:43:34 -0700 Subject: [PATCH] fix: fix null precedence and pagination for list agents (#2927) Co-authored-by: Jin Peng --- .../services/helpers/agent_manager_helper.py | 90 ++++++++++++------- 1 file changed, 57 insertions(+), 33 deletions(-) diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index ebe6dc7b..249e807e 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -2,7 +2,7 @@ import datetime from typing import List, Literal, Optional import numpy as np -from sqlalchemy import Select, and_, asc, desc, func, literal, or_, select, union_all +from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all from sqlalchemy.sql.expression import exists from letta import system @@ -430,23 +430,47 @@ def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> return True -def _cursor_filter(created_at_col, id_col, ref_created_at, ref_id, forward: bool): +def _cursor_filter(sort_col, id_col, ref_sort_col, ref_id, forward: bool, nulls_last: bool = False): """ Returns a SQLAlchemy filter expression for cursor-based pagination. If `forward` is True, returns records after the reference. If `forward` is False, returns records before the reference. + + Handles NULL values in the sort column properly when nulls_last is True. """ - if forward: - return or_( - created_at_col > ref_created_at, - and_(created_at_col == ref_created_at, id_col > ref_id), - ) + if not nulls_last: + # Simple case: no special NULL handling needed + if forward: + return or_( + sort_col > ref_sort_col, + and_(sort_col == ref_sort_col, id_col > ref_id), + ) + else: + return or_( + sort_col < ref_sort_col, + and_(sort_col == ref_sort_col, id_col < ref_id), + ) + + # Handle nulls_last case + # TODO: add tests to check if this works for ascending order but nulls are stil last? + if ref_sort_col is None: + # Reference cursor is at a NULL value + if forward: + # Moving forward (e.g. previous) from NULL: either other NULLs with greater IDs or non-NULLs + return or_(and_(sort_col.is_(None), id_col > ref_id), sort_col.isnot(None)) + else: + # Moving backward (e.g. next) from NULL: NULLs with smaller IDs + return and_(sort_col.is_(None), id_col < ref_id) else: - return or_( - created_at_col < ref_created_at, - and_(created_at_col == ref_created_at, id_col < ref_id), - ) + # Reference cursor is at a non-NULL value + if forward: + # Moving forward (e.g. previous) from non-NULL: only greater non-NULL values + # (NULLs are at the end, so we don't include them when moving forward from non-NULL) + return and_(sort_col.isnot(None), or_(sort_col > ref_sort_col, and_(sort_col == ref_sort_col, id_col > ref_id))) + else: + # Moving backward (e.g. next) from non-NULL: smaller non-NULL values or NULLs + return or_(sort_col.is_(None), or_(sort_col < ref_sort_col, and_(sort_col == ref_sort_col, id_col < ref_id))) def _apply_pagination( @@ -455,30 +479,30 @@ def _apply_pagination( # Determine the sort column if sort_by == "last_run_completion": sort_column = AgentModel.last_run_completion + sort_nulls_last = True # TODO: handle this as a query param eventually else: sort_column = AgentModel.created_at + sort_nulls_last = False if after: - if sort_by == "last_run_completion": - result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after)).first() - else: - result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first() + result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after)).first() if result: after_sort_value, after_id = result - query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending)) + query = query.where( + _cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last) + ) if before: - if sort_by == "last_run_completion": - result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before)).first() - else: - result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first() + result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before)).first() if result: before_sort_value, before_id = result - query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending)) + query = query.where( + _cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last) + ) # Apply ordering order_fn = asc if ascending else desc - query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id)) + query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id)) return query @@ -488,30 +512,30 @@ async def _apply_pagination_async( # Determine the sort column if sort_by == "last_run_completion": sort_column = AgentModel.last_run_completion + sort_nulls_last = True # TODO: handle this as a query param eventually else: sort_column = AgentModel.created_at + sort_nulls_last = False if after: - if sort_by == "last_run_completion": - result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after))).first() - else: - result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first() + result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after))).first() if result: after_sort_value, after_id = result - query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending)) + query = query.where( + _cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last) + ) if before: - if sort_by == "last_run_completion": - result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before))).first() - else: - result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first() + result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before))).first() if result: before_sort_value, before_id = result - query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending)) + query = query.where( + _cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last) + ) # Apply ordering order_fn = asc if ascending else desc - query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id)) + query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id)) return query