feat: Make multi agent broadcast directly invoke step (#1355)

This commit is contained in:
Matthew Zhou
2025-03-20 17:05:04 -07:00
committed by GitHub
parent e1b16c5fea
commit 8e6fd8a991
11 changed files with 167 additions and 147 deletions

View File

@@ -522,7 +522,7 @@ class Agent(BaseAgent):
openai_message_dict=response_message.model_dump(),
)
) # extend conversation with assistant's reply
self.logger.info(f"Function call message: {messages[-1]}")
self.logger.debug(f"Function call message: {messages[-1]}")
nonnull_content = False
if response_message.content:
@@ -786,6 +786,7 @@ class Agent(BaseAgent):
total_usage = UsageStatistics()
step_count = 0
function_failed = False
steps_messages = []
while True:
kwargs["first_message"] = False
kwargs["step_count"] = step_count
@@ -800,6 +801,7 @@ class Agent(BaseAgent):
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage
steps_messages.append(step_response.messages)
step_count += 1
total_usage += usage
@@ -859,9 +861,9 @@ class Agent(BaseAgent):
break
if self.agent_state.message_buffer_autoclear:
self.agent_manager.trim_all_in_context_messages_except_system(self.agent_state.id, actor=self.user)
self.agent_state = self.agent_manager.trim_all_in_context_messages_except_system(self.agent_state.id, actor=self.user)
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count, steps_messages=steps_messages)
def inner_step(
self,

View File

@@ -274,7 +274,7 @@ class VoiceAgent(BaseAgent):
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor

View File

@@ -1,16 +1,19 @@
import asyncio
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING, List
from letta.functions.helpers import (
_send_message_to_agents_matching_tags_async,
_send_message_to_all_agents_in_group_async,
execute_send_message_to_agent,
extract_send_message_from_steps_messages,
fire_and_forget_send_to_agent,
)
from letta.helpers.message_helper import prepare_input_message_create
from letta.schemas.enums import MessageRole
from letta.schemas.message import MessageCreate
from letta.server.rest_api.utils import get_letta_server
from letta.utils import log_telemetry
from letta.settings import settings
if TYPE_CHECKING:
from letta.agent import Agent
@@ -87,51 +90,59 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
response corresponds to a single agent. Agents that do not respond will not have an entry
in the returned list.
"""
log_telemetry(
self.logger,
"_send_message_to_agents_matching_tags_async start",
message=message,
match_all=match_all,
match_some=match_some,
)
server = get_letta_server()
augmented_message = (
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
f"[Incoming message from external Letta agent - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
# Retrieve up to 100 matching agents
log_telemetry(
self.logger,
"_send_message_to_agents_matching_tags_async listing agents start",
message=message,
match_all=match_all,
match_some=match_some,
)
# Find matching agents
matching_agents = server.agent_manager.list_agents_matching_tags(actor=self.user, match_all=match_all, match_some=match_some)
if not matching_agents:
return []
log_telemetry(
self.logger,
"_send_message_to_agents_matching_tags_async listing agents finish",
message=message,
match_all=match_all,
match_some=match_some,
)
def process_agent(agent_id: str) -> str:
"""Loads an agent, formats the message, and executes .step()"""
actor = self.user # Ensure correct actor context
agent = server.load_agent(agent_id=agent_id, interface=None, actor=actor)
# Create a system message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
# Prepare the message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
input_messages = [prepare_input_message_create(m, agent_id) for m in messages]
result = asyncio.run(_send_message_to_agents_matching_tags_async(self, server, messages, matching_agents))
log_telemetry(
self.logger,
"_send_message_to_agents_matching_tags_async finish",
messages=message,
match_all=match_all,
match_some=match_some,
)
return result
# Run .step() and return the response
usage_stats = agent.step(
messages=input_messages,
chaining=True,
max_chaining_steps=None,
stream=False,
skip_verify=True,
metadata=None,
put_inner_thoughts_first=True,
)
send_messages = extract_send_message_from_steps_messages(usage_stats.steps_messages, logger=agent.logger)
response_data = {
"agent_id": agent_id,
"response_messages": send_messages if send_messages else ["<no response>"],
}
return json.dumps(response_data, indent=2)
# Use ThreadPoolExecutor for parallel execution
results = []
with ThreadPoolExecutor(max_workers=settings.multi_agent_concurrent_sends) as executor:
future_to_agent = {executor.submit(process_agent, agent_state.id): agent_state for agent_state in matching_agents}
for future in as_completed(future_to_agent):
try:
results.append(future.result()) # Collect results
except Exception as e:
# Log or handle failure for specific agents if needed
self.logger.exception(f"Error processing agent {future_to_agent[future]}: {e}")
return results
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:

View File

@@ -1,4 +1,6 @@
import asyncio
import json
import logging
import threading
from random import uniform
from typing import Any, Dict, List, Optional, Type, Union
@@ -17,7 +19,6 @@ from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.server.rest_api.utils import get_letta_server
from letta.settings import settings
from letta.utils import log_telemetry
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
@@ -386,15 +387,9 @@ async def async_send_message_with_retries(
logging_prefix: Optional[str] = None,
) -> str:
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
log_telemetry(sender_agent.logger, f"async_send_message_with_retries start", target_agent_id=target_agent_id)
for attempt in range(1, max_retries + 1):
try:
log_telemetry(
sender_agent.logger,
f"async_send_message_with_retries -> asyncio wait for send_message_to_agent_no_stream start",
target_agent_id=target_agent_id,
)
response = await asyncio.wait_for(
send_message_to_agent_no_stream(
server=server,
@@ -404,24 +399,15 @@ async def async_send_message_with_retries(
),
timeout=timeout,
)
log_telemetry(
sender_agent.logger,
f"async_send_message_with_retries -> asyncio wait for send_message_to_agent_no_stream finish",
target_agent_id=target_agent_id,
)
# Then parse out the assistant message
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response)
if assistant_message:
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
log_telemetry(
sender_agent.logger, f"async_send_message_with_retries finish with assistant message", target_agent_id=target_agent_id
)
return assistant_message
else:
msg = f"(No response from agent {target_agent_id})"
sender_agent.logger.info(f"{logging_prefix} - {msg}")
log_telemetry(sender_agent.logger, f"async_send_message_with_retries finish no response", target_agent_id=target_agent_id)
return msg
except asyncio.TimeoutError:
@@ -439,12 +425,6 @@ async def async_send_message_with_retries(
await asyncio.sleep(backoff)
else:
sender_agent.logger.error(f"{logging_prefix} - Fatal error: {error_msg}")
log_telemetry(
sender_agent.logger,
f"async_send_message_with_retries finish fatal error",
target_agent_id=target_agent_id,
error_msg=error_msg,
)
raise Exception(error_msg)
@@ -673,3 +653,27 @@ def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[
else:
return Union[tuple(types)]
raise ValueError(f"Unable to convert pydantic field schema to type: {field_schema}")
def extract_send_message_from_steps_messages(
steps_messages: List[List[Message]],
agent_send_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
agent_send_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
logger: Optional[logging.Logger] = None,
) -> List[str]:
extracted_messages = []
for step in steps_messages:
for message in step:
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.function.name == agent_send_message_tool_name:
try:
# Parse arguments to extract the "message" field
arguments = json.loads(tool_call.function.arguments)
if agent_send_message_tool_kwarg in arguments:
extracted_messages.append(arguments[agent_send_message_tool_kwarg])
except json.JSONDecodeError:
logger.error(f"Failed to parse arguments for tool call: {tool_call.id}")
return extracted_messages

View File

@@ -0,0 +1,41 @@
from letta import system
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
def prepare_input_message_create(
message: MessageCreate,
agent_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> Message:
"""Converts a MessageCreate object into a Message object, applying wrapping if needed."""
# TODO: This seems like extra boilerplate with little benefit
assert isinstance(message, MessageCreate)
# Extract message content
if isinstance(message.content, str):
message_content = message.content
elif message.content and len(message.content) > 0 and isinstance(message.content[0], TextContent):
message_content = message.content[0].text
else:
raise ValueError("Message content is empty or invalid")
# Apply wrapping if needed
if message.role == MessageRole.user and wrap_user_message:
message_content = system.package_user_message(user_message=message_content)
elif message.role == MessageRole.system and wrap_system_message:
message_content = system.package_system_message(system_message=message_content)
elif message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
return Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message_content)] if message_content else [],
name=message.name,
model=None, # assigned later?
tool_calls=None, # irrelevant
tool_call_id=None,
)

View File

@@ -361,14 +361,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if identifier_set != results_set:
# Construct a detailed error message based on query conditions
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
logger.warning(
f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}"
)
logger.debug(f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}")
return results
# Construct a detailed error message based on query conditions
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
logger.warning(f"{cls.__name__} not found with {conditions_str}")
logger.debug(f"{cls.__name__} not found with {conditions_str}")
return []
@handle_db_timeout

View File

@@ -1,7 +1,9 @@
from typing import Literal
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
from letta.schemas.message import Message
class LettaUsageStatistics(BaseModel):
"""
@@ -19,3 +21,5 @@ class LettaUsageStatistics(BaseModel):
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
step_count: int = Field(0, description="The number of steps taken by the agent.")
# TODO: Optional for now. This field makes everyone's lives easier
steps_messages: Optional[List[List[Message]]] = Field(None, description="The messages generated per step")

View File

@@ -26,6 +26,7 @@ from letta.functions.mcp_client.stdio_client import StdioMCPClient
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.message_helper import prepare_input_message_create
# TODO use custom interface
from letta.interface import AgentInterface # abstract
@@ -48,7 +49,7 @@ from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage, PassageUpdate
from letta.schemas.providers import (
@@ -85,7 +86,6 @@ from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager
from letta.services.passage_manager import PassageManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.provider_manager import ProviderManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
@@ -210,9 +210,6 @@ class SyncServer(Server):
self.identity_manager = IdentityManager()
self.group_manager = GroupManager()
# Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager()
# Make default user and org
if init_with_default_org_and_user:
self.default_org = self.organization_manager.create_default_organization()
@@ -353,21 +350,19 @@ class SyncServer(Server):
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
"""Updated method to load agents from persisted storage"""
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
with agent_lock:
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.multi_agent_group:
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.multi_agent_group:
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor, mcp_clients=self.mcp_clients)
elif agent_state.agent_type == AgentType.offline_memory_agent:
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
else:
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor, mcp_clients=self.mcp_clients)
elif agent_state.agent_type == AgentType.offline_memory_agent:
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
else:
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
return agent
return agent
def load_multi_agent(
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
@@ -702,63 +697,22 @@ class SyncServer(Server):
actor: User,
agent_id: str,
messages: Union[List[MessageCreate], List[Message]],
# whether or not to wrap user and system message as MemGPT-style stringified JSON
wrap_user_message: bool = True,
wrap_system_message: bool = True,
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed to getting responses
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed for responses
metadata: Optional[dict] = None, # Pass through metadata to interface
put_inner_thoughts_first: bool = True,
) -> LettaUsageStatistics:
"""Send a list of messages to the agent
"""Send a list of messages to the agent.
If the messages are of type MessageCreate, we need to turn them into
Message objects first before sending them through step.
Otherwise, we can pass them in directly.
If messages are of type MessageCreate, convert them to Message objects before sending.
"""
message_objects: List[Message] = []
if all(isinstance(m, MessageCreate) for m in messages):
for message in messages:
assert isinstance(message, MessageCreate)
# If wrapping is enabled, wrap with metadata before placing content inside the Message object
if isinstance(message.content, str):
message_content = message.content
elif message.content and len(message.content) > 0 and isinstance(message.content[0], TextContent):
message_content = message.content[0].text
else:
assert message_content is not None, "Message content is empty"
if message.role == MessageRole.user and wrap_user_message:
message_content = system.package_user_message(user_message=message_content)
elif message.role == MessageRole.system and wrap_system_message:
message_content = system.package_system_message(system_message=message_content)
else:
raise ValueError(f"Invalid message role: {message.role}")
# Create the Message object
message_objects.append(
Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message_content)] if message_content else [],
name=message.name,
# assigned later?
model=None,
# irrelevant
tool_calls=None,
tool_call_id=None,
)
)
message_objects = [prepare_input_message_create(m, agent_id, wrap_user_message, wrap_system_message) for m in messages]
elif all(isinstance(m, Message) for m in messages):
for message in messages:
assert isinstance(message, Message)
message_objects.append(message)
message_objects = messages
else:
raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(message) for message in messages]}")
raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(m) for m in messages]}")
# Store metadata in interface if provided
if metadata and hasattr(interface, "metadata"):

View File

@@ -639,7 +639,7 @@ class AgentManager:
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
if len(diff) > 0: # there was a diff
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Swap the system message out (only if there is a diff)
message = PydanticMessage.dict_to_message(

View File

@@ -47,7 +47,7 @@ def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
return decorator_retry
def retry_until_success(max_attempts=10, sleep_time_seconds=4):
def retry_until_success(max_attempts=10, sleep_time_seconds=4, flush_tables_in_between: bool = False):
"""
Decorator to retry a function until it succeeds or the maximum number of attempts is reached.
@@ -58,13 +58,25 @@ def retry_until_success(max_attempts=10, sleep_time_seconds=4):
def decorator_retry(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
from letta.orm.base import Base
from letta.server.db import db_context
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"\033[93mAttempt {attempt} failed with error:\n{e}\033[0m")
# Clear tables before retrying
if flush_tables_in_between:
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
session.execute(table.delete()) # Truncate table
session.commit()
if attempt == max_attempts:
raise
time.sleep(sleep_time_seconds)
return wrapper

