From f52200f434f1768100e7aa02fbd6a6ffb4ad1c0b Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 29 Jan 2025 13:14:15 -1000 Subject: [PATCH] feat: Robustify multi agent tools (#835) --- letta/functions/function_sets/multi_agent.py | 15 ++- letta/functions/helpers.py | 25 ++-- tests/test_base_functions.py | 115 +++++++++++++++++-- 3 files changed, 132 insertions(+), 23 deletions(-) diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index a8641b2f..ef607713 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -22,7 +22,12 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_ Returns: str: The response from the target agent. """ - messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)] + message = ( + f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, " + f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " + f"{message}" + ) + messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)] return execute_send_message_to_agent( sender_agent=self, messages=messages, @@ -78,9 +83,15 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: server = get_letta_server() + message = ( + f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, " + f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " + f"{message}" + ) + # Retrieve agents that match ALL specified tags matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100) - messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)] + messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)] async def send_messages_to_all_agents(): tasks = [ diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 24492119..9ebe9494 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -249,24 +249,29 @@ def generate_import_code(module_attr_map: Optional[dict]): def parse_letta_response_for_assistant_message( + target_agent_id: str, letta_response: LettaResponse, assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> Optional[str]: - reasoning_message = "" + messages = [] + # This is not ideal, but we would like to return something rather than nothing + fallback_reasoning = [] for m in letta_response.messages: if isinstance(m, AssistantMessage): - return m.content + messages.append(m.content) elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name: try: - return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg] + messages.append(json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]) except Exception: # TODO: Make this more specific continue elif isinstance(m, ReasoningMessage): - # This is not ideal, but we would like to return something rather than nothing - reasoning_message += f"{m.reasoning}\n" + fallback_reasoning.append(m.reasoning) - return None + if messages: + return f"Agent {target_agent_id} said: '{"\n".join(messages)}'" + else: + return f"Agent {target_agent_id}'s inner thoughts: '{"\n".join(messages)}'" def execute_send_message_to_agent( @@ -364,17 +369,19 @@ async def async_send_message_with_retries( # Extract assistant message assistant_message = parse_letta_response_for_assistant_message( + target_agent_id, response, assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, ) if assistant_message: - msg = f"Agent {target_agent_id} said '{assistant_message}'" - sender_agent.logger.info(f"{logging_prefix} - {msg}") - return msg + sender_agent.logger.info(f"{logging_prefix} - {assistant_message}") + return assistant_message else: msg = f"(No response from agent {target_agent_id})" sender_agent.logger.info(f"{logging_prefix} - {msg}") + sender_agent.logger.info(f"{logging_prefix} - raw response: {response.model_dump_json(indent=4)}") + sender_agent.logger.info(f"{logging_prefix} - parsed assistant message: {assistant_message}") return msg except asyncio.TimeoutError: error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})" diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 92c929f9..30ba8ab6 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -4,10 +4,12 @@ import pytest import letta.functions.function_sets.base as base_functions from letta import LocalClient, create_client +from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.letta_message import ToolReturnMessage +from letta.schemas.letta_message import SystemMessage, ToolReturnMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory +from letta.schemas.tool import Tool from tests.helpers.utils import retry_until_success from tests.utils import wait_for_incoming_message @@ -44,6 +46,36 @@ def other_agent_obj(client: LocalClient): client.delete_agent(other_agent_obj.agent_state.id) +@pytest.fixture +def roll_dice_tool(client): + def roll_dice(): + """ + Rolls a 6 sided die. + + Returns: + str: The roll result. + """ + return "Rolled a 5!" + + # Set up tool details + source_code = parse_source_code(roll_dice) + source_type = "python" + description = "test_description" + tags = ["test"] + + tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + + tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user) + + # Yield the created tool + yield tool + + def query_in_search_results(search_results, query): for result in search_results: if query.lower() in result["content"].lower(): @@ -118,7 +150,7 @@ def test_recall(client, agent_obj): # This test is nondeterministic, so we retry until we get the perfect behavior from the LLM -@retry_until_success(max_attempts=5, sleep_time_seconds=2) +@retry_until_success(max_attempts=2, sleep_time_seconds=2) def test_send_message_to_agent(client, agent_obj, other_agent_obj): secret_word = "banana" @@ -130,13 +162,18 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj): ) # Conversation search the other agent - result = base_functions.conversation_search(other_agent_obj, secret_word) - assert secret_word in result + messages = client.get_messages(other_agent_obj.agent_state.id) + # Check for the presence of system message + for m in reversed(messages): + print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}") + if isinstance(m, SystemMessage): + assert secret_word in m.content + break # Search the sender agent for the response from another agent in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) found = False - target_snippet = f"Agent {other_agent_obj.agent_state.id} said " + target_snippet = f"Agent {other_agent_obj.agent_state.id} said:" for m in in_context_messages: if target_snippet in m.text: @@ -152,9 +189,8 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj): print(response.messages) -# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_send_message_to_agents_with_tags(client): +@retry_until_success(max_attempts=2, sleep_time_seconds=2) +def test_send_message_to_agents_with_tags_simple(client): worker_tags = ["worker", "user-456"] # Clean up first from possibly failed tests @@ -169,7 +205,7 @@ def test_send_message_to_agents_with_tags(client): manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) - # Create 3 worker agents + # Create 3 non-matching worker agents (These should NOT get the message) worker_agents = [] worker_tags = ["worker", "user-123"] for _ in range(3): @@ -177,7 +213,7 @@ def test_send_message_to_agents_with_tags(client): worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) worker_agents.append(worker_agent) - # Create 2 worker agents that belong to a different user (These should NOT get the message) + # Create 3 worker agents that should get the message worker_agents = [] worker_tags = ["worker", "user-456"] for _ in range(3): @@ -203,8 +239,63 @@ def test_send_message_to_agents_with_tags(client): # Conversation search the worker agents for agent in worker_agents: - result = base_functions.conversation_search(agent, secret_word) - assert secret_word in result + messages = client.get_messages(agent.agent_state.id) + # Check for the presence of system message + for m in reversed(messages): + print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}") + if isinstance(m, SystemMessage): + assert secret_word in m.content + break + + # Test that the agent can still receive messages fine + response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") + print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages])) + + # Clean up agents + client.delete_agent(manager_agent_state.id) + for agent in worker_agents: + client.delete_agent(agent.agent_state.id) + + +# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM +@retry_until_success(max_attempts=2, sleep_time_seconds=2) +def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool): + worker_tags = ["dice-rollers"] + + # Clean up first from possibly failed tests + prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True) + for agent in prev_worker_agents: + client.delete_agent(agent.id) + + # Create "manager" agent + send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") + manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) + manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) + + # Create 3 worker agents + worker_agents = [] + worker_tags = ["dice-rollers"] + for _ in range(2): + worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id]) + worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) + worker_agents.append(worker_agent) + + # Encourage the manager to send a message to the other agent_obj with the secret string + broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" + response = client.send_message( + agent_id=manager_agent.agent_state.id, + role="user", + message=broadcast_message, + ) + + for m in response.messages: + if isinstance(m, ToolReturnMessage): + tool_response = eval(json.loads(m.tool_return)["message"]) + print(f"\n\nManager agent tool response: \n{tool_response}\n\n") + assert len(tool_response) == len(worker_agents) + + # We can break after this, the ToolReturnMessage after is not related + break # Test that the agent can still receive messages fine response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")