feat: Modify multi agent broadcast for partial matching (#1208)

This commit is contained in:
Matthew Zhou
2025-03-06 13:26:59 -08:00
committed by GitHub
parent e1ea5c3cdc
commit 5f69182063
7 changed files with 197 additions and 40 deletions

View File

@@ -50,7 +50,7 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
# Multi agent tools
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
# Set of all built-in Letta tools
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)

View File

@@ -2,7 +2,7 @@ import asyncio
from typing import TYPE_CHECKING, List
from letta.functions.helpers import (
_send_message_to_agents_matching_all_tags_async,
_send_message_to_agents_matching_tags_async,
execute_send_message_to_agent,
fire_and_forget_send_to_agent,
)
@@ -70,18 +70,19 @@ def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str
return "Successfully sent message"
def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]:
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]:
"""
Sends a message to all agents within the same organization that match all of the specified tags. Messages are dispatched in parallel for improved performance, with retries to handle transient issues and timeouts to ensure responsiveness. This function enforces a limit of 100 agents and does not support pagination (cursor-based queries). Each agent must match all specified tags (`match_all_tags=True`) to be included.
Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message.
Args:
message (str): The content of the message to be sent to each matching agent.
tags (List[str]): A list of tags that an agent must possess to receive the message.
match_all (List[str]): A list of tags that an agent must possess to receive the message.
match_some (List[str]): A list of tags where an agent must have at least one to qualify.
Returns:
List[str]: A list of responses from the agents that matched all tags. Each
response corresponds to a single agent. Agents that do not respond will not
have an entry in the returned list.
List[str]: A list of responses from the agents that matched the filtering criteria. Each
response corresponds to a single agent. Agents that do not respond will not have an entry
in the returned list.
"""
return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags))
return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some))

View File

@@ -518,8 +518,16 @@ def fire_and_forget_send_to_agent(
run_in_background_thread(background_task())
async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]:
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async start", message=message, tags=tags)
async def _send_message_to_agents_matching_tags_async(
sender_agent: "Agent", message: str, match_all: List[str], match_some: List[str]
) -> List[str]:
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async start",
message=message,
match_all=match_all,
match_some=match_some,
)
server = get_letta_server()
augmented_message = (
@@ -529,9 +537,22 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
)
# Retrieve up to 100 matching agents
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents start", message=message, tags=tags)
matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100)
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents finish", message=message, tags=tags)
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async listing agents start",
message=message,
match_all=match_all,
match_some=match_some,
)
matching_agents = server.agent_manager.list_agents_matching_tags(actor=sender_agent.user, match_all=match_all, match_some=match_some)
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async listing agents finish",
message=message,
match_all=match_all,
match_some=match_some,
)
# Create a system message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
@@ -559,7 +580,13 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
else:
final.append(r)
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async finish", message=message, tags=tags)
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async finish",
message=message,
match_all=match_all,
match_some=match_some,
)
return final

View File

@@ -358,6 +358,49 @@ class AgentManager:
return [agent.to_pydantic() for agent in agents]
@enforce_types
def list_agents_matching_tags(
self,
actor: PydanticUser,
match_all: List[str],
match_some: List[str],
limit: Optional[int] = 50,
) -> List[PydanticAgentState]:
"""
Retrieves agents in the same organization that match all specified `match_all` tags
and at least one tag from `match_some`. The query is optimized for efficiency by
leveraging indexed filtering and aggregation.
Args:
actor (PydanticUser): The user requesting the agent list.
match_all (List[str]): Agents must have all these tags.
match_some (List[str]): Agents must have at least one of these tags.
limit (Optional[int]): Maximum number of agents to return.
Returns:
List[PydanticAgentState: The filtered list of matching agents.
"""
with self.session_maker() as session:
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
if match_all:
# Subquery to find agent IDs that contain all match_all tags
subquery = (
select(AgentsTags.agent_id)
.where(AgentsTags.tag.in_(match_all))
.group_by(AgentsTags.agent_id)
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
)
query = query.where(AgentModel.id.in_(subquery))
if match_some:
# Ensures agents match at least one tag in match_some
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
query = query.group_by(AgentModel.id).limit(limit)
return list(session.execute(query).scalars())
@enforce_types
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""

