From b5053d02d525210f1e45e8a9c5b7430811c0200f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 25 Sep 2025 17:25:05 -0700 Subject: [PATCH] fix: Fix multi agent test (#4947) Fix multi agent --- letta/server/server.py | 6 +++--- letta/services/agent_manager.py | 6 +++--- tests/integration_test_multi_agent.py | 17 ++++++++++------- tests/test_multi_agent.py | 6 ++++-- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/letta/server/server.py b/letta/server/server.py index dbb2dd59..f213241c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -626,7 +626,7 @@ class SyncServer(object): # delete the passage await self.passage_manager.delete_passage_by_id_async(passage_id=memory_id, actor=actor) - def get_agent_recall( + async def get_agent_recall( self, user_id: str, agent_id: str, @@ -642,9 +642,9 @@ class SyncServer(object): ) -> Union[List[Message], List[LettaMessage]]: # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) + actor = await self.user_manager.get_actor_or_default_async(actor_id=user_id) - records = self.message_manager.list_messages_for_agent( + records = await self.message_manager.list_messages_for_agent_async( agent_id=agent_id, actor=actor, after=after, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index f852e5f6..f2ebf0a8 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1042,9 +1042,9 @@ class AgentManager: # TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query. @enforce_types @trace_method - def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: - message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids - return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor) + async def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + return await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) @enforce_types @trace_method diff --git a/tests/integration_test_multi_agent.py b/tests/integration_test_multi_agent.py index 36f25508..f439cbe2 100644 --- a/tests/integration_test_multi_agent.py +++ b/tests/integration_test_multi_agent.py @@ -1,3 +1,4 @@ +import asyncio import json import os import threading @@ -147,7 +148,7 @@ def roll_dice_tool(client): @retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_send_message_to_agent(client, server, agent_obj, other_agent_obj): secret_word = "banana" - actor = server.user_manager.get_user_or_default() + actor = asyncio.run(server.user_manager.get_actor_or_default_async()) # Encourage the agent to send a message to the other agent_obj with the secret string response = client.agents.messages.create( @@ -161,11 +162,13 @@ def test_send_message_to_agent(client, server, agent_obj, other_agent_obj): ) # Conversation search the other agent - messages = server.get_agent_recall( - user_id=actor.id, - agent_id=other_agent_obj.id, - reverse=True, - return_message_object=False, + messages = asyncio.run( + server.get_agent_recall( + user_id=actor.id, + agent_id=other_agent_obj.id, + reverse=True, + return_message_object=False, + ) ) # Check for the presence of system message @@ -176,7 +179,7 @@ def test_send_message_to_agent(client, server, agent_obj, other_agent_obj): break # Search the sender agent for the response from another agent - in_context_messages = AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor) + in_context_messages = asyncio.run(AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor)) found = False target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': [" diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 945c881f..28a4ff3a 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from letta.config import LettaConfig @@ -236,7 +238,7 @@ async def test_round_robin(server, default_user, four_participant_agents): assert message.name == four_participant_agents[i // 2].name for agent_id in group.agent_ids: - agent_messages = server.get_agent_recall( + agent_messages = await server.get_agent_recall( user_id=default_user.id, agent_id=agent_id, group_id=group.id, @@ -292,7 +294,7 @@ async def test_round_robin(server, default_user, four_participant_agents): assert message.name == four_participant_agents[::-1][i // 2].name for i in range(len(group.agent_ids)): - agent_messages = server.get_agent_recall( + agent_messages = await server.get_agent_recall( user_id=default_user.id, agent_id=group.agent_ids[i], group_id=group.id,