From 5f6918206326e2073525a97729c9f8d79eb745cc Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 6 Mar 2025 13:26:59 -0800 Subject: [PATCH] feat: Modify multi agent broadcast for partial matching (#1208) --- letta/constants.py | 2 +- letta/functions/function_sets/multi_agent.py | 17 ++-- letta/functions/helpers.py | 39 ++++++++-- letta/services/agent_manager.py | 43 ++++++++++ tests/integration_test_multi_agent.py | 52 ++++++++----- ...manual_test_multi_agent_broadcast_large.py | 6 +- tests/test_managers.py | 78 +++++++++++++++++++ 7 files changed, 197 insertions(+), 40 deletions(-) diff --git a/letta/constants.py b/letta/constants.py index e06984a3..7408d05d 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -50,7 +50,7 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", " # Base memory tools CAN be edited, and are added by default by the server BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] # Multi agent tools -MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"] +MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"] # Set of all built-in Letta tools LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS) diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index bd8f7a94..1f702b24 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -2,7 +2,7 @@ import asyncio from typing import TYPE_CHECKING, List from letta.functions.helpers import ( - _send_message_to_agents_matching_all_tags_async, + _send_message_to_agents_matching_tags_async, execute_send_message_to_agent, fire_and_forget_send_to_agent, ) @@ -70,18 +70,19 @@ def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str return "Successfully sent message" -def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]: +def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]: """ - Sends a message to all agents within the same organization that match all of the specified tags. Messages are dispatched in parallel for improved performance, with retries to handle transient issues and timeouts to ensure responsiveness. This function enforces a limit of 100 agents and does not support pagination (cursor-based queries). Each agent must match all specified tags (`match_all_tags=True`) to be included. + Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message. Args: message (str): The content of the message to be sent to each matching agent. - tags (List[str]): A list of tags that an agent must possess to receive the message. + match_all (List[str]): A list of tags that an agent must possess to receive the message. + match_some (List[str]): A list of tags where an agent must have at least one to qualify. Returns: - List[str]: A list of responses from the agents that matched all tags. Each - response corresponds to a single agent. Agents that do not respond will not - have an entry in the returned list. + List[str]: A list of responses from the agents that matched the filtering criteria. Each + response corresponds to a single agent. Agents that do not respond will not have an entry + in the returned list. """ - return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags)) + return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some)) diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 03b27e40..19446e05 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -518,8 +518,16 @@ def fire_and_forget_send_to_agent( run_in_background_thread(background_task()) -async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]: - log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async start", message=message, tags=tags) +async def _send_message_to_agents_matching_tags_async( + sender_agent: "Agent", message: str, match_all: List[str], match_some: List[str] +) -> List[str]: + log_telemetry( + sender_agent.logger, + "_send_message_to_agents_matching_tags_async start", + message=message, + match_all=match_all, + match_some=match_some, + ) server = get_letta_server() augmented_message = ( @@ -529,9 +537,22 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", ) # Retrieve up to 100 matching agents - log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents start", message=message, tags=tags) - matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100) - log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents finish", message=message, tags=tags) + log_telemetry( + sender_agent.logger, + "_send_message_to_agents_matching_tags_async listing agents start", + message=message, + match_all=match_all, + match_some=match_some, + ) + matching_agents = server.agent_manager.list_agents_matching_tags(actor=sender_agent.user, match_all=match_all, match_some=match_some) + + log_telemetry( + sender_agent.logger, + "_send_message_to_agents_matching_tags_async listing agents finish", + message=message, + match_all=match_all, + match_some=match_some, + ) # Create a system message messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)] @@ -559,7 +580,13 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", else: final.append(r) - log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async finish", message=message, tags=tags) + log_telemetry( + sender_agent.logger, + "_send_message_to_agents_matching_tags_async finish", + message=message, + match_all=match_all, + match_some=match_some, + ) return final diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 6b754dd3..8f92eaf2 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -358,6 +358,49 @@ class AgentManager: return [agent.to_pydantic() for agent in agents] + @enforce_types + def list_agents_matching_tags( + self, + actor: PydanticUser, + match_all: List[str], + match_some: List[str], + limit: Optional[int] = 50, + ) -> List[PydanticAgentState]: + """ + Retrieves agents in the same organization that match all specified `match_all` tags + and at least one tag from `match_some`. The query is optimized for efficiency by + leveraging indexed filtering and aggregation. + + Args: + actor (PydanticUser): The user requesting the agent list. + match_all (List[str]): Agents must have all these tags. + match_some (List[str]): Agents must have at least one of these tags. + limit (Optional[int]): Maximum number of agents to return. + + Returns: + List[PydanticAgentState: The filtered list of matching agents. + """ + with self.session_maker() as session: + query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id) + + if match_all: + # Subquery to find agent IDs that contain all match_all tags + subquery = ( + select(AgentsTags.agent_id) + .where(AgentsTags.tag.in_(match_all)) + .group_by(AgentsTags.agent_id) + .having(func.count(AgentsTags.tag) == literal(len(match_all))) + ) + query = query.where(AgentModel.id.in_(subquery)) + + if match_some: + # Ensures agents match at least one tag in match_some + query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some)) + + query = query.group_by(AgentModel.id).limit(limit) + + return list(session.execute(query).scalars()) + @enforce_types def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" diff --git a/tests/integration_test_multi_agent.py b/tests/integration_test_multi_agent.py index 91df2e24..30413d69 100644 --- a/tests/integration_test_multi_agent.py +++ b/tests/integration_test_multi_agent.py @@ -127,54 +127,55 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj): @retry_until_success(max_attempts=3, sleep_time_seconds=2) def test_send_message_to_agents_with_tags_simple(client): - worker_tags = ["worker", "user-456"] + worker_tags_123 = ["worker", "user-123"] + worker_tags_456 = ["worker", "user-456"] # Clean up first from possibly failed tests - prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True) + prev_worker_agents = client.server.agent_manager.list_agents( + client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True + ) for agent in prev_worker_agents: client.delete_agent(agent.id) secret_word = "banana" # Create "manager" agent - send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") - manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) + send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") + manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id]) manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) # Create 3 non-matching worker agents (These should NOT get the message) - worker_agents = [] - worker_tags = ["worker", "user-123"] + worker_agents_123 = [] for _ in range(3): - worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags) + worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_123) worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) + worker_agents_123.append(worker_agent) # Create 3 worker agents that should get the message - worker_agents = [] - worker_tags = ["worker", "user-456"] + worker_agents_456 = [] for _ in range(3): - worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags) + worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_456) worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) + worker_agents_456.append(worker_agent) # Encourage the manager to send a message to the other agent_obj with the secret string response = client.send_message( agent_id=manager_agent.agent_state.id, role="user", - message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!", + message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!", ) for m in response.messages: if isinstance(m, ToolReturnMessage): tool_response = eval(json.loads(m.tool_return)["message"]) print(f"\n\nManager agent tool response: \n{tool_response}\n\n") - assert len(tool_response) == len(worker_agents) + assert len(tool_response) == len(worker_agents_456) # We can break after this, the ToolReturnMessage after is not related break # Conversation search the worker agents - for agent in worker_agents: + for agent in worker_agents_456: messages = client.get_messages(agent.agent_state.id) # Check for the presence of system message for m in reversed(messages): @@ -183,13 +184,22 @@ def test_send_message_to_agents_with_tags_simple(client): assert secret_word in m.content break + # Ensure it's NOT in the non matching worker agents + for agent in worker_agents_123: + messages = client.get_messages(agent.agent_state.id) + # Check for the presence of system message + for m in reversed(messages): + print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}") + if isinstance(m, SystemMessage): + assert secret_word not in m.content + # Test that the agent can still receive messages fine response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages])) # Clean up agents client.delete_agent(manager_agent_state.id) - for agent in worker_agents: + for agent in worker_agents_456 + worker_agents_123: client.delete_agent(agent.agent_state.id) @@ -203,8 +213,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too client.delete_agent(agent.id) # Create "manager" agent - send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") - manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) + send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") + manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id]) manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) # Create 3 worker agents @@ -245,8 +255,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too @retry_until_success(max_attempts=3, sleep_time_seconds=2) def test_send_message_to_sub_agents_auto_clear_message_buffer(client): # Create "manager" agent - send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") - manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) + send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") + manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id]) manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) # Create 2 worker agents @@ -260,7 +270,7 @@ def test_send_message_to_sub_agents_auto_clear_message_buffer(client): worker_agents.append(worker_agent) # Encourage the manager to send a message to the other agent_obj with the secret string - broadcast_message = f"Using your tool named `send_message_to_agents_matching_all_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!" + broadcast_message = f"Using your tool named `send_message_to_agents_matching_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!" client.send_message( agent_id=manager_agent.agent_state.id, role="user", diff --git a/tests/manual_test_multi_agent_broadcast_large.py b/tests/manual_test_multi_agent_broadcast_large.py index 2108f03a..70d88f44 100644 --- a/tests/manual_test_multi_agent_broadcast_large.py +++ b/tests/manual_test_multi_agent_broadcast_large.py @@ -65,10 +65,8 @@ def test_multi_agent_large(client, roll_dice_tool, num_workers): client.delete_agent(agent.id) # Create "manager" agent - send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") - manager_agent_state = client.create_agent( - name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags - ) + send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") + manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags) manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) # Create 3 worker agents diff --git a/tests/test_managers.py b/tests/test_managers.py index ba4f0f93..ec4d2033 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -474,6 +474,45 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent): server.source_manager.delete_source(default_source.id, actor=actor) +@pytest.fixture +def agent_with_tags(server: SyncServer, default_user): + """Fixture to create agents with specific tags.""" + agent1 = server.agent_manager.create_agent( + agent_create=CreateAgent( + name="agent1", + tags=["primary_agent", "benefit_1"], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + memory_blocks=[], + ), + actor=default_user, + ) + + agent2 = server.agent_manager.create_agent( + agent_create=CreateAgent( + name="agent2", + tags=["primary_agent", "benefit_2"], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + memory_blocks=[], + ), + actor=default_user, + ) + + agent3 = server.agent_manager.create_agent( + agent_create=CreateAgent( + name="agent3", + tags=["primary_agent", "benefit_1", "benefit_2"], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + memory_blocks=[], + ), + actor=default_user, + ) + + return [agent1, agent2, agent3] + + # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== @@ -777,6 +816,45 @@ def test_list_attached_agents_nonexistent_source(server: SyncServer, default_use # ====================================================================================================================== +def test_list_agents_matching_all_tags(server: SyncServer, default_user, agent_with_tags): + agents = server.agent_manager.list_agents_matching_tags( + actor=default_user, + match_all=["primary_agent", "benefit_1"], + match_some=[], + ) + assert len(agents) == 2 # agent1 and agent3 match + assert {a.name for a in agents} == {"agent1", "agent3"} + + +def test_list_agents_matching_some_tags(server: SyncServer, default_user, agent_with_tags): + agents = server.agent_manager.list_agents_matching_tags( + actor=default_user, + match_all=["primary_agent"], + match_some=["benefit_1", "benefit_2"], + ) + assert len(agents) == 3 # All agents match + assert {a.name for a in agents} == {"agent1", "agent2", "agent3"} + + +def test_list_agents_matching_all_and_some_tags(server: SyncServer, default_user, agent_with_tags): + agents = server.agent_manager.list_agents_matching_tags( + actor=default_user, + match_all=["primary_agent", "benefit_1"], + match_some=["benefit_2", "nonexistent"], + ) + assert len(agents) == 1 # Only agent3 matches + assert agents[0].name == "agent3" + + +def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_with_tags): + agents = server.agent_manager.list_agents_matching_tags( + actor=default_user, + match_all=["primary_agent", "nonexistent_tag"], + match_some=["benefit_1", "benefit_2"], + ) + assert len(agents) == 0 # No agent should match + + def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user): """Test listing agents that have ALL specified tags.""" # Create agents with multiple tags