fix: Robust new streaming interface for multi-agent tooling (#907)

This commit is contained in:
Matthew Zhou
2025-02-05 16:20:52 -05:00
committed by GitHub
parent 35c4df1d07
commit 3cd3cd4f5f
7 changed files with 296 additions and 126 deletions

View File

@@ -53,6 +53,7 @@ BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3
MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60
MULTI_AGENT_CONCURRENT_SENDS = 15
# The name of the tool used to send message to the user
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)

View File

@@ -1,11 +1,13 @@
import asyncio
from typing import TYPE_CHECKING, List
from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT
from letta.functions.helpers import async_send_message_with_retries, execute_send_message_to_agent, fire_and_forget_send_to_agent
from letta.functions.helpers import (
_send_message_to_agents_matching_all_tags_async,
execute_send_message_to_agent,
fire_and_forget_send_to_agent,
)
from letta.schemas.enums import MessageRole
from letta.schemas.message import MessageCreate
from letta.server.rest_api.utils import get_letta_server
if TYPE_CHECKING:
from letta.agent import Agent
@@ -22,12 +24,13 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
Returns:
str: The response from the target agent.
"""
message = (
augmented_message = (
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
return execute_send_message_to_agent(
sender_agent=self,
messages=messages,
@@ -81,33 +84,4 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags:
have an entry in the returned list.
"""
server = get_letta_server()
message = (
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
# Retrieve agents that match ALL specified tags
matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100)
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
async def send_messages_to_all_agents():
tasks = [
async_send_message_with_retries(
server=server,
sender_agent=self,
target_agent_id=agent_state.id,
messages=messages,
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
logging_prefix="[send_message_to_agents_matching_all_tags]",
)
for agent_state in matching_agents
]
# Run all tasks in parallel
return await asyncio.gather(*tasks)
# Run the async function and return results
return asyncio.run(send_messages_to_all_agents())
return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags))

View File

