feat: Robustify multi agent tools (#835)

This commit is contained in:
Matthew Zhou
2025-01-29 13:14:15 -10:00
committed by GitHub
parent 29feb4c55c
commit 986397e7d0
3 changed files with 132 additions and 23 deletions

View File

@@ -22,7 +22,12 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
Returns:
str: The response from the target agent.
"""
messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)]
message = (
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
return execute_send_message_to_agent(
sender_agent=self,
messages=messages,
@@ -78,9 +83,15 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags:
server = get_letta_server()
message = (
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
# Retrieve agents that match ALL specified tags
matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100)
messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)]
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
async def send_messages_to_all_agents():
tasks = [

View File

@@ -249,24 +249,29 @@ def generate_import_code(module_attr_map: Optional[dict]):
def parse_letta_response_for_assistant_message(
target_agent_id: str,
letta_response: LettaResponse,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> Optional[str]:
reasoning_message = ""
messages = []
# This is not ideal, but we would like to return something rather than nothing
fallback_reasoning = []
for m in letta_response.messages:
if isinstance(m, AssistantMessage):
return m.content
messages.append(m.content)
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
try:
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
messages.append(json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg])
except Exception: # TODO: Make this more specific
continue
elif isinstance(m, ReasoningMessage):
# This is not ideal, but we would like to return something rather than nothing
reasoning_message += f"{m.reasoning}\n"
fallback_reasoning.append(m.reasoning)
return None
if messages:
return f"Agent {target_agent_id} said: '{"\n".join(messages)}'"
else:
return f"Agent {target_agent_id}'s inner thoughts: '{"\n".join(messages)}'"
def execute_send_message_to_agent(
@@ -364,17 +369,19 @@ async def async_send_message_with_retries(
# Extract assistant message
assistant_message = parse_letta_response_for_assistant_message(
target_agent_id,
response,
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
)
if assistant_message:
msg = f"Agent {target_agent_id} said '{assistant_message}'"
sender_agent.logger.info(f"{logging_prefix} - {msg}")
return msg
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
return assistant_message
else:
msg = f"(No response from agent {target_agent_id})"
sender_agent.logger.info(f"{logging_prefix} - {msg}")
sender_agent.logger.info(f"{logging_prefix} - raw response: {response.model_dump_json(indent=4)}")
sender_agent.logger.info(f"{logging_prefix} - parsed assistant message: {assistant_message}")
return msg
except asyncio.TimeoutError:
error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})"

View File

@@ -4,10 +4,12 @@ import pytest
import letta.functions.function_sets.base as base_functions
from letta import LocalClient, create_client
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.tool import Tool
from tests.helpers.utils import retry_until_success
from tests.utils import wait_for_incoming_message
@@ -44,6 +46,36 @@ def other_agent_obj(client: LocalClient):
client.delete_agent(other_agent_obj.agent_state.id)
@pytest.fixture
def roll_dice_tool(client):
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
return "Rolled a 5!"
# Set up tool details
source_code = parse_source_code(roll_dice)
source_type = "python"
description = "test_description"
tags = ["test"]
tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user)
# Yield the created tool
yield tool
def query_in_search_results(search_results, query):
for result in search_results:
if query.lower() in result["content"].lower():
@@ -118,7 +150,7 @@ def test_recall(client, agent_obj):
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
secret_word = "banana"
@@ -130,13 +162,18 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
)
# Conversation search the other agent
result = base_functions.conversation_search(other_agent_obj, secret_word)
assert secret_word in result
messages = client.get_messages(other_agent_obj.agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
# Search the sender agent for the response from another agent
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
found = False
target_snippet = f"Agent {other_agent_obj.agent_state.id} said "
target_snippet = f"Agent {other_agent_obj.agent_state.id} said:"
for m in in_context_messages:
if target_snippet in m.text:
@@ -152,9 +189,8 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
print(response.messages)
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags(client):
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_simple(client):
worker_tags = ["worker", "user-456"]
# Clean up first from possibly failed tests
@@ -169,7 +205,7 @@ def test_send_message_to_agents_with_tags(client):
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents
# Create 3 non-matching worker agents (These should NOT get the message)
worker_agents = []
worker_tags = ["worker", "user-123"]
for _ in range(3):
@@ -177,7 +213,7 @@ def test_send_message_to_agents_with_tags(client):
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
# Create 2 worker agents that belong to a different user (These should NOT get the message)
# Create 3 worker agents that should get the message
worker_agents = []
worker_tags = ["worker", "user-456"]
for _ in range(3):
@@ -203,8 +239,63 @@ def test_send_message_to_agents_with_tags(client):
# Conversation search the worker agents
for agent in worker_agents:
result = base_functions.conversation_search(agent, secret_word)
assert secret_word in result
messages = client.get_messages(agent.agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
# Test that the agent can still receive messages fine
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
# Clean up agents
client.delete_agent(manager_agent_state.id)
for agent in worker_agents:
client.delete_agent(agent.agent_state.id)
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
worker_tags = ["dice-rollers"]
# Clean up first from possibly failed tests
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
for agent in prev_worker_agents:
client.delete_agent(agent.id)
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents
worker_agents = []
worker_tags = ["dice-rollers"]
for _ in range(2):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id])
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
response = client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",
message=broadcast_message,
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = eval(json.loads(m.tool_return)["message"])
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents)
# We can break after this, the ToolReturnMessage after is not related
break
# Test that the agent can still receive messages fine
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")