diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index c92de777..d602a049 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -18,7 +18,6 @@ from sqlalchemy import ( asc, create_engine, desc, - func, or_, select, text, @@ -314,19 +313,13 @@ class SQLStorageConnector(StorageConnector): # cursor logic: filter records based on before/after ID if after: after_value = getattr(self.get(id=after), order_by) - if reverse: # if reverse, then we want to get records that are less than the after_value - sort_exp = getattr(self.db_model, order_by) < after_value - else: # otherwise, we want to get records that are greater than the after_value - sort_exp = getattr(self.db_model, order_by) > after_value + sort_exp = getattr(self.db_model, order_by) > after_value query = query.filter( or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case ) if before: before_value = getattr(self.get(id=before), order_by) - if reverse: - sort_exp = getattr(self.db_model, order_by) > before_value - else: - sort_exp = getattr(self.db_model, order_by) < before_value + sort_exp = getattr(self.db_model, order_by) < before_value query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before))) # get records diff --git a/tests/test_server.py b/tests/test_server.py index 9b79f451..bcae5b90 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -154,14 +154,14 @@ def test_user_message(server, user_id, agent_id): @pytest.mark.order(5) def test_get_recall_memory(server, user_id, agent_id): # test recall memory cursor pagination - cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=2) - cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, after=cursor1, limit=1000) - cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=1000) - [m["id"] for m in messages_3] - [m["id"] for m in messages_2] + cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2) + cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000) + cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000) + # [m["id"] for m in messages_3] + # [m["id"] for m in messages_2] timestamps = [m["created_at"] for m in messages_3] print("timestamps", timestamps) - assert messages_3[-1]["created_at"] < messages_3[0]["created_at"] + assert messages_3[-1]["created_at"] >= messages_3[0]["created_at"] assert len(messages_3) == len(messages_1) + len(messages_2) cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1) assert len(messages_4) == 1