fix: sort order (#1574)

Co-authored-by: robingotz <tug29225@temple.edu>
This commit is contained in:
Sarah Wooders
2024-07-26 14:22:38 -07:00
committed by GitHub
parent 4d758a3073
commit 7c5d68be86
2 changed files with 8 additions and 15 deletions

View File

@@ -18,7 +18,6 @@ from sqlalchemy import (
asc,
create_engine,
desc,
func,
or_,
select,
text,
@@ -314,19 +313,13 @@ class SQLStorageConnector(StorageConnector):
# cursor logic: filter records based on before/after ID
if after:
after_value = getattr(self.get(id=after), order_by)
if reverse: # if reverse, then we want to get records that are less than the after_value
sort_exp = getattr(self.db_model, order_by) < after_value
else: # otherwise, we want to get records that are greater than the after_value
sort_exp = getattr(self.db_model, order_by) > after_value
sort_exp = getattr(self.db_model, order_by) > after_value
query = query.filter(
or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
)
if before:
before_value = getattr(self.get(id=before), order_by)
if reverse:
sort_exp = getattr(self.db_model, order_by) > before_value
else:
sort_exp = getattr(self.db_model, order_by) < before_value
sort_exp = getattr(self.db_model, order_by) < before_value
query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
# get records

View File

@@ -154,14 +154,14 @@ def test_user_message(server, user_id, agent_id):
@pytest.mark.order(5)
def test_get_recall_memory(server, user_id, agent_id):
# test recall memory cursor pagination
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=1000)
[m["id"] for m in messages_3]
[m["id"] for m in messages_2]
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
# [m["id"] for m in messages_3]
# [m["id"] for m in messages_2]
timestamps = [m["created_at"] for m in messages_3]
print("timestamps", timestamps)
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
assert messages_3[-1]["created_at"] >= messages_3[0]["created_at"]
assert len(messages_3) == len(messages_1) + len(messages_2)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1