fix: sort order (#1574)
Co-authored-by: robingotz <tug29225@temple.edu>
This commit is contained in:
@@ -18,7 +18,6 @@ from sqlalchemy import (
|
||||
asc,
|
||||
create_engine,
|
||||
desc,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
text,
|
||||
@@ -314,19 +313,13 @@ class SQLStorageConnector(StorageConnector):
|
||||
# cursor logic: filter records based on before/after ID
|
||||
if after:
|
||||
after_value = getattr(self.get(id=after), order_by)
|
||||
if reverse: # if reverse, then we want to get records that are less than the after_value
|
||||
sort_exp = getattr(self.db_model, order_by) < after_value
|
||||
else: # otherwise, we want to get records that are greater than the after_value
|
||||
sort_exp = getattr(self.db_model, order_by) > after_value
|
||||
sort_exp = getattr(self.db_model, order_by) > after_value
|
||||
query = query.filter(
|
||||
or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
|
||||
)
|
||||
if before:
|
||||
before_value = getattr(self.get(id=before), order_by)
|
||||
if reverse:
|
||||
sort_exp = getattr(self.db_model, order_by) > before_value
|
||||
else:
|
||||
sort_exp = getattr(self.db_model, order_by) < before_value
|
||||
sort_exp = getattr(self.db_model, order_by) < before_value
|
||||
query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
|
||||
|
||||
# get records
|
||||
|
||||
@@ -154,14 +154,14 @@ def test_user_message(server, user_id, agent_id):
|
||||
@pytest.mark.order(5)
|
||||
def test_get_recall_memory(server, user_id, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=1000)
|
||||
[m["id"] for m in messages_3]
|
||||
[m["id"] for m in messages_2]
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
# [m["id"] for m in messages_3]
|
||||
# [m["id"] for m in messages_2]
|
||||
timestamps = [m["created_at"] for m in messages_3]
|
||||
print("timestamps", timestamps)
|
||||
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
|
||||
assert messages_3[-1]["created_at"] >= messages_3[0]["created_at"]
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
Reference in New Issue
Block a user