From 44afd54c5cfe303074cbc950ff413f3f6bf3de99 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 20 May 2025 16:52:11 -0700 Subject: [PATCH] feat(asyncify): migrate list messages (#2272) --- letta/agents/voice_agent.py | 2 +- letta/groups/sleeptime_multi_agent_v2.py | 2 +- letta/server/rest_api/routers/v1/agents.py | 8 +- letta/server/server.py | 38 ++++++++ letta/services/message_manager.py | 101 +++++++++++++++++++++ tests/test_server.py | 2 +- 6 files changed, 146 insertions(+), 7 deletions(-) diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 3926ce5a..959a25a9 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -457,7 +457,7 @@ class VoiceAgent(BaseAgent): keyword_results = {} if convo_keyword_queries: for keyword in convo_keyword_queries: - messages = self.message_manager.list_messages_for_agent( + messages = await self.message_manager.list_messages_for_agent_async( agent_id=self.agent_id, actor=self.actor, query_text=keyword, diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index e2910e5b..9cd2cede 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -190,7 +190,7 @@ class SleeptimeMultiAgentV2(BaseAgent): prior_messages = [] if self.group.sleeptime_agent_frequency: try: - prior_messages = self.message_manager.list_messages_for_agent( + prior_messages = await self.message_manager.list_messages_for_agent_async( agent_id=foreground_agent_id, actor=self.actor, after=last_processed_message_id, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 7e0ec54a..5672fc33 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -567,7 +567,7 @@ AgentMessagesResponse = Annotated[ @router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_messages") -def list_messages( +async def list_messages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."), @@ -582,10 +582,9 @@ def list_messages( """ Retrieve message history for an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.get_agent_recall( - user_id=actor.id, + return await server.get_agent_recall_async( agent_id=agent_id, after=after, before=before, @@ -596,6 +595,7 @@ def list_messages( use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, + actor=actor, ) diff --git a/letta/server/server.py b/letta/server/server.py index 512dd4d9..bf579a22 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1096,6 +1096,44 @@ class SyncServer(Server): return records + async def get_agent_recall_async( + self, + agent_id: str, + actor: User, + after: Optional[str] = None, + before: Optional[str] = None, + limit: Optional[int] = 100, + group_id: Optional[str] = None, + reverse: Optional[bool] = False, + return_message_object: bool = True, + use_assistant_message: bool = True, + assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, + ) -> Union[List[Message], List[LettaMessage]]: + records = await self.message_manager.list_messages_for_agent_async( + agent_id=agent_id, + actor=actor, + after=after, + before=before, + limit=limit, + ascending=not reverse, + group_id=group_id, + ) + + if not return_message_object: + records = Message.to_letta_messages_from_list( + messages=records, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + reverse=reverse, + ) + + if reverse: + records = records[::-1] + + return records + def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index a364e108..55510f08 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -429,6 +429,107 @@ class MessageManager: results = query.all() return [msg.to_pydantic() for msg in results] + @enforce_types + async def list_messages_for_agent_async( + self, + agent_id: str, + actor: PydanticUser, + after: Optional[str] = None, + before: Optional[str] = None, + query_text: Optional[str] = None, + roles: Optional[Sequence[MessageRole]] = None, + limit: Optional[int] = 50, + ascending: bool = True, + group_id: Optional[str] = None, + ) -> List[PydanticMessage]: + """ + Most performant query to list messages for an agent by directly querying the Message table. + + This function filters by the agent_id (leveraging the index on messages.agent_id) + and applies pagination using sequence_id as the cursor. + If query_text is provided, it will filter messages whose text content partially matches the query. + If role is provided, it will filter messages by the specified role. + + Args: + agent_id: The ID of the agent whose messages are queried. + actor: The user performing the action (used for permission checks). + after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned. + before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned. + query_text: Optional string to partially match the message text content. + roles: Optional MessageRole to filter messages by role. + limit: Maximum number of messages to return. + ascending: If True, sort by sequence_id ascending; if False, sort descending. + group_id: Optional group ID to filter messages by group_id. + + Returns: + List[PydanticMessage]: A list of messages (converted via .to_pydantic()). + + Raises: + NoResultFound: If the provided after/before message IDs do not exist. + """ + + async with db_registry.async_session() as session: + # Permission check: raise if the agent doesn't exist or actor is not allowed. + await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # Build a query that directly filters the Message table by agent_id. + query = select(MessageModel).where(MessageModel.agent_id == agent_id) + + # If group_id is provided, filter messages by group_id. + if group_id: + query = query.where(MessageModel.group_id == group_id) + + # If query_text is provided, filter messages using subquery + json_array_elements. + if query_text: + content_element = func.json_array_elements(MessageModel.content).alias("content_element") + query = query.where( + exists( + select(1) + .select_from(content_element) + .where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text")) + .params(query_text=f"%{query_text}%") + ) + ) + + # If role(s) are provided, filter messages by those roles. + if roles: + role_values = [r.value for r in roles] + query = query.where(MessageModel.role.in_(role_values)) + + # Apply 'after' pagination if specified. + if after: + after_query = select(MessageModel.sequence_id).where(MessageModel.id == after) + after_result = await session.execute(after_query) + after_ref = after_result.one_or_none() + if not after_ref: + raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.") + # Filter out any messages with a sequence_id <= after_ref.sequence_id + query = query.where(MessageModel.sequence_id > after_ref.sequence_id) + + # Apply 'before' pagination if specified. + if before: + before_query = select(MessageModel.sequence_id).where(MessageModel.id == before) + before_result = await session.execute(before_query) + before_ref = before_result.one_or_none() + if not before_ref: + raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.") + # Filter out any messages with a sequence_id >= before_ref.sequence_id + query = query.where(MessageModel.sequence_id < before_ref.sequence_id) + + # Apply ordering based on the ascending flag. + if ascending: + query = query.order_by(MessageModel.sequence_id.asc()) + else: + query = query.order_by(MessageModel.sequence_id.desc()) + + # Limit the number of results. + query = query.limit(limit) + + # Execute and convert each Message to its Pydantic representation. + result = await session.execute(query) + results = result.scalars().all() + return [msg.to_pydantic() for msg in results] + @enforce_types def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int: """ diff --git a/tests/test_server.py b/tests/test_server.py index 0b43a12f..a832bc92 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -936,7 +936,7 @@ def test_composio_client_simple(server): assert len(actions) > 0 -def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools): +async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" actor = user # create agent