@@ -1,5 +1,4 @@
import asyncio
import json
import threading
from random import uniform
from typing import Any, List, Optional, Union
@@ -12,13 +11,17 @@ from letta.constants import (
COMPOSIO_ENTITY_ENV_VAR_KEY,
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
MULTI_AGENT_CONCURRENT_SENDS,
MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
)
from letta.functions.interface import MultiAgentMessagingInterface
from letta.orm.errors import NoResultFound
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import AssistantMessage
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import MessageCreate
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.server.rest_api.utils import get_letta_server
@@ -249,29 +252,48 @@ def generate_import_code(module_attr_map: Optional[dict]):
def parse_letta_response_for_assistant_message(
target_agent_id: str,
letta_response: LettaResponse,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> Optional[str]:
messages = []
# This is not ideal, but we would like to return something rather than nothing
fallback_reasoning = []
for m in letta_response.messages:
if isinstance(m, AssistantMessage):
messages.append(m.content)
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
try:
messages.append(json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg])
except Exception: # TODO: Make this more specific
continue
elif isinstance(m, ReasoningMessage):
fallback_reasoning.append(m.reasoning)
if messages:
messages_str = "\n".join(messages)
return f"Agent {target_agent_id} said: '{messages_str}'"
return f"{target_agent_id} said: '{messages_str}'"
else:
messages_str = "\n".join(fallback_reasoning)
return f"Agent {target_agent_id}'s inner thoughts: '{messages_str}'"
return f"No response from {target_agent_id}"
async def async_execute_send_message_to_agent(
sender_agent: "Agent",
messages: List[MessageCreate],
other_agent_id: str,
log_prefix: str,
) -> Optional[str]:
"""
Async helper to:
1) validate the target agent exists & is in the same org,
2) send a message via async_send_message_with_retries.
"""
server = get_letta_server()
# 1. Validate target agent
try:
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user)
except NoResultFound:
raise ValueError(f"Target agent {other_agent_id} either does not exist or is not in org " f"({sender_agent.user.organization_id}).")
# 2. Use your async retry logic
return await async_send_message_with_retries(
server=server,
sender_agent=sender_agent,
target_agent_id=other_agent_id,
messages=messages,
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
logging_prefix=log_prefix,
)
def execute_send_message_to_agent(
@@ -281,53 +303,43 @@ def execute_send_message_to_agent(
log_prefix: str,
) -> Optional[str]:
"""
Helper function to send a message to a specific Letta agent.
Args:
sender_agent ("Agent"): The sender agent object.
message (str): The message to send.
other_agent_id (str): The identifier of the target Letta agent.
log_prefix (str): Logging prefix for retries.
Returns:
Optional[str]: The response from the Letta agent if required by the caller.
Synchronous wrapper that calls `async_execute_send_message_to_agent` using asyncio.run.
This function must be called from a synchronous context (i.e., no running event loop).
"""
server = get_letta_server()
return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix))
# Ensure the target agent is in the same org
try:
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user)
except NoResultFound:
raise ValueError(
f"The passed-in agent_id {other_agent_id} either does not exist, "
f"or does not belong to the same org ({sender_agent.user.organization_id})."
)
# Async logic to send a message with retries and timeout
async def async_send():
return await async_send_message_with_retries(
server=server,
sender_agent=sender_agent,
target_agent_id=other_agent_id,
messages=messages,
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
logging_prefix=log_prefix,
)
async def send_message_to_agent_no_stream(
server: "SyncServer",
agent_id: str,
actor: User,
messages: Union[List[Message], List[MessageCreate]],
metadata: Optional[dict] = None,
) -> LettaResponse:
"""
A simpler helper to send messages to a single agent WITHOUT streaming.
Returns a LettaResponse containing the final messages.
"""
interface = MultiAgentMessagingInterface()
if metadata:
interface.metadata = metadata
# Run in the current event loop or create one if needed
try:
return asyncio.run(async_send())
except RuntimeError:
loop = asyncio.get_event_loop()
if loop.is_running():
return loop.run_until_complete(async_send())
else:
raise
# Offload the synchronous `send_messages` call
usage_stats = await asyncio.to_thread(
server.send_messages,
actor=actor,
agent_id=agent_id,
messages=messages,
interface=interface,
metadata=metadata,
)
final_messages = interface.get_captured_send_messages()
return LettaResponse(messages=final_messages, usage=usage_stats)
async def async_send_message_with_retries(
server,
server: "SyncServer",
sender_agent: "Agent",
target_agent_id: str,
messages: List[MessageCreate],
@@ -335,57 +347,34 @@ async def async_send_message_with_retries(
timeout: int,
logging_prefix: Optional[str] = None,
) -> str:
"""
Shared helper coroutine to send a message to an agent with retries and a timeout.
Args:
server: The Letta server instance (from get_letta_server()).
sender_agent (Agent): The agent initiating the send action.
target_agent_id (str): The ID of the agent to send the message to.
message_text (str): The text to send as the user message.
max_retries (int): Maximum number of retries for the request.
timeout (int): Maximum time to wait for a response (in seconds).
logging_prefix (str): A prefix to append to logging
Returns:
str: The response or an error message.
"""
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
for attempt in range(1, max_retries + 1):
try:
# Wrap in a timeout
response = await asyncio.wait_for(
server.send_message_to_agent(
send_message_to_agent_no_stream(
server=server,
agent_id=target_agent_id,
actor=sender_agent.user,
messages=messages,
stream_steps=False,
stream_tokens=False,
use_assistant_message=True,
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
),
timeout=timeout,
)
# Extract assistant message
assistant_message = parse_letta_response_for_assistant_message(
target_agent_id,
response,
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
)
# Then parse out the assistant message
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response)
if assistant_message:
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
return assistant_message
else:
msg = f"(No response from agent {target_agent_id})"
sender_agent.logger.info(f"{logging_prefix} - {msg}")
sender_agent.logger.info(f"{logging_prefix} - raw response: {response.model_dump_json(indent=4)}")
sender_agent.logger.info(f"{logging_prefix} - parsed assistant message: {assistant_message}")
return msg
except asyncio.TimeoutError:
error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})"
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
except Exception as e:
error_msg = f"(Error on attempt {attempt}/{max_retries} for agent {target_agent_id}: {e})"
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
@@ -393,10 +382,10 @@ async def async_send_message_with_retries(
# Exponential backoff before retrying
if attempt < max_retries:
backoff = uniform(0.5, 2) * (2**attempt)
sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent to agent send_message...sleeping for {backoff}")
sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent-to-agent send_message...sleeping for {backoff}")
await asyncio.sleep(backoff)
else:
sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}")
sender_agent.logger.error(f"{logging_prefix} - Fatal error: {error_msg}")
raise Exception(error_msg)
@@ -482,3 +471,43 @@ def fire_and_forget_send_to_agent(
except RuntimeError:
# Means no event loop is running in this thread
run_in_background_thread(background_task())
async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]:
server = get_letta_server()
augmented_message = (
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
# Retrieve up to 100 matching agents
matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100)
# Create a system message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
# Possibly limit concurrency to avoid meltdown:
sem = asyncio.Semaphore(MULTI_AGENT_CONCURRENT_SENDS)
async def _send_single(agent_state):
async with sem:
return await async_send_message_with_retries(
server=server,
sender_agent=sender_agent,
target_agent_id=agent_state.id,
messages=messages,
max_retries=3,
timeout=30,
)
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in matching_agents]
results = await asyncio.gather(*tasks, return_exceptions=True)
final = []
for r in results:
if isinstance(r, Exception):
final.append(str(r))
else:
final.append(r)
return final

