feat: Robustify multi agent tools (#835)
This commit is contained in:
@@ -22,7 +22,12 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
|
||||
Returns:
|
||||
str: The response from the target agent.
|
||||
"""
|
||||
messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)]
|
||||
message = (
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
|
||||
return execute_send_message_to_agent(
|
||||
sender_agent=self,
|
||||
messages=messages,
|
||||
@@ -78,9 +83,15 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags:
|
||||
|
||||
server = get_letta_server()
|
||||
|
||||
message = (
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
# Retrieve agents that match ALL specified tags
|
||||
matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100)
|
||||
messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)]
|
||||
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
|
||||
|
||||
async def send_messages_to_all_agents():
|
||||
tasks = [
|
||||
|
||||
@@ -249,24 +249,29 @@ def generate_import_code(module_attr_map: Optional[dict]):
|
||||
|
||||
|
||||
def parse_letta_response_for_assistant_message(
|
||||
target_agent_id: str,
|
||||
letta_response: LettaResponse,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> Optional[str]:
|
||||
reasoning_message = ""
|
||||
messages = []
|
||||
# This is not ideal, but we would like to return something rather than nothing
|
||||
fallback_reasoning = []
|
||||
for m in letta_response.messages:
|
||||
if isinstance(m, AssistantMessage):
|
||||
return m.content
|
||||
messages.append(m.content)
|
||||
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
|
||||
try:
|
||||
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
|
||||
messages.append(json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg])
|
||||
except Exception: # TODO: Make this more specific
|
||||
continue
|
||||
elif isinstance(m, ReasoningMessage):
|
||||
# This is not ideal, but we would like to return something rather than nothing
|
||||
reasoning_message += f"{m.reasoning}\n"
|
||||
fallback_reasoning.append(m.reasoning)
|
||||
|
||||
return None
|
||||
if messages:
|
||||
return f"Agent {target_agent_id} said: '{"\n".join(messages)}'"
|
||||
else:
|
||||
return f"Agent {target_agent_id}'s inner thoughts: '{"\n".join(messages)}'"
|
||||
|
||||
|
||||
def execute_send_message_to_agent(
|
||||
@@ -364,17 +369,19 @@ async def async_send_message_with_retries(
|
||||
|
||||
# Extract assistant message
|
||||
assistant_message = parse_letta_response_for_assistant_message(
|
||||
target_agent_id,
|
||||
response,
|
||||
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
)
|
||||
if assistant_message:
|
||||
msg = f"Agent {target_agent_id} said '{assistant_message}'"
|
||||
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
||||
return msg
|
||||
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
|
||||
return assistant_message
|
||||
else:
|
||||
msg = f"(No response from agent {target_agent_id})"
|
||||
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
||||
sender_agent.logger.info(f"{logging_prefix} - raw response: {response.model_dump_json(indent=4)}")
|
||||
sender_agent.logger.info(f"{logging_prefix} - parsed assistant message: {assistant_message}")
|
||||
return msg
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})"
|
||||
|
||||
@@ -4,10 +4,12 @@ import pytest
|
||||
|
||||
import letta.functions.function_sets.base as base_functions
|
||||
from letta import LocalClient, create_client
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.tool import Tool
|
||||
from tests.helpers.utils import retry_until_success
|
||||
from tests.utils import wait_for_incoming_message
|
||||
|
||||
@@ -44,6 +46,36 @@ def other_agent_obj(client: LocalClient):
|
||||
client.delete_agent(other_agent_obj.agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool(client):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
return "Rolled a 5!"
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(roll_dice)
|
||||
source_type = "python"
|
||||
description = "test_description"
|
||||
tags = ["test"]
|
||||
|
||||
tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
def query_in_search_results(search_results, query):
|
||||
for result in search_results:
|
||||
if query.lower() in result["content"].lower():
|
||||
@@ -118,7 +150,7 @@ def test_recall(client, agent_obj):
|
||||
|
||||
|
||||
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
|
||||
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
secret_word = "banana"
|
||||
|
||||
@@ -130,13 +162,18 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
)
|
||||
|
||||
# Conversation search the other agent
|
||||
result = base_functions.conversation_search(other_agent_obj, secret_word)
|
||||
assert secret_word in result
|
||||
messages = client.get_messages(other_agent_obj.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Search the sender agent for the response from another agent
|
||||
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
|
||||
found = False
|
||||
target_snippet = f"Agent {other_agent_obj.agent_state.id} said "
|
||||
target_snippet = f"Agent {other_agent_obj.agent_state.id} said:"
|
||||
|
||||
for m in in_context_messages:
|
||||
if target_snippet in m.text:
|
||||
@@ -152,9 +189,8 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
print(response.messages)
|
||||
|
||||
|
||||
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags(client):
|
||||
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_simple(client):
|
||||
worker_tags = ["worker", "user-456"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
@@ -169,7 +205,7 @@ def test_send_message_to_agents_with_tags(client):
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
# Create 3 non-matching worker agents (These should NOT get the message)
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-123"]
|
||||
for _ in range(3):
|
||||
@@ -177,7 +213,7 @@ def test_send_message_to_agents_with_tags(client):
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Create 2 worker agents that belong to a different user (These should NOT get the message)
|
||||
# Create 3 worker agents that should get the message
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-456"]
|
||||
for _ in range(3):
|
||||
@@ -203,8 +239,63 @@ def test_send_message_to_agents_with_tags(client):
|
||||
|
||||
# Conversation search the worker agents
|
||||
for agent in worker_agents:
|
||||
result = base_functions.conversation_search(agent, secret_word)
|
||||
assert secret_word in result
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
|
||||
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
|
||||
worker_tags = ["dice-rollers"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
worker_agents = []
|
||||
worker_tags = ["dice-rollers"]
|
||||
for _ in range(2):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id])
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=broadcast_message,
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolReturnMessage):
|
||||
tool_response = eval(json.loads(m.tool_return)["message"])
|
||||
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
|
||||
assert len(tool_response) == len(worker_agents)
|
||||
|
||||
# We can break after this, the ToolReturnMessage after is not related
|
||||
break
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
|
||||
Reference in New Issue
Block a user