View File

@@ -10,12 +10,11 @@ from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.tool import Tool
from tests.helpers.utils import retry_until_success
from tests.utils import wait_for_incoming_message
@pytest.fixture(autouse=True)
def clear_tables():
def truncate_database():
from letta.server.db import db_context
with db_context() as session:
@@ -86,7 +85,6 @@ def roll_dice_tool(client):
yield tool
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
secret_word = "banana"
@@ -125,7 +123,6 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
print(response.messages)
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_simple(client):
worker_tags_123 = ["worker", "user-123"]
worker_tags_456 = ["worker", "user-456"]
@@ -141,20 +138,20 @@ def test_send_message_to_agents_with_tags_simple(client):
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent_state = client.create_agent(name="manager_agent", tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 non-matching worker agents (These should NOT get the message)
worker_agents_123 = []
for _ in range(3):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_123)
for idx in range(2):
worker_agent_state = client.create_agent(name=f"not_worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_123)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents_123.append(worker_agent)
# Create 3 worker agents that should get the message
worker_agents_456 = []
for _ in range(3):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_456)
for idx in range(2):
worker_agent_state = client.create_agent(name=f"worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_456)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents_456.append(worker_agent)
@@ -203,7 +200,6 @@ def test_send_message_to_agents_with_tags_simple(client):
client.delete_agent(agent.agent_state.id)
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
worker_tags = ["dice-rollers"]
@@ -252,7 +248,6 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
client.delete_agent(agent.agent_state.id)
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
@@ -285,7 +280,6 @@ def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
assert "banana" in worker_agent_state.memory.compile().lower()
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_agents_async_simple(client):
"""
Test two agents with multi-agent tools sending messages back and forth to count to 5.