fix: Fix multi agent test (#4947)

Fix multi agent
This commit is contained in:
Matthew Zhou
2025-09-25 17:25:05 -07:00
committed by Caren Thomas
parent 81bf132712
commit b5053d02d5
4 changed files with 20 additions and 15 deletions

View File

@@ -626,7 +626,7 @@ class SyncServer(object):
# delete the passage
await self.passage_manager.delete_passage_by_id_async(passage_id=memory_id, actor=actor)
def get_agent_recall(
async def get_agent_recall(
self,
user_id: str,
agent_id: str,
@@ -642,9 +642,9 @@ class SyncServer(object):
) -> Union[List[Message], List[LettaMessage]]:
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
actor = await self.user_manager.get_actor_or_default_async(actor_id=user_id)
records = self.message_manager.list_messages_for_agent(
records = await self.message_manager.list_messages_for_agent_async(
agent_id=agent_id,
actor=actor,
after=after,

View File

@@ -1042,9 +1042,9 @@ class AgentManager:
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
@enforce_types
@trace_method
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
async def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
return await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
@enforce_types
@trace_method

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import os
import threading
@@ -147,7 +148,7 @@ def roll_dice_tool(client):
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agent(client, server, agent_obj, other_agent_obj):
secret_word = "banana"
actor = server.user_manager.get_user_or_default()
actor = asyncio.run(server.user_manager.get_actor_or_default_async())
# Encourage the agent to send a message to the other agent_obj with the secret string
response = client.agents.messages.create(
@@ -161,11 +162,13 @@ def test_send_message_to_agent(client, server, agent_obj, other_agent_obj):
)
# Conversation search the other agent
messages = server.get_agent_recall(
user_id=actor.id,
agent_id=other_agent_obj.id,
reverse=True,
return_message_object=False,
messages = asyncio.run(
server.get_agent_recall(
user_id=actor.id,
agent_id=other_agent_obj.id,
reverse=True,
return_message_object=False,
)
)
# Check for the presence of system message
@@ -176,7 +179,7 @@ def test_send_message_to_agent(client, server, agent_obj, other_agent_obj):
break
# Search the sender agent for the response from another agent
in_context_messages = AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor)
in_context_messages = asyncio.run(AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor))
found = False
target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': ["

View File

@@ -1,3 +1,5 @@
import asyncio
import pytest
from letta.config import LettaConfig
@@ -236,7 +238,7 @@ async def test_round_robin(server, default_user, four_participant_agents):
assert message.name == four_participant_agents[i // 2].name
for agent_id in group.agent_ids:
agent_messages = server.get_agent_recall(
agent_messages = await server.get_agent_recall(
user_id=default_user.id,
agent_id=agent_id,
group_id=group.id,
@@ -292,7 +294,7 @@ async def test_round_robin(server, default_user, four_participant_agents):
assert message.name == four_participant_agents[::-1][i // 2].name
for i in range(len(group.agent_ids)):
agent_messages = server.get_agent_recall(
agent_messages = await server.get_agent_recall(
user_id=default_user.id,
agent_id=group.agent_ids[i],
group_id=group.id,