diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index a6c30936..e44864dd 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -146,64 +146,87 @@ async def test_round_robin(server, actor, participant_agent_ids): ), actor=actor, ) - response = await server.send_group_message_to_agent( - group_id=group.id, - actor=actor, - messages=[ - MessageCreate( - role="user", - content="what is everyone up to for the holidays?", - ), - ], - stream_steps=False, - stream_tokens=False, - ) - assert response.usage.step_count == len(participant_agent_ids) - assert len(response.messages) == response.usage.step_count * 2 + try: + response = await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate( + role="user", + content="what is everyone up to for the holidays?", + ), + ], + stream_steps=False, + stream_tokens=False, + ) + assert response.usage.step_count == len(participant_agent_ids) + assert len(response.messages) == response.usage.step_count * 2 - server.group_manager.delete_group(group_id=group.id, actor=actor) + finally: + server.group_manager.delete_group(group_id=group.id, actor=actor) @pytest.mark.asyncio -async def test_supervisor(server, actor, manager_agent_id, participant_agent_ids): +async def test_supervisor(server, actor, participant_agent_ids): + agent_scrappy = server.create_agent( + request=CreateAgent( + name="shaggy", + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your role is to supervise the group, sending messages and aggregating the responses.", + ), + CreateBlock( + label="human", + value="", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ), + actor=actor, + ) group = server.group_manager.create_group( group=GroupCreate( description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.", agent_ids=participant_agent_ids, manager_config=SupervisorManager( - manager_agent_id=manager_agent_id, + manager_agent_id=agent_scrappy.id, ), ), actor=actor, ) - response = await server.send_group_message_to_agent( - group_id=group.id, - actor=actor, - messages=[ - MessageCreate( - role="user", - content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.", - ), - ], - stream_steps=False, - stream_tokens=False, - ) - assert response.usage.step_count == 2 - assert len(response.messages) == 5 + try: + response = await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate( + role="user", + content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.", + ), + ], + stream_steps=False, + stream_tokens=False, + ) + assert response.usage.step_count == 2 + assert len(response.messages) == 5 - # verify tool call - assert response.messages[0].message_type == "reasoning_message" - assert ( - response.messages[1].message_type == "tool_call_message" - and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group" - ) - assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len( - participant_agent_ids - ) - assert response.messages[3].message_type == "reasoning_message" - assert response.messages[4].message_type == "assistant_message" + # verify tool call + assert response.messages[0].message_type == "reasoning_message" + assert ( + response.messages[1].message_type == "tool_call_message" + and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group" + ) + assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len( + participant_agent_ids + ) + assert response.messages[3].message_type == "reasoning_message" + assert response.messages[4].message_type == "assistant_message" - server.group_manager.delete_group(group_id=group.id, actor=actor) + finally: + server.group_manager.delete_group(group_id=group.id, actor=actor) + server.agent_manager.delete_agent(agent_id=agent_scrappy.id, actor=actor) @pytest.mark.asyncio @@ -218,16 +241,18 @@ async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_a ), actor=actor, ) - response = await server.send_group_message_to_agent( - group_id=group.id, - actor=actor, - messages=[ - MessageCreate(role="user", content="what is everyone up to for the holidays?"), - ], - stream_steps=False, - stream_tokens=False, - ) - assert response.usage.step_count == len(participant_agent_ids) * 2 - assert len(response.messages) == response.usage.step_count * 2 + try: + response = await server.send_group_message_to_agent( + group_id=group.id, + actor=actor, + messages=[ + MessageCreate(role="user", content="what is everyone up to for the holidays?"), + ], + stream_steps=False, + stream_tokens=False, + ) + assert response.usage.step_count == len(participant_agent_ids) * 2 + assert len(response.messages) == response.usage.step_count * 2 - server.group_manager.delete_group(group_id=group.id, actor=actor) + finally: + server.group_manager.delete_group(group_id=group.id, actor=actor)