diff --git a/letta/agent.py b/letta/agent.py index 9ff0f437..fefca2f5 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -108,9 +108,6 @@ class Agent(BaseAgent): if not isinstance(rule, TerminalToolRule): warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.") break - # add default rule for having send_message be a terminal tool - if agent_state.tool_rules is None: - agent_state.tool_rules = [] self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) diff --git a/letta/constants.py b/letta/constants.py index ee062cda..acaaca2c 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -50,7 +50,7 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", " # Base memory tools CAN be edited, and are added by default by the server BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] # Multi agent tools -MULTI_AGENT_TOOLS = ["send_message_to_specific_agent", "send_message_to_agents_matching_all_tags"] +MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"] MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3 MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60 diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index 40202ed9..a8641b2f 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -1,80 +1,86 @@ import asyncio -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT -from letta.functions.helpers import async_send_message_with_retries -from letta.orm.errors import NoResultFound +from letta.functions.helpers import async_send_message_with_retries, execute_send_message_to_agent, fire_and_forget_send_to_agent +from letta.schemas.enums import MessageRole +from letta.schemas.message import MessageCreate from letta.server.rest_api.utils import get_letta_server if TYPE_CHECKING: from letta.agent import Agent -def send_message_to_specific_agent(self: "Agent", message: str, other_agent_id: str) -> Optional[str]: +def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_agent_id: str) -> str: """ - Send a message to a specific Letta agent within the same organization. + Sends a message to a specific Letta agent within the same organization and waits for a response. The sender's identity is automatically included, so no explicit introduction is needed in the message. This function is designed for two-way communication where a reply is expected. Args: - message (str): The message to be sent to the target Letta agent. - other_agent_id (str): The identifier of the target Letta agent. + message (str): The content of the message to be sent to the target agent. + other_agent_id (str): The unique identifier of the target Letta agent. Returns: - Optional[str]: The response from the Letta agent. It's possible that the agent does not respond. + str: The response from the target agent. """ - server = get_letta_server() + messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)] + return execute_send_message_to_agent( + sender_agent=self, + messages=messages, + other_agent_id=other_agent_id, + log_prefix="[send_message_to_agent_and_wait_for_reply]", + ) - # Ensure the target agent is in the same org - try: - server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=self.user) - except NoResultFound: - raise ValueError( - f"The passed-in agent_id {other_agent_id} either does not exist, " - f"or does not belong to the same org ({self.user.organization_id})." - ) - # Async logic to send a message with retries and timeout - async def async_send_single_agent(): - return await async_send_message_with_retries( - server=server, - sender_agent=self, - target_agent_id=other_agent_id, - message_text=message, - max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, # or your chosen constants - timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, # e.g., 1200 for 20 min - logging_prefix="[send_message_to_specific_agent]", - ) +def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str) -> str: + """ + Sends a message to a specific Letta agent within the same organization. The sender's identity is automatically included, so no explicit introduction is required in the message. This function does not expect a response from the target agent, making it suitable for notifications or one-way communication. - # Run in the current event loop or create one if needed - try: - return asyncio.run(async_send_single_agent()) - except RuntimeError: - # e.g., in case there's already an active loop - loop = asyncio.get_event_loop() - if loop.is_running(): - return loop.run_until_complete(async_send_single_agent()) - else: - raise + Args: + message (str): The content of the message to be sent to the target agent. + other_agent_id (str): The unique identifier of the target Letta agent. + + Returns: + str: A confirmation message indicating the message was successfully sent. + """ + message = ( + f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, " + f"make sure to use the 'send_message_to_agent_async' tool, or the agent will not receive your message] " + f"{message}" + ) + messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)] + + # Do the actual fire-and-forget + fire_and_forget_send_to_agent( + sender_agent=self, + messages=messages, + other_agent_id=other_agent_id, + log_prefix="[send_message_to_agent_async]", + use_retries=False, # or True if you want to use async_send_message_with_retries + ) + + # Immediately return to caller + return "Successfully sent message" def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]: """ - Send a message to all agents in the same organization that match ALL of the given tags. - - Messages are sent in parallel for improved performance, with retries on flaky calls and timeouts for long-running requests. - This function does not use a cursor (pagination) and enforces a limit of 100 agents. + Sends a message to all agents within the same organization that match all of the specified tags. Messages are dispatched in parallel for improved performance, with retries to handle transient issues and timeouts to ensure responsiveness. This function enforces a limit of 100 agents and does not support pagination (cursor-based queries). Each agent must match all specified tags (`match_all_tags=True`) to be included. Args: - message (str): The message to be sent to each matching agent. - tags (List[str]): The list of tags that each agent must have (match_all_tags=True). + message (str): The content of the message to be sent to each matching agent. + tags (List[str]): A list of tags that an agent must possess to receive the message. Returns: - List[str]: A list of responses from the agents that match all tags. - Each response corresponds to one agent. + List[str]: A list of responses from the agents that matched all tags. Each + response corresponds to a single agent. Agents that do not respond will not + have an entry in the returned list. """ + server = get_letta_server() # Retrieve agents that match ALL specified tags matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100) + messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)] async def send_messages_to_all_agents(): tasks = [ @@ -82,7 +88,7 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: server=server, sender_agent=self, target_agent_id=agent_state.id, - message_text=message, + messages=messages, max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, logging_prefix="[send_message_to_agents_matching_all_tags]", diff --git a/letta/functions/functions.py b/letta/functions/functions.py index 4195cbee..d5e9d088 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -122,7 +122,6 @@ def get_json_schema_from_module(module_name: str, function_name: str) -> dict: generated_schema = generate_schema(attr) return generated_schema - except ModuleNotFoundError: raise ModuleNotFoundError(f"Module '{module_name}' not found.") except AttributeError: diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 1718ffef..24492119 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -1,15 +1,25 @@ +import asyncio import json -from typing import Any, Optional, Union +import threading +from random import uniform +from typing import Any, List, Optional, Union import humps from composio.constants import DEFAULT_ENTITY_ID from pydantic import BaseModel -from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG -from letta.schemas.enums import MessageRole +from letta.constants import ( + COMPOSIO_ENTITY_ENV_VAR_KEY, + DEFAULT_MESSAGE_TOOL, + DEFAULT_MESSAGE_TOOL_KWARG, + MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, + MULTI_AGENT_SEND_MESSAGE_TIMEOUT, +) +from letta.orm.errors import NoResultFound from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage from letta.schemas.letta_response import LettaResponse from letta.schemas.message import MessageCreate +from letta.server.rest_api.utils import get_letta_server # TODO: This is kind of hacky, as this is used to search up the action later on composio's side @@ -259,16 +269,63 @@ def parse_letta_response_for_assistant_message( return None -import asyncio -from random import uniform -from typing import Optional +def execute_send_message_to_agent( + sender_agent: "Agent", + messages: List[MessageCreate], + other_agent_id: str, + log_prefix: str, +) -> Optional[str]: + """ + Helper function to send a message to a specific Letta agent. + + Args: + sender_agent ("Agent"): The sender agent object. + message (str): The message to send. + other_agent_id (str): The identifier of the target Letta agent. + log_prefix (str): Logging prefix for retries. + + Returns: + Optional[str]: The response from the Letta agent if required by the caller. + """ + server = get_letta_server() + + # Ensure the target agent is in the same org + try: + server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user) + except NoResultFound: + raise ValueError( + f"The passed-in agent_id {other_agent_id} either does not exist, " + f"or does not belong to the same org ({sender_agent.user.organization_id})." + ) + + # Async logic to send a message with retries and timeout + async def async_send(): + return await async_send_message_with_retries( + server=server, + sender_agent=sender_agent, + target_agent_id=other_agent_id, + messages=messages, + max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, + timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, + logging_prefix=log_prefix, + ) + + # Run in the current event loop or create one if needed + try: + return asyncio.run(async_send()) + except RuntimeError: + loop = asyncio.get_event_loop() + if loop.is_running(): + return loop.run_until_complete(async_send()) + else: + raise async def async_send_message_with_retries( server, sender_agent: "Agent", target_agent_id: str, - message_text: str, + messages: List[MessageCreate], max_retries: int, timeout: int, logging_prefix: Optional[str] = None, @@ -290,7 +347,6 @@ async def async_send_message_with_retries( logging_prefix = logging_prefix or "[async_send_message_with_retries]" for attempt in range(1, max_retries + 1): try: - messages = [MessageCreate(role=MessageRole.user, content=message_text, name=sender_agent.agent_state.name)] # Wrap in a timeout response = await asyncio.wait_for( server.send_message_to_agent( @@ -334,4 +390,88 @@ async def async_send_message_with_retries( await asyncio.sleep(backoff) else: sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}") - return error_msg + raise Exception(error_msg) + + +def fire_and_forget_send_to_agent( + sender_agent: "Agent", + messages: List[MessageCreate], + other_agent_id: str, + log_prefix: str, + use_retries: bool = False, +) -> None: + """ + Fire-and-forget send of messages to a specific agent. + Returns immediately in the calling thread, never blocks. + + Args: + sender_agent (Agent): The sender agent object. + server: The Letta server instance + messages (List[MessageCreate]): The messages to send. + other_agent_id (str): The ID of the target agent. + log_prefix (str): Prefix for logging. + use_retries (bool): If True, uses async_send_message_with_retries; + if False, calls server.send_message_to_agent directly. + """ + server = get_letta_server() + + # 1) Validate the target agent (raises ValueError if not in same org) + try: + server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user) + except NoResultFound: + raise ValueError( + f"The passed-in agent_id {other_agent_id} either does not exist, " + f"or does not belong to the same org ({sender_agent.user.organization_id})." + ) + + # 2) Define the async coroutine to run + async def background_task(): + try: + if use_retries: + result = await async_send_message_with_retries( + server=server, + sender_agent=sender_agent, + target_agent_id=other_agent_id, + messages=messages, + max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, + timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, + logging_prefix=log_prefix, + ) + sender_agent.logger.info(f"{log_prefix} fire-and-forget success with retries: {result}") + else: + # Direct call to server.send_message_to_agent, no retry logic + await server.send_message_to_agent( + agent_id=other_agent_id, + actor=sender_agent.user, + messages=messages, + stream_steps=False, + stream_tokens=False, + use_assistant_message=True, + assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, + ) + sender_agent.logger.info(f"{log_prefix} fire-and-forget success (no retries).") + except Exception as e: + sender_agent.logger.error(f"{log_prefix} fire-and-forget send failed: {e}") + + # 3) Helper to run the coroutine in a brand-new event loop in a separate thread + def run_in_background_thread(coro): + def runner(): + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + loop.run_until_complete(coro) + finally: + loop.close() + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + + # 4) Try to schedule the coroutine in an existing loop, else spawn a thread + try: + loop = asyncio.get_running_loop() + # If we get here, a loop is running; schedule the coroutine in background + loop.create_task(background_task()) + except RuntimeError: + # Means no event loop is running in this thread + run_in_background_thread(background_task()) diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 781ab383..515f77c2 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, List, Optional from sqlalchemy import JSON, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship +from letta.constants import MULTI_AGENT_TOOLS from letta.orm.block import Block from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn from letta.orm.message import Message @@ -15,7 +16,7 @@ from letta.schemas.agent import AgentType from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory -from letta.schemas.tool_rule import ToolRule +from letta.schemas.tool_rule import TerminalToolRule, ToolRule if TYPE_CHECKING: from letta.orm.agents_tags import AgentsTags @@ -114,6 +115,16 @@ class Agent(SqlalchemyBase, OrganizationMixin): def to_pydantic(self) -> PydanticAgentState: """converts to the basic pydantic model counterpart""" + # add default rule for having send_message be a terminal tool + tool_rules = self.tool_rules + if not tool_rules: + tool_rules = [ + TerminalToolRule(tool_name="send_message"), + ] + + for tool_name in MULTI_AGENT_TOOLS: + tool_rules.append(TerminalToolRule(tool_name=tool_name)) + state = { "id": self.id, "organization_id": self.organization_id, @@ -123,7 +134,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): "tools": self.tools, "sources": [source.to_pydantic() for source in self.sources], "tags": [t.tag for t in self.tags], - "tool_rules": self.tool_rules, + "tool_rules": tool_rules, "system": self.system, "agent_type": self.agent_type, "llm_config": self.llm_config, @@ -136,4 +147,5 @@ class Agent(SqlalchemyBase, OrganizationMixin): "updated_at": self.updated_at, "tool_exec_environment_variables": self.tool_exec_environment_variables, } + return self.__pydantic_model__(**state) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 375417f8..8cdd686a 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -1,6 +1,7 @@ from datetime import datetime from enum import Enum from functools import wraps +from pprint import pformat from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union from sqlalchemy import String, and_, func, or_, select @@ -504,7 +505,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): model.metadata = self.metadata_ return model - def to_record(self) -> "BaseModel": - """Deprecated accessor for to_pydantic""" - logger.warning("to_record is deprecated, use to_pydantic instead.") - return self.to_pydantic() + def pretty_print_columns(self) -> str: + """ + Pretty prints all columns of the current SQLAlchemy object along with their values. + """ + if not hasattr(self, "__table__") or not hasattr(self.__table__, "columns"): + raise NotImplementedError("This object does not have a '__table__.columns' attribute.") + + # Iterate over the columns correctly + column_data = {column.name: getattr(self, column.name, None) for column in self.__table__.columns} + + return pformat(column_data, indent=4, sort_dicts=True) diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 05d6653e..70aa86ff 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -97,6 +97,14 @@ class LLMConfig(BaseModel): model_wrapper=None, context_window=128000, ) + elif model_name == "gpt-4o": + return cls( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model_wrapper=None, + context_window=128000, + ) elif model_name == "letta": return cls( model="memgpt-openai", diff --git a/letta/server/server.py b/letta/server/server.py index 1e0eea9e..7eb14154 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1290,7 +1290,7 @@ class SyncServer(Server): llm_config.model_endpoint_type not in ["openai", "anthropic"] or "inference.memgpt.ai" in llm_config.model_endpoint ): warnings.warn( - "Token streaming is only supported for models with type 'openai', 'anthropic', or `inference.memgpt.ai` in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." + f"Token streaming is only supported for models with type 'openai' or 'anthropic' in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." ) stream_tokens = False diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 01e4c855..2e831f92 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -4,6 +4,7 @@ from typing import List, Optional from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS from letta.functions.functions import derive_openai_json_schema, load_function_set +from letta.log import get_logger from letta.orm.enums import ToolType # TODO: Remove this once we translate all of these to the ORM @@ -14,6 +15,8 @@ from letta.schemas.tool import ToolUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types, printd +logger = get_logger(__name__) + class ToolManager: """Manager class to handle business logic related to Tools.""" @@ -102,7 +105,20 @@ class ToolManager: limit=limit, organization_id=actor.organization_id, ) - return [tool.to_pydantic() for tool in tools] + + # Remove any malformed tools + results = [] + for tool in tools: + try: + pydantic_tool = tool.to_pydantic() + results.append(pydantic_tool) + except (ValueError, ModuleNotFoundError, AttributeError) as e: + logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}") + logger.warning("Deleted tool: ") + logger.warning(tool.pretty_print_columns()) + self.delete_tool_by_id(tool.id, actor=actor) + + return results @enforce_types def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool: diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 8736825b..92c929f9 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,6 +1,4 @@ import json -import secrets -import string import pytest @@ -9,30 +7,33 @@ from letta import LocalClient, create_client from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.llm_config import LLMConfig +from letta.schemas.memory import ChatMemory from tests.helpers.utils import retry_until_success +from tests.utils import wait_for_incoming_message -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def client(): client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) + client.set_default_llm_config(LLMConfig.default_config("gpt-4o")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) yield client -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def agent_obj(client: LocalClient): """Create a test agent that we can call functions on""" - agent_state = client.create_agent(include_multi_agent_tools=True) + send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply") + agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id]) agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) yield agent_obj - client.delete_agent(agent_obj.agent_state.id) + # client.delete_agent(agent_obj.agent_state.id) -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def other_agent_obj(client: LocalClient): """Create another test agent that we can call functions on""" agent_state = client.create_agent(include_multi_agent_tools=False) @@ -119,18 +120,18 @@ def test_recall(client, agent_obj): # This test is nondeterministic, so we retry until we get the perfect behavior from the LLM @retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_send_message_to_agent(client, agent_obj, other_agent_obj): - long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10)) + secret_word = "banana" # Encourage the agent to send a message to the other agent_obj with the secret string client.send_message( agent_id=agent_obj.agent_state.id, role="user", - message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} with the secret password={long_random_string}", + message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} to share the secret word: {secret_word}!", ) # Conversation search the other agent - result = base_functions.conversation_search(other_agent_obj, long_random_string) - assert long_random_string in result + result = base_functions.conversation_search(other_agent_obj, secret_word) + assert secret_word in result # Search the sender agent for the response from another agent in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) @@ -144,7 +145,7 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj): print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.text for m in in_context_messages[1:]])}") if not found: - pytest.fail(f"Was not able to find an instance of the target snippet: {target_snippet}") + raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}") # Test that the agent can still receive messages fine response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?") @@ -161,10 +162,11 @@ def test_send_message_to_agents_with_tags(client): for agent in prev_worker_agents: client.delete_agent(agent.id) - long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10)) + secret_word = "banana" # Create "manager" agent - manager_agent_state = client.create_agent(include_multi_agent_tools=True) + send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") + manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) # Create 3 worker agents @@ -187,7 +189,7 @@ def test_send_message_to_agents_with_tags(client): response = client.send_message( agent_id=manager_agent.agent_state.id, role="user", - message=f"Send a message to all agents with tags {worker_tags} informing them of the secret password={long_random_string}", + message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!", ) for m in response.messages: @@ -201,8 +203,8 @@ def test_send_message_to_agents_with_tags(client): # Conversation search the worker agents for agent in worker_agents: - result = base_functions.conversation_search(agent, long_random_string) - assert long_random_string in result + result = base_functions.conversation_search(agent, secret_word) + assert secret_word in result # Test that the agent can still receive messages fine response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") @@ -212,3 +214,56 @@ def test_send_message_to_agents_with_tags(client): client.delete_agent(manager_agent_state.id) for agent in worker_agents: client.delete_agent(agent.agent_state.id) + + +@retry_until_success(max_attempts=5, sleep_time_seconds=2) +def test_agents_async_simple(client): + """ + Test two agents with multi-agent tools sending messages back and forth to count to 5. + The chain is started by prompting one of the agents. + """ + # Cleanup from potentially failed previous runs + existing_agents = client.server.agent_manager.list_agents(client.user) + for agent in existing_agents: + client.delete_agent(agent.id) + + # Create two agents with multi-agent tools + send_message_to_agent_async_tool_id = client.get_tool_id(name="send_message_to_agent_async") + memory_a = ChatMemory( + human="Chad - I'm interested in hearing poem.", + persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).", + ) + charles_state = client.create_agent(name="charles", memory=memory_a, tool_ids=[send_message_to_agent_async_tool_id]) + charles = client.server.load_agent(agent_id=charles_state.id, actor=client.user) + + memory_b = ChatMemory( + human="No human - you are to only communicate with the other AI agent.", + persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.", + ) + sarah_state = client.create_agent(name="sarah", memory=memory_b, tool_ids=[send_message_to_agent_async_tool_id]) + + # Start the count chain with Agent1 + initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me." + client.send_message( + agent_id=charles.agent_state.id, + role="user", + message=initial_prompt, + ) + + found_in_charles = wait_for_incoming_message( + client=client, + agent_id=charles_state.id, + substring="[Incoming message from agent with ID", + max_wait_seconds=10, + sleep_interval=0.5, + ) + assert found_in_charles, "Charles never received the system message from Sarah (timed out)." + + found_in_sarah = wait_for_incoming_message( + client=client, + agent_id=sarah_state.id, + substring="[Incoming message from agent with ID", + max_wait_seconds=10, + sleep_interval=0.5, + ) + assert found_in_sarah, "Sarah never received the system message from Charles (timed out)." diff --git a/tests/utils.py b/tests/utils.py index 19a05a09..46d83ed7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import datetime import os +import time from datetime import datetime from importlib import util from typing import Dict, Iterator, List, Tuple @@ -8,6 +9,7 @@ import requests from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector +from letta.schemas.enums import MessageRole from letta.schemas.file import FileMetadata from letta.settings import TestSettings @@ -145,3 +147,27 @@ def with_qdrant_storage(storage: list[str]): storage.append("qdrant") return storage + + +def wait_for_incoming_message( + client, + agent_id: str, + substring: str = "[Incoming message from agent with ID", + max_wait_seconds: float = 10.0, + sleep_interval: float = 0.5, +) -> bool: + """ + Polls for up to `max_wait_seconds` to see if the agent's message list + contains a system message with `substring`. + Returns True if found, otherwise False after timeout. + """ + deadline = time.time() + max_wait_seconds + + while time.time() < deadline: + messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id) + # Check for the system message containing `substring` + if any(message.role == MessageRole.system and substring in (message.text or "") for message in messages): + return True + time.sleep(sleep_interval) + + return False