diff --git a/letta/constants.py b/letta/constants.py index ea42306a..1269dc84 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -53,6 +53,7 @@ BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"] MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3 MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60 +MULTI_AGENT_CONCURRENT_SENDS = 15 # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index ef607713..bd8f7a94 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -1,11 +1,13 @@ import asyncio from typing import TYPE_CHECKING, List -from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT -from letta.functions.helpers import async_send_message_with_retries, execute_send_message_to_agent, fire_and_forget_send_to_agent +from letta.functions.helpers import ( + _send_message_to_agents_matching_all_tags_async, + execute_send_message_to_agent, + fire_and_forget_send_to_agent, +) from letta.schemas.enums import MessageRole from letta.schemas.message import MessageCreate -from letta.server.rest_api.utils import get_letta_server if TYPE_CHECKING: from letta.agent import Agent @@ -22,12 +24,13 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_ Returns: str: The response from the target agent. """ - message = ( + augmented_message = ( f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, " f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " f"{message}" ) - messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)] + messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)] + return execute_send_message_to_agent( sender_agent=self, messages=messages, @@ -81,33 +84,4 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: have an entry in the returned list. """ - server = get_letta_server() - - message = ( - f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, " - f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " - f"{message}" - ) - - # Retrieve agents that match ALL specified tags - matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100) - messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)] - - async def send_messages_to_all_agents(): - tasks = [ - async_send_message_with_retries( - server=server, - sender_agent=self, - target_agent_id=agent_state.id, - messages=messages, - max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, - timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, - logging_prefix="[send_message_to_agents_matching_all_tags]", - ) - for agent_state in matching_agents - ] - # Run all tasks in parallel - return await asyncio.gather(*tasks) - - # Run the async function and return results - return asyncio.run(send_messages_to_all_agents()) + return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags)) diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 8c232cd5..fe179e4a 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -1,5 +1,4 @@ import asyncio -import json import threading from random import uniform from typing import Any, List, Optional, Union @@ -12,13 +11,17 @@ from letta.constants import ( COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, + MULTI_AGENT_CONCURRENT_SENDS, MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT, ) +from letta.functions.interface import MultiAgentMessagingInterface from letta.orm.errors import NoResultFound -from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import AssistantMessage from letta.schemas.letta_response import LettaResponse -from letta.schemas.message import MessageCreate +from letta.schemas.message import Message, MessageCreate +from letta.schemas.user import User from letta.server.rest_api.utils import get_letta_server @@ -249,29 +252,48 @@ def generate_import_code(module_attr_map: Optional[dict]): def parse_letta_response_for_assistant_message( target_agent_id: str, letta_response: LettaResponse, - assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, - assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> Optional[str]: messages = [] - # This is not ideal, but we would like to return something rather than nothing - fallback_reasoning = [] for m in letta_response.messages: if isinstance(m, AssistantMessage): messages.append(m.content) - elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name: - try: - messages.append(json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]) - except Exception: # TODO: Make this more specific - continue - elif isinstance(m, ReasoningMessage): - fallback_reasoning.append(m.reasoning) if messages: messages_str = "\n".join(messages) - return f"Agent {target_agent_id} said: '{messages_str}'" + return f"{target_agent_id} said: '{messages_str}'" else: - messages_str = "\n".join(fallback_reasoning) - return f"Agent {target_agent_id}'s inner thoughts: '{messages_str}'" + return f"No response from {target_agent_id}" + + +async def async_execute_send_message_to_agent( + sender_agent: "Agent", + messages: List[MessageCreate], + other_agent_id: str, + log_prefix: str, +) -> Optional[str]: + """ + Async helper to: + 1) validate the target agent exists & is in the same org, + 2) send a message via async_send_message_with_retries. + """ + server = get_letta_server() + + # 1. Validate target agent + try: + server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user) + except NoResultFound: + raise ValueError(f"Target agent {other_agent_id} either does not exist or is not in org " f"({sender_agent.user.organization_id}).") + + # 2. Use your async retry logic + return await async_send_message_with_retries( + server=server, + sender_agent=sender_agent, + target_agent_id=other_agent_id, + messages=messages, + max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, + timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, + logging_prefix=log_prefix, + ) def execute_send_message_to_agent( @@ -281,53 +303,43 @@ def execute_send_message_to_agent( log_prefix: str, ) -> Optional[str]: """ - Helper function to send a message to a specific Letta agent. - - Args: - sender_agent ("Agent"): The sender agent object. - message (str): The message to send. - other_agent_id (str): The identifier of the target Letta agent. - log_prefix (str): Logging prefix for retries. - - Returns: - Optional[str]: The response from the Letta agent if required by the caller. + Synchronous wrapper that calls `async_execute_send_message_to_agent` using asyncio.run. + This function must be called from a synchronous context (i.e., no running event loop). """ - server = get_letta_server() + return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix)) - # Ensure the target agent is in the same org - try: - server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user) - except NoResultFound: - raise ValueError( - f"The passed-in agent_id {other_agent_id} either does not exist, " - f"or does not belong to the same org ({sender_agent.user.organization_id})." - ) - # Async logic to send a message with retries and timeout - async def async_send(): - return await async_send_message_with_retries( - server=server, - sender_agent=sender_agent, - target_agent_id=other_agent_id, - messages=messages, - max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, - timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, - logging_prefix=log_prefix, - ) +async def send_message_to_agent_no_stream( + server: "SyncServer", + agent_id: str, + actor: User, + messages: Union[List[Message], List[MessageCreate]], + metadata: Optional[dict] = None, +) -> LettaResponse: + """ + A simpler helper to send messages to a single agent WITHOUT streaming. + Returns a LettaResponse containing the final messages. + """ + interface = MultiAgentMessagingInterface() + if metadata: + interface.metadata = metadata - # Run in the current event loop or create one if needed - try: - return asyncio.run(async_send()) - except RuntimeError: - loop = asyncio.get_event_loop() - if loop.is_running(): - return loop.run_until_complete(async_send()) - else: - raise + # Offload the synchronous `send_messages` call + usage_stats = await asyncio.to_thread( + server.send_messages, + actor=actor, + agent_id=agent_id, + messages=messages, + interface=interface, + metadata=metadata, + ) + + final_messages = interface.get_captured_send_messages() + return LettaResponse(messages=final_messages, usage=usage_stats) async def async_send_message_with_retries( - server, + server: "SyncServer", sender_agent: "Agent", target_agent_id: str, messages: List[MessageCreate], @@ -335,57 +347,34 @@ async def async_send_message_with_retries( timeout: int, logging_prefix: Optional[str] = None, ) -> str: - """ - Shared helper coroutine to send a message to an agent with retries and a timeout. - Args: - server: The Letta server instance (from get_letta_server()). - sender_agent (Agent): The agent initiating the send action. - target_agent_id (str): The ID of the agent to send the message to. - message_text (str): The text to send as the user message. - max_retries (int): Maximum number of retries for the request. - timeout (int): Maximum time to wait for a response (in seconds). - logging_prefix (str): A prefix to append to logging - Returns: - str: The response or an error message. - """ logging_prefix = logging_prefix or "[async_send_message_with_retries]" for attempt in range(1, max_retries + 1): try: - # Wrap in a timeout response = await asyncio.wait_for( - server.send_message_to_agent( + send_message_to_agent_no_stream( + server=server, agent_id=target_agent_id, actor=sender_agent.user, messages=messages, - stream_steps=False, - stream_tokens=False, - use_assistant_message=True, - assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, - assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, ), timeout=timeout, ) - # Extract assistant message - assistant_message = parse_letta_response_for_assistant_message( - target_agent_id, - response, - assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, - assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, - ) + # Then parse out the assistant message + assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response) if assistant_message: sender_agent.logger.info(f"{logging_prefix} - {assistant_message}") return assistant_message else: msg = f"(No response from agent {target_agent_id})" sender_agent.logger.info(f"{logging_prefix} - {msg}") - sender_agent.logger.info(f"{logging_prefix} - raw response: {response.model_dump_json(indent=4)}") - sender_agent.logger.info(f"{logging_prefix} - parsed assistant message: {assistant_message}") return msg + except asyncio.TimeoutError: error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})" sender_agent.logger.warning(f"{logging_prefix} - {error_msg}") + except Exception as e: error_msg = f"(Error on attempt {attempt}/{max_retries} for agent {target_agent_id}: {e})" sender_agent.logger.warning(f"{logging_prefix} - {error_msg}") @@ -393,10 +382,10 @@ async def async_send_message_with_retries( # Exponential backoff before retrying if attempt < max_retries: backoff = uniform(0.5, 2) * (2**attempt) - sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent to agent send_message...sleeping for {backoff}") + sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent-to-agent send_message...sleeping for {backoff}") await asyncio.sleep(backoff) else: - sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}") + sender_agent.logger.error(f"{logging_prefix} - Fatal error: {error_msg}") raise Exception(error_msg) @@ -482,3 +471,43 @@ def fire_and_forget_send_to_agent( except RuntimeError: # Means no event loop is running in this thread run_in_background_thread(background_task()) + + +async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]: + server = get_letta_server() + + augmented_message = ( + f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, " + f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " + f"{message}" + ) + + # Retrieve up to 100 matching agents + matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100) + + # Create a system message + messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)] + + # Possibly limit concurrency to avoid meltdown: + sem = asyncio.Semaphore(MULTI_AGENT_CONCURRENT_SENDS) + + async def _send_single(agent_state): + async with sem: + return await async_send_message_with_retries( + server=server, + sender_agent=sender_agent, + target_agent_id=agent_state.id, + messages=messages, + max_retries=3, + timeout=30, + ) + + tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in matching_agents] + results = await asyncio.gather(*tasks, return_exceptions=True) + final = [] + for r in results: + if isinstance(r, Exception): + final.append(str(r)) + else: + final.append(r) + return final diff --git a/letta/functions/interface.py b/letta/functions/interface.py new file mode 100644 index 00000000..82bf229e --- /dev/null +++ b/letta/functions/interface.py @@ -0,0 +1,75 @@ +import json +from typing import List, Optional + +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.interface import AgentInterface +from letta.schemas.letta_message import AssistantMessage, LettaMessage +from letta.schemas.message import Message + + +class MultiAgentMessagingInterface(AgentInterface): + """ + A minimal interface that captures *only* calls to the 'send_message' function + by inspecting msg_obj.tool_calls. We parse out the 'message' field from the + JSON function arguments and store it as an AssistantMessage. + """ + + def __init__(self): + self._captured_messages: List[AssistantMessage] = [] + self.metadata = {} + + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + """Ignore internal monologue.""" + + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + """Ignore normal assistant messages (only capturing send_message calls).""" + + def function_message(self, msg: str, msg_obj: Optional[Message] = None): + """ + Called whenever the agent logs a function call. We'll inspect msg_obj.tool_calls: + - If tool_calls include a function named 'send_message', parse its arguments + - Extract the 'message' field + - Save it as an AssistantMessage in self._captured_messages + """ + if not msg_obj or not msg_obj.tool_calls: + return + + for tool_call in msg_obj.tool_calls: + if not tool_call.function: + continue + if tool_call.function.name != DEFAULT_MESSAGE_TOOL: + # Skip any other function calls + continue + + # Now parse the JSON in tool_call.function.arguments + func_args_str = tool_call.function.arguments or "" + try: + data = json.loads(func_args_str) + # Extract the 'message' key if present + content = data.get(DEFAULT_MESSAGE_TOOL_KWARG, str(data)) + except json.JSONDecodeError: + # If we can't parse, store the raw string + content = func_args_str + + # Store as an AssistantMessage + new_msg = AssistantMessage( + id=msg_obj.id, + date=msg_obj.created_at, + content=content, + ) + self._captured_messages.append(new_msg) + + def user_message(self, msg: str, msg_obj: Optional[Message] = None): + """Ignore user messages.""" + + def step_complete(self): + """No streaming => no step boundaries.""" + + def step_yield(self): + """No streaming => no final yield needed.""" + + def get_captured_send_messages(self) -> List[LettaMessage]: + """ + Returns only the messages extracted from 'send_message' calls. + """ + return self._captured_messages diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index ded9d749..a9e617f7 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -315,7 +315,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # extra prints self.debug = False - self.timeout = 30 + self.timeout = 10 * 60 # 10 minute timeout def _reset_inner_thoughts_json_reader(self): # A buffer for accumulating function arguments (we want to buffer keys and run checks on each one) @@ -330,7 +330,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): while self._active: try: # Wait until there is an item in the deque or the stream is deactivated - await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout + await asyncio.wait_for(self._event.wait(), timeout=self.timeout) except asyncio.TimeoutError: break # Exit the loop if we timeout diff --git a/tests/manual_test_multi_agent_broadcast_large.py b/tests/manual_test_multi_agent_broadcast_large.py new file mode 100644 index 00000000..4adcfa07 --- /dev/null +++ b/tests/manual_test_multi_agent_broadcast_large.py @@ -0,0 +1,91 @@ +import json +import os + +import pytest +from tqdm import tqdm + +from letta import create_client +from letta.functions.functions import derive_openai_json_schema, parse_source_code +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.tool import Tool +from tests.integration_test_summarizer import LLM_CONFIG_DIR + + +@pytest.fixture(scope="function") +def client(): + filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-haiku.json") + config_data = json.load(open(filename, "r")) + llm_config = LLMConfig(**config_data) + client = create_client() + client.set_default_llm_config(llm_config) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + + yield client + + +@pytest.fixture +def roll_dice_tool(client): + def roll_dice(): + """ + Rolls a 6 sided die. + + Returns: + str: The roll result. + """ + return "Rolled a 5!" + + # Set up tool details + source_code = parse_source_code(roll_dice) + source_type = "python" + description = "test_description" + tags = ["test"] + + tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + + tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user) + + # Yield the created tool + yield tool + + +def test_multi_agent_large(client, roll_dice_tool): + manager_tags = ["manager"] + worker_tags = ["helpers"] + + # Clean up first from possibly failed tests + prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags + manager_tags, match_all_tags=True) + for agent in prev_worker_agents: + client.delete_agent(agent.id) + + # Create "manager" agent + send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") + manager_agent_state = client.create_agent( + name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags + ) + manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) + + # Create 3 worker agents + worker_agents = [] + num_workers = 50 + for idx in tqdm(range(num_workers)): + worker_agent_state = client.create_agent( + name=f"worker-{idx}", include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id] + ) + worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) + worker_agents.append(worker_agent) + + # Encourage the manager to send a message to the other agent_obj with the secret string + broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" + client.send_message( + agent_id=manager_agent.agent_state.id, + role="user", + message=broadcast_message, + ) + + # Please manually inspect the agent results diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 30ba8ab6..b5ce5104 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -173,7 +173,7 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj): # Search the sender agent for the response from another agent in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) found = False - target_snippet = f"Agent {other_agent_obj.agent_state.id} said:" + target_snippet = f"{other_agent_obj.agent_state.id} said:" for m in in_context_messages: if target_snippet in m.text: