feat: Modify multi agent broadcast for partial matching (#1208)
This commit is contained in:
@@ -50,7 +50,7 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
|
||||
# Base memory tools CAN be edited, and are added by default by the server
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
# Multi agent tools
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
|
||||
# Set of all built-in Letta tools
|
||||
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_all_tags_async,
|
||||
_send_message_to_agents_matching_tags_async,
|
||||
execute_send_message_to_agent,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
@@ -70,18 +70,19 @@ def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str
|
||||
return "Successfully sent message"
|
||||
|
||||
|
||||
def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]:
|
||||
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]:
|
||||
"""
|
||||
Sends a message to all agents within the same organization that match all of the specified tags. Messages are dispatched in parallel for improved performance, with retries to handle transient issues and timeouts to ensure responsiveness. This function enforces a limit of 100 agents and does not support pagination (cursor-based queries). Each agent must match all specified tags (`match_all_tags=True`) to be included.
|
||||
Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message.
|
||||
|
||||
Args:
|
||||
message (str): The content of the message to be sent to each matching agent.
|
||||
tags (List[str]): A list of tags that an agent must possess to receive the message.
|
||||
match_all (List[str]): A list of tags that an agent must possess to receive the message.
|
||||
match_some (List[str]): A list of tags where an agent must have at least one to qualify.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of responses from the agents that matched all tags. Each
|
||||
response corresponds to a single agent. Agents that do not respond will not
|
||||
have an entry in the returned list.
|
||||
List[str]: A list of responses from the agents that matched the filtering criteria. Each
|
||||
response corresponds to a single agent. Agents that do not respond will not have an entry
|
||||
in the returned list.
|
||||
"""
|
||||
|
||||
return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags))
|
||||
return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some))
|
||||
|
||||
@@ -518,8 +518,16 @@ def fire_and_forget_send_to_agent(
|
||||
run_in_background_thread(background_task())
|
||||
|
||||
|
||||
async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]:
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async start", message=message, tags=tags)
|
||||
async def _send_message_to_agents_matching_tags_async(
|
||||
sender_agent: "Agent", message: str, match_all: List[str], match_some: List[str]
|
||||
) -> List[str]:
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async start",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
@@ -529,9 +537,22 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
|
||||
)
|
||||
|
||||
# Retrieve up to 100 matching agents
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents start", message=message, tags=tags)
|
||||
matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100)
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents finish", message=message, tags=tags)
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async listing agents start",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
matching_agents = server.agent_manager.list_agents_matching_tags(actor=sender_agent.user, match_all=match_all, match_some=match_some)
|
||||
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async listing agents finish",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
|
||||
# Create a system message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
|
||||
@@ -559,7 +580,13 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
|
||||
else:
|
||||
final.append(r)
|
||||
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async finish", message=message, tags=tags)
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async finish",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
return final
|
||||
|
||||
|
||||
|
||||
@@ -358,6 +358,49 @@ class AgentManager:
|
||||
|
||||
return [agent.to_pydantic() for agent in agents]
|
||||
|
||||
@enforce_types
|
||||
def list_agents_matching_tags(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
match_all: List[str],
|
||||
match_some: List[str],
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Retrieves agents in the same organization that match all specified `match_all` tags
|
||||
and at least one tag from `match_some`. The query is optimized for efficiency by
|
||||
leveraging indexed filtering and aggregation.
|
||||
|
||||
Args:
|
||||
actor (PydanticUser): The user requesting the agent list.
|
||||
match_all (List[str]): Agents must have all these tags.
|
||||
match_some (List[str]): Agents must have at least one of these tags.
|
||||
limit (Optional[int]): Maximum number of agents to return.
|
||||
|
||||
Returns:
|
||||
List[PydanticAgentState: The filtered list of matching agents.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
|
||||
|
||||
if match_all:
|
||||
# Subquery to find agent IDs that contain all match_all tags
|
||||
subquery = (
|
||||
select(AgentsTags.agent_id)
|
||||
.where(AgentsTags.tag.in_(match_all))
|
||||
.group_by(AgentsTags.agent_id)
|
||||
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
|
||||
)
|
||||
query = query.where(AgentModel.id.in_(subquery))
|
||||
|
||||
if match_some:
|
||||
# Ensures agents match at least one tag in match_some
|
||||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||||
|
||||
query = query.group_by(AgentModel.id).limit(limit)
|
||||
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
|
||||
@@ -127,54 +127,55 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_simple(client):
|
||||
worker_tags = ["worker", "user-456"]
|
||||
worker_tags_123 = ["worker", "user-123"]
|
||||
worker_tags_456 = ["worker", "user-456"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(
|
||||
client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True
|
||||
)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
secret_word = "banana"
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 non-matching worker agents (These should NOT get the message)
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-123"]
|
||||
worker_agents_123 = []
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_123)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
worker_agents_123.append(worker_agent)
|
||||
|
||||
# Create 3 worker agents that should get the message
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-456"]
|
||||
worker_agents_456 = []
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_456)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
worker_agents_456.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!",
|
||||
message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolReturnMessage):
|
||||
tool_response = eval(json.loads(m.tool_return)["message"])
|
||||
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
|
||||
assert len(tool_response) == len(worker_agents)
|
||||
assert len(tool_response) == len(worker_agents_456)
|
||||
|
||||
# We can break after this, the ToolReturnMessage after is not related
|
||||
break
|
||||
|
||||
# Conversation search the worker agents
|
||||
for agent in worker_agents:
|
||||
for agent in worker_agents_456:
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
@@ -183,13 +184,22 @@ def test_send_message_to_agents_with_tags_simple(client):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Ensure it's NOT in the non matching worker agents
|
||||
for agent in worker_agents_123:
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word not in m.content
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents:
|
||||
for agent in worker_agents_456 + worker_agents_123:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
@@ -203,8 +213,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
@@ -245,8 +255,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 2 worker agents
|
||||
@@ -260,7 +270,7 @@ def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
broadcast_message = f"Using your tool named `send_message_to_agents_matching_all_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
|
||||
broadcast_message = f"Using your tool named `send_message_to_agents_matching_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
|
||||
client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
|
||||
@@ -65,10 +65,8 @@ def test_multi_agent_large(client, roll_dice_tool, num_workers):
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(
|
||||
name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags
|
||||
)
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags)
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
|
||||
@@ -474,6 +474,45 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
server.source_manager.delete_source(default_source.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_with_tags(server: SyncServer, default_user):
|
||||
"""Fixture to create agents with specific tags."""
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent1",
|
||||
tags=["primary_agent", "benefit_1"],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent2",
|
||||
tags=["primary_agent", "benefit_2"],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent3 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent3",
|
||||
tags=["primary_agent", "benefit_1", "benefit_2"],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
return [agent1, agent2, agent3]
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -777,6 +816,45 @@ def test_list_attached_agents_nonexistent_source(server: SyncServer, default_use
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_list_agents_matching_all_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent", "benefit_1"],
|
||||
match_some=[],
|
||||
)
|
||||
assert len(agents) == 2 # agent1 and agent3 match
|
||||
assert {a.name for a in agents} == {"agent1", "agent3"}
|
||||
|
||||
|
||||
def test_list_agents_matching_some_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent"],
|
||||
match_some=["benefit_1", "benefit_2"],
|
||||
)
|
||||
assert len(agents) == 3 # All agents match
|
||||
assert {a.name for a in agents} == {"agent1", "agent2", "agent3"}
|
||||
|
||||
|
||||
def test_list_agents_matching_all_and_some_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent", "benefit_1"],
|
||||
match_some=["benefit_2", "nonexistent"],
|
||||
)
|
||||
assert len(agents) == 1 # Only agent3 matches
|
||||
assert agents[0].name == "agent3"
|
||||
|
||||
|
||||
def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent", "nonexistent_tag"],
|
||||
match_some=["benefit_1", "benefit_2"],
|
||||
)
|
||||
assert len(agents) == 0 # No agent should match
|
||||
|
||||
|
||||
def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
"""Test listing agents that have ALL specified tags."""
|
||||
# Create agents with multiple tags
|
||||
|
||||
Reference in New Issue
Block a user