View File

@@ -0,0 +1,75 @@
import json
from typing import List, Optional
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.interface import AgentInterface
from letta.schemas.letta_message import AssistantMessage, LettaMessage
from letta.schemas.message import Message
class MultiAgentMessagingInterface(AgentInterface):
"""
A minimal interface that captures *only* calls to the 'send_message' function
by inspecting msg_obj.tool_calls. We parse out the 'message' field from the
JSON function arguments and store it as an AssistantMessage.
"""
def __init__(self):
self._captured_messages: List[AssistantMessage] = []
self.metadata = {}
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
"""Ignore internal monologue."""
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Ignore normal assistant messages (only capturing send_message calls)."""
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
"""
Called whenever the agent logs a function call. We'll inspect msg_obj.tool_calls:
- If tool_calls include a function named 'send_message', parse its arguments
- Extract the 'message' field
- Save it as an AssistantMessage in self._captured_messages
"""
if not msg_obj or not msg_obj.tool_calls:
return
for tool_call in msg_obj.tool_calls:
if not tool_call.function:
continue
if tool_call.function.name != DEFAULT_MESSAGE_TOOL:
# Skip any other function calls
continue
# Now parse the JSON in tool_call.function.arguments
func_args_str = tool_call.function.arguments or ""
try:
data = json.loads(func_args_str)
# Extract the 'message' key if present
content = data.get(DEFAULT_MESSAGE_TOOL_KWARG, str(data))
except json.JSONDecodeError:
# If we can't parse, store the raw string
content = func_args_str
# Store as an AssistantMessage
new_msg = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
content=content,
)
self._captured_messages.append(new_msg)
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Ignore user messages."""
def step_complete(self):
"""No streaming => no step boundaries."""
def step_yield(self):
"""No streaming => no final yield needed."""
def get_captured_send_messages(self) -> List[LettaMessage]:
"""
Returns only the messages extracted from 'send_message' calls.
"""
return self._captured_messages

View File

@@ -315,7 +315,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# extra prints
self.debug = False
self.timeout = 30
self.timeout = 10 * 60 # 10 minute timeout
def _reset_inner_thoughts_json_reader(self):
# A buffer for accumulating function arguments (we want to buffer keys and run checks on each one)
@@ -330,7 +330,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
while self._active:
try:
# Wait until there is an item in the deque or the stream is deactivated
await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout
await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
break # Exit the loop if we timeout

View File

@@ -0,0 +1,91 @@
import json
import os
import pytest
from tqdm import tqdm
from letta import create_client
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool import Tool
from tests.integration_test_summarizer import LLM_CONFIG_DIR
@pytest.fixture(scope="function")
def client():
filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-haiku.json")
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
client = create_client()
client.set_default_llm_config(llm_config)
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture
def roll_dice_tool(client):
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
return "Rolled a 5!"
# Set up tool details
source_code = parse_source_code(roll_dice)
source_type = "python"
description = "test_description"
tags = ["test"]
tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user)
# Yield the created tool
yield tool
def test_multi_agent_large(client, roll_dice_tool):
manager_tags = ["manager"]
worker_tags = ["helpers"]
# Clean up first from possibly failed tests
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags + manager_tags, match_all_tags=True)
for agent in prev_worker_agents:
client.delete_agent(agent.id)
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(
name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags
)
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents
worker_agents = []
num_workers = 50
for idx in tqdm(range(num_workers)):
worker_agent_state = client.create_agent(
name=f"worker-{idx}", include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id]
)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",
message=broadcast_message,
)
# Please manually inspect the agent results

View File

@@ -173,7 +173,7 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
# Search the sender agent for the response from another agent
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
found = False
target_snippet = f"Agent {other_agent_obj.agent_state.id} said:"
target_snippet = f"{other_agent_obj.agent_state.id} said:"
for m in in_context_messages:
if target_snippet in m.text: