fix: Robust new streaming interface for multi-agent tooling (#907)
This commit is contained in:
@@ -53,6 +53,7 @@ BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
|
||||
MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3
|
||||
MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60
|
||||
MULTI_AGENT_CONCURRENT_SENDS = 15
|
||||
|
||||
# The name of the tool used to send message to the user
|
||||
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT
|
||||
from letta.functions.helpers import async_send_message_with_retries, execute_send_message_to_agent, fire_and_forget_send_to_agent
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_all_tags_async,
|
||||
execute_send_message_to_agent,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agent import Agent
|
||||
@@ -22,12 +24,13 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
|
||||
Returns:
|
||||
str: The response from the target agent.
|
||||
"""
|
||||
message = (
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
|
||||
|
||||
return execute_send_message_to_agent(
|
||||
sender_agent=self,
|
||||
messages=messages,
|
||||
@@ -81,33 +84,4 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags:
|
||||
have an entry in the returned list.
|
||||
"""
|
||||
|
||||
server = get_letta_server()
|
||||
|
||||
message = (
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
# Retrieve agents that match ALL specified tags
|
||||
matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100)
|
||||
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
|
||||
|
||||
async def send_messages_to_all_agents():
|
||||
tasks = [
|
||||
async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=self,
|
||||
target_agent_id=agent_state.id,
|
||||
messages=messages,
|
||||
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
|
||||
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
|
||||
logging_prefix="[send_message_to_agents_matching_all_tags]",
|
||||
)
|
||||
for agent_state in matching_agents
|
||||
]
|
||||
# Run all tasks in parallel
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Run the async function and return results
|
||||
return asyncio.run(send_messages_to_all_agents())
|
||||
return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags))
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from random import uniform
|
||||
from typing import Any, List, Optional, Union
|
||||
@@ -12,13 +11,17 @@ from letta.constants import (
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
||||
DEFAULT_MESSAGE_TOOL,
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
MULTI_AGENT_CONCURRENT_SENDS,
|
||||
MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
|
||||
MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
|
||||
)
|
||||
from letta.functions.interface import MultiAgentMessagingInterface
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import AssistantMessage
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
|
||||
@@ -249,29 +252,48 @@ def generate_import_code(module_attr_map: Optional[dict]):
|
||||
def parse_letta_response_for_assistant_message(
|
||||
target_agent_id: str,
|
||||
letta_response: LettaResponse,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> Optional[str]:
|
||||
messages = []
|
||||
# This is not ideal, but we would like to return something rather than nothing
|
||||
fallback_reasoning = []
|
||||
for m in letta_response.messages:
|
||||
if isinstance(m, AssistantMessage):
|
||||
messages.append(m.content)
|
||||
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
|
||||
try:
|
||||
messages.append(json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg])
|
||||
except Exception: # TODO: Make this more specific
|
||||
continue
|
||||
elif isinstance(m, ReasoningMessage):
|
||||
fallback_reasoning.append(m.reasoning)
|
||||
|
||||
if messages:
|
||||
messages_str = "\n".join(messages)
|
||||
return f"Agent {target_agent_id} said: '{messages_str}'"
|
||||
return f"{target_agent_id} said: '{messages_str}'"
|
||||
else:
|
||||
messages_str = "\n".join(fallback_reasoning)
|
||||
return f"Agent {target_agent_id}'s inner thoughts: '{messages_str}'"
|
||||
return f"No response from {target_agent_id}"
|
||||
|
||||
|
||||
async def async_execute_send_message_to_agent(
|
||||
sender_agent: "Agent",
|
||||
messages: List[MessageCreate],
|
||||
other_agent_id: str,
|
||||
log_prefix: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Async helper to:
|
||||
1) validate the target agent exists & is in the same org,
|
||||
2) send a message via async_send_message_with_retries.
|
||||
"""
|
||||
server = get_letta_server()
|
||||
|
||||
# 1. Validate target agent
|
||||
try:
|
||||
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Target agent {other_agent_id} either does not exist or is not in org " f"({sender_agent.user.organization_id}).")
|
||||
|
||||
# 2. Use your async retry logic
|
||||
return await async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=sender_agent,
|
||||
target_agent_id=other_agent_id,
|
||||
messages=messages,
|
||||
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
|
||||
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
|
||||
logging_prefix=log_prefix,
|
||||
)
|
||||
|
||||
|
||||
def execute_send_message_to_agent(
|
||||
@@ -281,53 +303,43 @@ def execute_send_message_to_agent(
|
||||
log_prefix: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Helper function to send a message to a specific Letta agent.
|
||||
|
||||
Args:
|
||||
sender_agent ("Agent"): The sender agent object.
|
||||
message (str): The message to send.
|
||||
other_agent_id (str): The identifier of the target Letta agent.
|
||||
log_prefix (str): Logging prefix for retries.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The response from the Letta agent if required by the caller.
|
||||
Synchronous wrapper that calls `async_execute_send_message_to_agent` using asyncio.run.
|
||||
This function must be called from a synchronous context (i.e., no running event loop).
|
||||
"""
|
||||
server = get_letta_server()
|
||||
return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix))
|
||||
|
||||
# Ensure the target agent is in the same org
|
||||
try:
|
||||
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user)
|
||||
except NoResultFound:
|
||||
raise ValueError(
|
||||
f"The passed-in agent_id {other_agent_id} either does not exist, "
|
||||
f"or does not belong to the same org ({sender_agent.user.organization_id})."
|
||||
)
|
||||
|
||||
# Async logic to send a message with retries and timeout
|
||||
async def async_send():
|
||||
return await async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=sender_agent,
|
||||
target_agent_id=other_agent_id,
|
||||
messages=messages,
|
||||
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
|
||||
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
|
||||
logging_prefix=log_prefix,
|
||||
)
|
||||
async def send_message_to_agent_no_stream(
|
||||
server: "SyncServer",
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
metadata: Optional[dict] = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
A simpler helper to send messages to a single agent WITHOUT streaming.
|
||||
Returns a LettaResponse containing the final messages.
|
||||
"""
|
||||
interface = MultiAgentMessagingInterface()
|
||||
if metadata:
|
||||
interface.metadata = metadata
|
||||
|
||||
# Run in the current event loop or create one if needed
|
||||
try:
|
||||
return asyncio.run(async_send())
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
return loop.run_until_complete(async_send())
|
||||
else:
|
||||
raise
|
||||
# Offload the synchronous `send_messages` call
|
||||
usage_stats = await asyncio.to_thread(
|
||||
server.send_messages,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
messages=messages,
|
||||
interface=interface,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
final_messages = interface.get_captured_send_messages()
|
||||
return LettaResponse(messages=final_messages, usage=usage_stats)
|
||||
|
||||
|
||||
async def async_send_message_with_retries(
|
||||
server,
|
||||
server: "SyncServer",
|
||||
sender_agent: "Agent",
|
||||
target_agent_id: str,
|
||||
messages: List[MessageCreate],
|
||||
@@ -335,57 +347,34 @@ async def async_send_message_with_retries(
|
||||
timeout: int,
|
||||
logging_prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Shared helper coroutine to send a message to an agent with retries and a timeout.
|
||||
|
||||
Args:
|
||||
server: The Letta server instance (from get_letta_server()).
|
||||
sender_agent (Agent): The agent initiating the send action.
|
||||
target_agent_id (str): The ID of the agent to send the message to.
|
||||
message_text (str): The text to send as the user message.
|
||||
max_retries (int): Maximum number of retries for the request.
|
||||
timeout (int): Maximum time to wait for a response (in seconds).
|
||||
logging_prefix (str): A prefix to append to logging
|
||||
Returns:
|
||||
str: The response or an error message.
|
||||
"""
|
||||
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# Wrap in a timeout
|
||||
response = await asyncio.wait_for(
|
||||
server.send_message_to_agent(
|
||||
send_message_to_agent_no_stream(
|
||||
server=server,
|
||||
agent_id=target_agent_id,
|
||||
actor=sender_agent.user,
|
||||
messages=messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
use_assistant_message=True,
|
||||
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract assistant message
|
||||
assistant_message = parse_letta_response_for_assistant_message(
|
||||
target_agent_id,
|
||||
response,
|
||||
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
)
|
||||
# Then parse out the assistant message
|
||||
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response)
|
||||
if assistant_message:
|
||||
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
|
||||
return assistant_message
|
||||
else:
|
||||
msg = f"(No response from agent {target_agent_id})"
|
||||
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
||||
sender_agent.logger.info(f"{logging_prefix} - raw response: {response.model_dump_json(indent=4)}")
|
||||
sender_agent.logger.info(f"{logging_prefix} - parsed assistant message: {assistant_message}")
|
||||
return msg
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})"
|
||||
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"(Error on attempt {attempt}/{max_retries} for agent {target_agent_id}: {e})"
|
||||
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
|
||||
@@ -393,10 +382,10 @@ async def async_send_message_with_retries(
|
||||
# Exponential backoff before retrying
|
||||
if attempt < max_retries:
|
||||
backoff = uniform(0.5, 2) * (2**attempt)
|
||||
sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent to agent send_message...sleeping for {backoff}")
|
||||
sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent-to-agent send_message...sleeping for {backoff}")
|
||||
await asyncio.sleep(backoff)
|
||||
else:
|
||||
sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}")
|
||||
sender_agent.logger.error(f"{logging_prefix} - Fatal error: {error_msg}")
|
||||
raise Exception(error_msg)
|
||||
|
||||
|
||||
@@ -482,3 +471,43 @@ def fire_and_forget_send_to_agent(
|
||||
except RuntimeError:
|
||||
# Means no event loop is running in this thread
|
||||
run_in_background_thread(background_task())
|
||||
|
||||
|
||||
async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]:
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
# Retrieve up to 100 matching agents
|
||||
matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100)
|
||||
|
||||
# Create a system message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
|
||||
|
||||
# Possibly limit concurrency to avoid meltdown:
|
||||
sem = asyncio.Semaphore(MULTI_AGENT_CONCURRENT_SENDS)
|
||||
|
||||
async def _send_single(agent_state):
|
||||
async with sem:
|
||||
return await async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=sender_agent,
|
||||
target_agent_id=agent_state.id,
|
||||
messages=messages,
|
||||
max_retries=3,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in matching_agents]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
final = []
|
||||
for r in results:
|
||||
if isinstance(r, Exception):
|
||||
final.append(str(r))
|
||||
else:
|
||||
final.append(r)
|
||||
return final
|
||||
|
||||
75
letta/functions/interface.py
Normal file
75
letta/functions/interface.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.interface import AgentInterface
|
||||
from letta.schemas.letta_message import AssistantMessage, LettaMessage
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
class MultiAgentMessagingInterface(AgentInterface):
|
||||
"""
|
||||
A minimal interface that captures *only* calls to the 'send_message' function
|
||||
by inspecting msg_obj.tool_calls. We parse out the 'message' field from the
|
||||
JSON function arguments and store it as an AssistantMessage.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._captured_messages: List[AssistantMessage] = []
|
||||
self.metadata = {}
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Ignore internal monologue."""
|
||||
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Ignore normal assistant messages (only capturing send_message calls)."""
|
||||
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""
|
||||
Called whenever the agent logs a function call. We'll inspect msg_obj.tool_calls:
|
||||
- If tool_calls include a function named 'send_message', parse its arguments
|
||||
- Extract the 'message' field
|
||||
- Save it as an AssistantMessage in self._captured_messages
|
||||
"""
|
||||
if not msg_obj or not msg_obj.tool_calls:
|
||||
return
|
||||
|
||||
for tool_call in msg_obj.tool_calls:
|
||||
if not tool_call.function:
|
||||
continue
|
||||
if tool_call.function.name != DEFAULT_MESSAGE_TOOL:
|
||||
# Skip any other function calls
|
||||
continue
|
||||
|
||||
# Now parse the JSON in tool_call.function.arguments
|
||||
func_args_str = tool_call.function.arguments or ""
|
||||
try:
|
||||
data = json.loads(func_args_str)
|
||||
# Extract the 'message' key if present
|
||||
content = data.get(DEFAULT_MESSAGE_TOOL_KWARG, str(data))
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse, store the raw string
|
||||
content = func_args_str
|
||||
|
||||
# Store as an AssistantMessage
|
||||
new_msg = AssistantMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
content=content,
|
||||
)
|
||||
self._captured_messages.append(new_msg)
|
||||
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Ignore user messages."""
|
||||
|
||||
def step_complete(self):
|
||||
"""No streaming => no step boundaries."""
|
||||
|
||||
def step_yield(self):
|
||||
"""No streaming => no final yield needed."""
|
||||
|
||||
def get_captured_send_messages(self) -> List[LettaMessage]:
|
||||
"""
|
||||
Returns only the messages extracted from 'send_message' calls.
|
||||
"""
|
||||
return self._captured_messages
|
||||
@@ -315,7 +315,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
# extra prints
|
||||
self.debug = False
|
||||
self.timeout = 30
|
||||
self.timeout = 10 * 60 # 10 minute timeout
|
||||
|
||||
def _reset_inner_thoughts_json_reader(self):
|
||||
# A buffer for accumulating function arguments (we want to buffer keys and run checks on each one)
|
||||
@@ -330,7 +330,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
while self._active:
|
||||
try:
|
||||
# Wait until there is an item in the deque or the stream is deactivated
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
break # Exit the loop if we timeout
|
||||
|
||||
|
||||
91
tests/manual_test_multi_agent_broadcast_large.py
Normal file
91
tests/manual_test_multi_agent_broadcast_large.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta import create_client
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from tests.integration_test_summarizer import LLM_CONFIG_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client():
|
||||
filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-haiku.json")
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
client = create_client()
|
||||
client.set_default_llm_config(llm_config)
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool(client):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
return "Rolled a 5!"
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(roll_dice)
|
||||
source_type = "python"
|
||||
description = "test_description"
|
||||
tags = ["test"]
|
||||
|
||||
tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
def test_multi_agent_large(client, roll_dice_tool):
|
||||
manager_tags = ["manager"]
|
||||
worker_tags = ["helpers"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags + manager_tags, match_all_tags=True)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(
|
||||
name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags
|
||||
)
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
worker_agents = []
|
||||
num_workers = 50
|
||||
for idx in tqdm(range(num_workers)):
|
||||
worker_agent_state = client.create_agent(
|
||||
name=f"worker-{idx}", include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id]
|
||||
)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
|
||||
client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=broadcast_message,
|
||||
)
|
||||
|
||||
# Please manually inspect the agent results
|
||||
@@ -173,7 +173,7 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
# Search the sender agent for the response from another agent
|
||||
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
|
||||
found = False
|
||||
target_snippet = f"Agent {other_agent_obj.agent_state.id} said:"
|
||||
target_snippet = f"{other_agent_obj.agent_state.id} said:"
|
||||
|
||||
for m in in_context_messages:
|
||||
if target_snippet in m.text:
|
||||
|
||||
Reference in New Issue
Block a user