diff --git a/letta/client/client.py b/letta/client/client.py index 6456aa3f..a0b6729b 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -3128,7 +3128,7 @@ class LocalClient(AbstractClient): return self.server.get_agent_recall_cursor( user_id=self.user_id, agent_id=agent_id, - cursor=cursor, + before=cursor, limit=limit, reverse=True, ) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 4507f51c..6f8a7644 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, List, Literal, Optional, Type -from sqlalchemy import String, func, select +from sqlalchemy import String, desc, func, or_, select from sqlalchemy.exc import DBAPIError from sqlalchemy.orm import Mapped, Session, mapped_column @@ -60,14 +60,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): end_date: Optional[datetime] = None, limit: Optional[int] = 50, query_text: Optional[str] = None, + ascending: bool = True, **kwargs, ) -> List[Type["SqlalchemyBase"]]: - """List records with advanced filtering and pagination options.""" + """ + List records with cursor-based pagination, ordering by created_at. + Cursor is an ID, but pagination is based on the cursor object's created_at value. + """ if start_date and end_date and start_date > end_date: raise ValueError("start_date must be earlier than or equal to end_date") logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}") with db_session as session: + # If cursor provided, get the reference object + cursor_obj = None + if cursor: + cursor_obj = session.get(cls, cursor) + if not cursor_obj: + raise NoResultFound(f"No {cls.__name__} found with id {cursor}") + query = select(cls) # Apply filtering logic @@ -80,22 +91,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): # Date range filtering if start_date: - query = query.filter(cls.created_at >= start_date) + query = query.filter(cls.created_at > start_date) if end_date: - query = query.filter(cls.created_at <= end_date) + query = query.filter(cls.created_at < end_date) - # Cursor-based pagination - if cursor: - query = query.where(cls.id > cursor) + # Cursor-based pagination using created_at + # TODO: There is a really nasty race condition issue here with Sqlite + # TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason + if cursor_obj: + if ascending: + query = query.where(cls.created_at >= cursor_obj.created_at).where( + or_(cls.created_at > cursor_obj.created_at, cls.id > cursor_obj.id) + ) + else: + query = query.where(cls.created_at <= cursor_obj.created_at).where( + or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id) + ) # Apply text search if query_text: query = query.filter(func.lower(cls.text).contains(func.lower(query_text))) - # Handle ordering and soft deletes + # Handle soft deletes if hasattr(cls, "is_deleted"): query = query.where(cls.is_deleted == False) - query = query.order_by(cls.id).limit(limit) + + # Apply ordering by created_at + if ascending: + query = query.order_by(cls.created_at, cls.id) + else: + query = query.order_by(desc(cls.created_at), desc(cls.id)) + + query = query.limit(limit) return list(session.execute(query).scalars()) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index e7f68dc9..2a9471d9 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -420,7 +420,7 @@ def get_agent_messages( return server.get_agent_recall_cursor( user_id=actor.id, agent_id=agent_id, - cursor=before, + before=before, limit=limit, reverse=True, return_message_object=msg_object, diff --git a/letta/server/server.py b/letta/server/server.py index 965c9b27..e20afd17 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -101,11 +101,6 @@ class Server(object): """List all available agents to a user""" raise NotImplementedError - @abstractmethod - def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list: - """Paginated query of in-context messages in agent message queue""" - raise NotImplementedError - @abstractmethod def get_agent_memory(self, user_id: str, agent_id: str) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" @@ -1173,55 +1168,6 @@ class SyncServer(Server): message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user) return message - def get_agent_messages( - self, - agent_id: str, - start: int, - count: int, - ) -> Union[List[Message], List[LettaMessage]]: - """Paginated query of all messages in agent message queue""" - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id) - - if start < 0 or count < 0: - raise ValueError("Start and count values should be non-negative") - - if start + count < len(letta_agent._messages): # messages can be returned from whats in memory - # Reverse the list to make it in reverse chronological order - reversed_messages = letta_agent._messages[::-1] - # Check if start is within the range of the list - if start >= len(reversed_messages): - raise IndexError("Start index is out of range") - - # Calculate the end index, ensuring it does not exceed the list length - end_index = min(start + count, len(reversed_messages)) - - # Slice the list for pagination - messages = reversed_messages[start:end_index] - - else: - # need to access persistence manager for additional messages - - # get messages using message manager - page = letta_agent.message_manager.list_user_messages_for_agent( - agent_id=agent_id, - actor=self.default_user, - cursor=start, - limit=count, - ) - - messages = page - assert all(isinstance(m, Message) for m in messages) - - ## Convert to json - ## Add a tag indicating in-context or not - # json_messages = [record.to_json() for record in messages] - # in_context_message_ids = [str(m.id) for m in letta_agent._messages] - # for d in json_messages: - # d["in_context"] = True if str(d["id"]) in in_context_message_ids else False - - return messages - def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]: """Paginated query of all messages in agent archival memory""" if self.user_manager.get_user_by_id(user_id=user_id) is None: @@ -1303,7 +1249,8 @@ class SyncServer(Server): self, user_id: str, agent_id: str, - cursor: Optional[str] = None, + after: Optional[str] = None, + before: Optional[str] = None, limit: Optional[int] = 100, reverse: Optional[bool] = False, return_message_object: bool = True, @@ -1320,12 +1267,15 @@ class SyncServer(Server): letta_agent = self.load_agent(agent_id=agent_id) # iterate over records - # TODO: Check "order_by", "order" + start_date = self.message_manager.get_message_by_id(after, actor=actor).created_at if after else None + end_date = self.message_manager.get_message_by_id(before, actor=actor).created_at if before else None records = letta_agent.message_manager.list_messages_for_agent( agent_id=agent_id, actor=actor, - cursor=cursor, + start_date=start_date, + end_date=end_date, limit=limit, + ascending=not reverse, ) assert all(isinstance(m, Message) for m in records) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index b9932b39..b4151944 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -119,6 +119,7 @@ class MessageManager: limit: Optional[int] = 50, filters: Optional[Dict] = None, query_text: Optional[str] = None, + ascending: bool = True, ) -> List[PydanticMessage]: """List user messages with flexible filtering and pagination options. @@ -159,6 +160,7 @@ class MessageManager: limit: Optional[int] = 50, filters: Optional[Dict] = None, query_text: Optional[str] = None, + ascending: bool = True, ) -> List[PydanticMessage]: """List messages with flexible filtering and pagination options. @@ -188,6 +190,7 @@ class MessageManager: end_date=end_date, limit=limit, query_text=query_text, + ascending=ascending, **message_filters, ) diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index 4c95184c..2a76a147 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -459,7 +459,7 @@ class ToolExecutionSandbox: Generate the code string to call the function. Args: - inject_agent_state (bool): Whether to inject the agent's state as an input into the tool + inject_agent_state (bool): Whether to inject the axgent's state as an input into the tool Returns: str: Generated code string for calling the tool diff --git a/tests/test_managers.py b/tests/test_managers.py index dc476938..74675b4c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1,4 +1,5 @@ import os +import time from datetime import datetime, timedelta import pytest @@ -73,8 +74,8 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( azure_version=None, azure_deployment=None, ) - -using_sqlite = not bool(os.getenv("LETTA_PG_URI")) +CREATE_DELAY_SQLITE = 1 +USING_SQLITE = not bool(os.getenv("LETTA_PG_URI")) @pytest.fixture(autouse=True) @@ -911,6 +912,8 @@ def test_list_sources(server: SyncServer, default_user): """Test listing sources with pagination.""" # Create multiple sources server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) # List sources without pagination @@ -1004,6 +1007,8 @@ def test_list_files(server: SyncServer, default_user, default_source): PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id), actor=default_user, ) + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) server.source_manager.create_file( PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id), actor=default_user, @@ -1184,6 +1189,8 @@ def test_list_sandbox_configs(server: SyncServer, default_user): config=LocalSandboxConfig(sandbox_dir=""), ) server.sandbox_config_manager.create_or_update_sandbox_config(config_a, actor=default_user) + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) server.sandbox_config_manager.create_or_update_sandbox_config(config_b, actor=default_user) # List configs without pagination @@ -1239,6 +1246,8 @@ def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, defau env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1") env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2") server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user) # List env vars without pagination @@ -1299,7 +1308,7 @@ def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agen assert default_block.label not in labels -@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite") +@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user): with pytest.raises(ForeignKeyConstraintViolationError): server.blocks_agents_manager.add_block_to_agent( @@ -1361,7 +1370,7 @@ def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_u assert len(agent_ids) == 2 -@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite") +@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block): block_manager = BlockManager() block_manager.delete_block(block_id=default_block.id, actor=default_user) @@ -1401,7 +1410,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, assert print_tool.name not in names -@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite") +@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user): with pytest.raises(ForeignKeyConstraintViolationError): server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name") @@ -1447,7 +1456,7 @@ def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_us assert len(agent_ids) == 2 -@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite") +@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite") def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool): tool_manager = ToolManager() tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user) diff --git a/tests/test_server.py b/tests/test_server.py index c85f12ca..8d85cc1c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -161,37 +161,25 @@ def test_user_message(server, user_id, agent_id): # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") -# TODO: Add this back, this is broken on main -# @pytest.mark.order(5) -# def test_get_recall_memory(server, org_id, user_id, agent_id): -# # test recall memory cursor pagination -# messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2) -# cursor1 = messages_1[-1].id -# messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000) -# messages_2[-1].id -# messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000) -# messages_3[-1].id -# assert messages_3[-1].created_at >= messages_3[0].created_at -# assert len(messages_3) == len(messages_1) + len(messages_2) -# messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1) -# assert len(messages_4) == 1 -# -# # test in-context message ids -# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id) -# message_ids = [m.id for m in messages_3] -# for message_id in in_context_ids: -# assert message_id in message_ids, f"{message_id} not in {message_ids}" -# -# # test recall memory -# messages_1 = server.get_agent_messages(agent_id=agent_id, start=0, count=1) -# assert len(messages_1) == 1 -# messages_2 = server.get_agent_messages(agent_id=agent_id, start=1, count=1000) -# messages_3 = server.get_agent_messages(agent_id=agent_id, start=1, count=2) -# # not sure exactly how many messages there should be -# assert len(messages_2) > len(messages_3) -# # test safe empty return -# messages_none = server.get_agent_messages(agent_id=agent_id, start=1000, count=1000) -# assert len(messages_none) == 0 +@pytest.mark.order(5) +def test_get_recall_memory(server, org_id, user_id, agent_id): + # test recall memory cursor pagination + messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2) + cursor1 = messages_1[-1].id + messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000) + messages_2[-1].id + messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000) + messages_3[-1].id + assert messages_3[-1].created_at >= messages_3[0].created_at + assert len(messages_3) == len(messages_1) + len(messages_2) + messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1) + assert len(messages_4) == 1 + + # test in-context message ids + in_context_ids = server.get_in_context_message_ids(agent_id=agent_id) + message_ids = [m.id for m in messages_3] + for message_id in in_context_ids: + assert message_id in message_ids, f"{message_id} not in {message_ids}" @pytest.mark.order(6)