View File

@@ -127,54 +127,55 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_simple(client):
worker_tags = ["worker", "user-456"]
worker_tags_123 = ["worker", "user-123"]
worker_tags_456 = ["worker", "user-456"]
# Clean up first from possibly failed tests
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
prev_worker_agents = client.server.agent_manager.list_agents(
client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True
)
for agent in prev_worker_agents:
client.delete_agent(agent.id)
secret_word = "banana"
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 non-matching worker agents (These should NOT get the message)
worker_agents = []
worker_tags = ["worker", "user-123"]
worker_agents_123 = []
for _ in range(3):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_123)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
worker_agents_123.append(worker_agent)
# Create 3 worker agents that should get the message
worker_agents = []
worker_tags = ["worker", "user-456"]
worker_agents_456 = []
for _ in range(3):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_456)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
worker_agents_456.append(worker_agent)
# Encourage the manager to send a message to the other agent_obj with the secret string
response = client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",
message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!",
message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = eval(json.loads(m.tool_return)["message"])
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents)
assert len(tool_response) == len(worker_agents_456)
# We can break after this, the ToolReturnMessage after is not related
break
# Conversation search the worker agents
for agent in worker_agents:
for agent in worker_agents_456:
messages = client.get_messages(agent.agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
@@ -183,13 +184,22 @@ def test_send_message_to_agents_with_tags_simple(client):
assert secret_word in m.content
break
# Ensure it's NOT in the non matching worker agents
for agent in worker_agents_123:
messages = client.get_messages(agent.agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word not in m.content
# Test that the agent can still receive messages fine
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
# Clean up agents
client.delete_agent(manager_agent_state.id)
for agent in worker_agents:
for agent in worker_agents_456 + worker_agents_123:
client.delete_agent(agent.agent_state.id)
@@ -203,8 +213,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
client.delete_agent(agent.id)
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents
@@ -245,8 +255,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 2 worker agents
@@ -260,7 +270,7 @@ def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
worker_agents.append(worker_agent)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Using your tool named `send_message_to_agents_matching_all_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
broadcast_message = f"Using your tool named `send_message_to_agents_matching_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",

View File

@@ -65,10 +65,8 @@ def test_multi_agent_large(client, roll_dice_tool, num_workers):
client.delete_agent(agent.id)
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(
name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags
)
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags)
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents

View File

@@ -474,6 +474,45 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
server.source_manager.delete_source(default_source.id, actor=actor)
@pytest.fixture
def agent_with_tags(server: SyncServer, default_user):
"""Fixture to create agents with specific tags."""
agent1 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent1",
tags=["primary_agent", "benefit_1"],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
),
actor=default_user,
)
agent2 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent2",
tags=["primary_agent", "benefit_2"],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
),
actor=default_user,
)
agent3 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent3",
tags=["primary_agent", "benefit_1", "benefit_2"],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
),
actor=default_user,
)
return [agent1, agent2, agent3]
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -777,6 +816,45 @@ def test_list_attached_agents_nonexistent_source(server: SyncServer, default_use
# ======================================================================================================================
def test_list_agents_matching_all_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent", "benefit_1"],
match_some=[],
)
assert len(agents) == 2 # agent1 and agent3 match
assert {a.name for a in agents} == {"agent1", "agent3"}
def test_list_agents_matching_some_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent"],
match_some=["benefit_1", "benefit_2"],
)
assert len(agents) == 3 # All agents match
assert {a.name for a in agents} == {"agent1", "agent2", "agent3"}
def test_list_agents_matching_all_and_some_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent", "benefit_1"],
match_some=["benefit_2", "nonexistent"],
)
assert len(agents) == 1 # Only agent3 matches
assert agents[0].name == "agent3"
def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent", "nonexistent_tag"],
match_some=["benefit_1", "benefit_2"],
)
assert len(agents) == 0 # No agent should match
def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user):
"""Test listing agents that have ALL specified tags."""
# Create agents with multiple tags