fix: new versions of send_message_to_agent that are async (#725)

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Charles Packer
2025-01-27 17:11:44 -08:00
committed by GitHub
parent 7eb44280c1
commit ec6e5d153c
12 changed files with 354 additions and 87 deletions

View File

@@ -108,9 +108,6 @@ class Agent(BaseAgent):
if not isinstance(rule, TerminalToolRule): if not isinstance(rule, TerminalToolRule):
warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.") warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.")
break break
# add default rule for having send_message be a terminal tool
if agent_state.tool_rules is None:
agent_state.tool_rules = []
self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)

View File

@@ -50,7 +50,7 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
# Base memory tools CAN be edited, and are added by default by the server # Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
# Multi agent tools # Multi agent tools
MULTI_AGENT_TOOLS = ["send_message_to_specific_agent", "send_message_to_agents_matching_all_tags"] MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3 MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3
MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60 MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60

View File

@@ -1,80 +1,86 @@
import asyncio import asyncio
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List
from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT
from letta.functions.helpers import async_send_message_with_retries from letta.functions.helpers import async_send_message_with_retries, execute_send_message_to_agent, fire_and_forget_send_to_agent
from letta.orm.errors import NoResultFound from letta.schemas.enums import MessageRole
from letta.schemas.message import MessageCreate
from letta.server.rest_api.utils import get_letta_server from letta.server.rest_api.utils import get_letta_server
if TYPE_CHECKING: if TYPE_CHECKING:
from letta.agent import Agent from letta.agent import Agent
def send_message_to_specific_agent(self: "Agent", message: str, other_agent_id: str) -> Optional[str]: def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_agent_id: str) -> str:
""" """
Send a message to a specific Letta agent within the same organization. Sends a message to a specific Letta agent within the same organization and waits for a response. The sender's identity is automatically included, so no explicit introduction is needed in the message. This function is designed for two-way communication where a reply is expected.
Args: Args:
message (str): The message to be sent to the target Letta agent. message (str): The content of the message to be sent to the target agent.
other_agent_id (str): The identifier of the target Letta agent. other_agent_id (str): The unique identifier of the target Letta agent.
Returns: Returns:
Optional[str]: The response from the Letta agent. It's possible that the agent does not respond. str: The response from the target agent.
""" """
server = get_letta_server() messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)]
return execute_send_message_to_agent(
sender_agent=self,
messages=messages,
other_agent_id=other_agent_id,
log_prefix="[send_message_to_agent_and_wait_for_reply]",
)
# Ensure the target agent is in the same org
try:
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=self.user)
except NoResultFound:
raise ValueError(
f"The passed-in agent_id {other_agent_id} either does not exist, "
f"or does not belong to the same org ({self.user.organization_id})."
)
# Async logic to send a message with retries and timeout def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str) -> str:
async def async_send_single_agent(): """
return await async_send_message_with_retries( Sends a message to a specific Letta agent within the same organization. The sender's identity is automatically included, so no explicit introduction is required in the message. This function does not expect a response from the target agent, making it suitable for notifications or one-way communication.
server=server,
sender_agent=self,
target_agent_id=other_agent_id,
message_text=message,
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, # or your chosen constants
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, # e.g., 1200 for 20 min
logging_prefix="[send_message_to_specific_agent]",
)
# Run in the current event loop or create one if needed Args:
try: message (str): The content of the message to be sent to the target agent.
return asyncio.run(async_send_single_agent()) other_agent_id (str): The unique identifier of the target Letta agent.
except RuntimeError:
# e.g., in case there's already an active loop Returns:
loop = asyncio.get_event_loop() str: A confirmation message indicating the message was successfully sent.
if loop.is_running(): """
return loop.run_until_complete(async_send_single_agent()) message = (
else: f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
raise f"make sure to use the 'send_message_to_agent_async' tool, or the agent will not receive your message] "
f"{message}"
)
messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)]
# Do the actual fire-and-forget
fire_and_forget_send_to_agent(
sender_agent=self,
messages=messages,
other_agent_id=other_agent_id,
log_prefix="[send_message_to_agent_async]",
use_retries=False, # or True if you want to use async_send_message_with_retries
)
# Immediately return to caller
return "Successfully sent message"
def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]: def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]:
""" """
Send a message to all agents in the same organization that match ALL of the given tags. Sends a message to all agents within the same organization that match all of the specified tags. Messages are dispatched in parallel for improved performance, with retries to handle transient issues and timeouts to ensure responsiveness. This function enforces a limit of 100 agents and does not support pagination (cursor-based queries). Each agent must match all specified tags (`match_all_tags=True`) to be included.
Messages are sent in parallel for improved performance, with retries on flaky calls and timeouts for long-running requests.
This function does not use a cursor (pagination) and enforces a limit of 100 agents.
Args: Args:
message (str): The message to be sent to each matching agent. message (str): The content of the message to be sent to each matching agent.
tags (List[str]): The list of tags that each agent must have (match_all_tags=True). tags (List[str]): A list of tags that an agent must possess to receive the message.
Returns: Returns:
List[str]: A list of responses from the agents that match all tags. List[str]: A list of responses from the agents that matched all tags. Each
Each response corresponds to one agent. response corresponds to a single agent. Agents that do not respond will not
have an entry in the returned list.
""" """
server = get_letta_server() server = get_letta_server()
# Retrieve agents that match ALL specified tags # Retrieve agents that match ALL specified tags
matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100) matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100)
messages = [MessageCreate(role=MessageRole.user, content=message, name=self.agent_state.name)]
async def send_messages_to_all_agents(): async def send_messages_to_all_agents():
tasks = [ tasks = [
@@ -82,7 +88,7 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags:
server=server, server=server,
sender_agent=self, sender_agent=self,
target_agent_id=agent_state.id, target_agent_id=agent_state.id,
message_text=message, messages=messages,
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
logging_prefix="[send_message_to_agents_matching_all_tags]", logging_prefix="[send_message_to_agents_matching_all_tags]",

View File

@@ -122,7 +122,6 @@ def get_json_schema_from_module(module_name: str, function_name: str) -> dict:
generated_schema = generate_schema(attr) generated_schema = generate_schema(attr)
return generated_schema return generated_schema
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found.") raise ModuleNotFoundError(f"Module '{module_name}' not found.")
except AttributeError: except AttributeError:

View File

@@ -1,15 +1,25 @@
import asyncio
import json import json
from typing import Any, Optional, Union import threading
from random import uniform
from typing import Any, List, Optional, Union
import humps import humps
from composio.constants import DEFAULT_ENTITY_ID from composio.constants import DEFAULT_ENTITY_ID
from pydantic import BaseModel from pydantic import BaseModel
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.constants import (
from letta.schemas.enums import MessageRole COMPOSIO_ENTITY_ENV_VAR_KEY,
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
)
from letta.orm.errors import NoResultFound
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage
from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import MessageCreate from letta.schemas.message import MessageCreate
from letta.server.rest_api.utils import get_letta_server
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side # TODO: This is kind of hacky, as this is used to search up the action later on composio's side
@@ -259,16 +269,63 @@ def parse_letta_response_for_assistant_message(
return None return None
import asyncio def execute_send_message_to_agent(
from random import uniform sender_agent: "Agent",
from typing import Optional messages: List[MessageCreate],
other_agent_id: str,
log_prefix: str,
) -> Optional[str]:
"""
Helper function to send a message to a specific Letta agent.
Args:
sender_agent ("Agent"): The sender agent object.
message (str): The message to send.
other_agent_id (str): The identifier of the target Letta agent.
log_prefix (str): Logging prefix for retries.
Returns:
Optional[str]: The response from the Letta agent if required by the caller.
"""
server = get_letta_server()
# Ensure the target agent is in the same org
try:
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user)
except NoResultFound:
raise ValueError(
f"The passed-in agent_id {other_agent_id} either does not exist, "
f"or does not belong to the same org ({sender_agent.user.organization_id})."
)
# Async logic to send a message with retries and timeout
async def async_send():
return await async_send_message_with_retries(
server=server,
sender_agent=sender_agent,
target_agent_id=other_agent_id,
messages=messages,
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
logging_prefix=log_prefix,
)
# Run in the current event loop or create one if needed
try:
return asyncio.run(async_send())
except RuntimeError:
loop = asyncio.get_event_loop()
if loop.is_running():
return loop.run_until_complete(async_send())
else:
raise
async def async_send_message_with_retries( async def async_send_message_with_retries(
server, server,
sender_agent: "Agent", sender_agent: "Agent",
target_agent_id: str, target_agent_id: str,
message_text: str, messages: List[MessageCreate],
max_retries: int, max_retries: int,
timeout: int, timeout: int,
logging_prefix: Optional[str] = None, logging_prefix: Optional[str] = None,
@@ -290,7 +347,6 @@ async def async_send_message_with_retries(
logging_prefix = logging_prefix or "[async_send_message_with_retries]" logging_prefix = logging_prefix or "[async_send_message_with_retries]"
for attempt in range(1, max_retries + 1): for attempt in range(1, max_retries + 1):
try: try:
messages = [MessageCreate(role=MessageRole.user, content=message_text, name=sender_agent.agent_state.name)]
# Wrap in a timeout # Wrap in a timeout
response = await asyncio.wait_for( response = await asyncio.wait_for(
server.send_message_to_agent( server.send_message_to_agent(
@@ -334,4 +390,88 @@ async def async_send_message_with_retries(
await asyncio.sleep(backoff) await asyncio.sleep(backoff)
else: else:
sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}") sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}")
return error_msg raise Exception(error_msg)
def fire_and_forget_send_to_agent(
sender_agent: "Agent",
messages: List[MessageCreate],
other_agent_id: str,
log_prefix: str,
use_retries: bool = False,
) -> None:
"""
Fire-and-forget send of messages to a specific agent.
Returns immediately in the calling thread, never blocks.
Args:
sender_agent (Agent): The sender agent object.
server: The Letta server instance
messages (List[MessageCreate]): The messages to send.
other_agent_id (str): The ID of the target agent.
log_prefix (str): Prefix for logging.
use_retries (bool): If True, uses async_send_message_with_retries;
if False, calls server.send_message_to_agent directly.
"""
server = get_letta_server()
# 1) Validate the target agent (raises ValueError if not in same org)
try:
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user)
except NoResultFound:
raise ValueError(
f"The passed-in agent_id {other_agent_id} either does not exist, "
f"or does not belong to the same org ({sender_agent.user.organization_id})."
)
# 2) Define the async coroutine to run
async def background_task():
try:
if use_retries:
result = await async_send_message_with_retries(
server=server,
sender_agent=sender_agent,
target_agent_id=other_agent_id,
messages=messages,
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
logging_prefix=log_prefix,
)
sender_agent.logger.info(f"{log_prefix} fire-and-forget success with retries: {result}")
else:
# Direct call to server.send_message_to_agent, no retry logic
await server.send_message_to_agent(
agent_id=other_agent_id,
actor=sender_agent.user,
messages=messages,
stream_steps=False,
stream_tokens=False,
use_assistant_message=True,
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
)
sender_agent.logger.info(f"{log_prefix} fire-and-forget success (no retries).")
except Exception as e:
sender_agent.logger.error(f"{log_prefix} fire-and-forget send failed: {e}")
# 3) Helper to run the coroutine in a brand-new event loop in a separate thread
def run_in_background_thread(coro):
def runner():
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.run_until_complete(coro)
finally:
loop.close()
thread = threading.Thread(target=runner, daemon=True)
thread.start()
# 4) Try to schedule the coroutine in an existing loop, else spawn a thread
try:
loop = asyncio.get_running_loop()
# If we get here, a loop is running; schedule the coroutine in background
loop.create_task(background_task())
except RuntimeError:
# Means no event loop is running in this thread
run_in_background_thread(background_task())

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Index, String from sqlalchemy import JSON, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.constants import MULTI_AGENT_TOOLS
from letta.orm.block import Block from letta.orm.block import Block
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
from letta.orm.message import Message from letta.orm.message import Message
@@ -15,7 +16,7 @@ from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory from letta.schemas.memory import Memory
from letta.schemas.tool_rule import ToolRule from letta.schemas.tool_rule import TerminalToolRule, ToolRule
if TYPE_CHECKING: if TYPE_CHECKING:
from letta.orm.agents_tags import AgentsTags from letta.orm.agents_tags import AgentsTags
@@ -114,6 +115,16 @@ class Agent(SqlalchemyBase, OrganizationMixin):
def to_pydantic(self) -> PydanticAgentState: def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart""" """converts to the basic pydantic model counterpart"""
# add default rule for having send_message be a terminal tool
tool_rules = self.tool_rules
if not tool_rules:
tool_rules = [
TerminalToolRule(tool_name="send_message"),
]
for tool_name in MULTI_AGENT_TOOLS:
tool_rules.append(TerminalToolRule(tool_name=tool_name))
state = { state = {
"id": self.id, "id": self.id,
"organization_id": self.organization_id, "organization_id": self.organization_id,
@@ -123,7 +134,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"tools": self.tools, "tools": self.tools,
"sources": [source.to_pydantic() for source in self.sources], "sources": [source.to_pydantic() for source in self.sources],
"tags": [t.tag for t in self.tags], "tags": [t.tag for t in self.tags],
"tool_rules": self.tool_rules, "tool_rules": tool_rules,
"system": self.system, "system": self.system,
"agent_type": self.agent_type, "agent_type": self.agent_type,
"llm_config": self.llm_config, "llm_config": self.llm_config,
@@ -136,4 +147,5 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"updated_at": self.updated_at, "updated_at": self.updated_at,
"tool_exec_environment_variables": self.tool_exec_environment_variables, "tool_exec_environment_variables": self.tool_exec_environment_variables,
} }
return self.__pydantic_model__(**state) return self.__pydantic_model__(**state)

View File

@@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from pprint import pformat
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
from sqlalchemy import String, and_, func, or_, select from sqlalchemy import String, and_, func, or_, select
@@ -504,7 +505,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
model.metadata = self.metadata_ model.metadata = self.metadata_
return model return model
def to_record(self) -> "BaseModel": def pretty_print_columns(self) -> str:
"""Deprecated accessor for to_pydantic""" """
logger.warning("to_record is deprecated, use to_pydantic instead.") Pretty prints all columns of the current SQLAlchemy object along with their values.
return self.to_pydantic() """
if not hasattr(self, "__table__") or not hasattr(self.__table__, "columns"):
raise NotImplementedError("This object does not have a '__table__.columns' attribute.")
# Iterate over the columns correctly
column_data = {column.name: getattr(self, column.name, None) for column in self.__table__.columns}
return pformat(column_data, indent=4, sort_dicts=True)

View File

@@ -97,6 +97,14 @@ class LLMConfig(BaseModel):
model_wrapper=None, model_wrapper=None,
context_window=128000, context_window=128000,
) )
elif model_name == "gpt-4o":
return cls(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model_wrapper=None,
context_window=128000,
)
elif model_name == "letta": elif model_name == "letta":
return cls( return cls(
model="memgpt-openai", model="memgpt-openai",

View File

@@ -1290,7 +1290,7 @@ class SyncServer(Server):
llm_config.model_endpoint_type not in ["openai", "anthropic"] or "inference.memgpt.ai" in llm_config.model_endpoint llm_config.model_endpoint_type not in ["openai", "anthropic"] or "inference.memgpt.ai" in llm_config.model_endpoint
): ):
warnings.warn( warnings.warn(
"Token streaming is only supported for models with type 'openai', 'anthropic', or `inference.memgpt.ai` in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." f"Token streaming is only supported for models with type 'openai' or 'anthropic' in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
) )
stream_tokens = False stream_tokens = False

View File

@@ -4,6 +4,7 @@ from typing import List, Optional
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS
from letta.functions.functions import derive_openai_json_schema, load_function_set from letta.functions.functions import derive_openai_json_schema, load_function_set
from letta.log import get_logger
from letta.orm.enums import ToolType from letta.orm.enums import ToolType
# TODO: Remove this once we translate all of these to the ORM # TODO: Remove this once we translate all of these to the ORM
@@ -14,6 +15,8 @@ from letta.schemas.tool import ToolUpdate
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types, printd from letta.utils import enforce_types, printd
logger = get_logger(__name__)
class ToolManager: class ToolManager:
"""Manager class to handle business logic related to Tools.""" """Manager class to handle business logic related to Tools."""
@@ -102,7 +105,20 @@ class ToolManager:
limit=limit, limit=limit,
organization_id=actor.organization_id, organization_id=actor.organization_id,
) )
return [tool.to_pydantic() for tool in tools]
# Remove any malformed tools
results = []
for tool in tools:
try:
pydantic_tool = tool.to_pydantic()
results.append(pydantic_tool)
except (ValueError, ModuleNotFoundError, AttributeError) as e:
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
logger.warning("Deleted tool: ")
logger.warning(tool.pretty_print_columns())
self.delete_tool_by_id(tool.id, actor=actor)
return results
@enforce_types @enforce_types
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool: def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool:

View File

@@ -1,6 +1,4 @@
import json import json
import secrets
import string
import pytest import pytest
@@ -9,30 +7,33 @@ from letta import LocalClient, create_client
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from tests.helpers.utils import retry_until_success from tests.helpers.utils import retry_until_success
from tests.utils import wait_for_incoming_message
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def client(): def client():
client = create_client() client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_llm_config(LLMConfig.default_config("gpt-4o"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client yield client
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def agent_obj(client: LocalClient): def agent_obj(client: LocalClient):
"""Create a test agent that we can call functions on""" """Create a test agent that we can call functions on"""
agent_state = client.create_agent(include_multi_agent_tools=True) send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply")
agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id])
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
yield agent_obj yield agent_obj
client.delete_agent(agent_obj.agent_state.id) # client.delete_agent(agent_obj.agent_state.id)
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def other_agent_obj(client: LocalClient): def other_agent_obj(client: LocalClient):
"""Create another test agent that we can call functions on""" """Create another test agent that we can call functions on"""
agent_state = client.create_agent(include_multi_agent_tools=False) agent_state = client.create_agent(include_multi_agent_tools=False)
@@ -119,18 +120,18 @@ def test_recall(client, agent_obj):
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM # This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
@retry_until_success(max_attempts=5, sleep_time_seconds=2) @retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agent(client, agent_obj, other_agent_obj): def test_send_message_to_agent(client, agent_obj, other_agent_obj):
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10)) secret_word = "banana"
# Encourage the agent to send a message to the other agent_obj with the secret string # Encourage the agent to send a message to the other agent_obj with the secret string
client.send_message( client.send_message(
agent_id=agent_obj.agent_state.id, agent_id=agent_obj.agent_state.id,
role="user", role="user",
message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} with the secret password={long_random_string}", message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} to share the secret word: {secret_word}!",
) )
# Conversation search the other agent # Conversation search the other agent
result = base_functions.conversation_search(other_agent_obj, long_random_string) result = base_functions.conversation_search(other_agent_obj, secret_word)
assert long_random_string in result assert secret_word in result
# Search the sender agent for the response from another agent # Search the sender agent for the response from another agent
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
@@ -144,7 +145,7 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.text for m in in_context_messages[1:]])}") print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.text for m in in_context_messages[1:]])}")
if not found: if not found:
pytest.fail(f"Was not able to find an instance of the target snippet: {target_snippet}") raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")
# Test that the agent can still receive messages fine # Test that the agent can still receive messages fine
response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?") response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?")
@@ -161,10 +162,11 @@ def test_send_message_to_agents_with_tags(client):
for agent in prev_worker_agents: for agent in prev_worker_agents:
client.delete_agent(agent.id) client.delete_agent(agent.id)
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10)) secret_word = "banana"
# Create "manager" agent # Create "manager" agent
manager_agent_state = client.create_agent(include_multi_agent_tools=True) send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents # Create 3 worker agents
@@ -187,7 +189,7 @@ def test_send_message_to_agents_with_tags(client):
response = client.send_message( response = client.send_message(
agent_id=manager_agent.agent_state.id, agent_id=manager_agent.agent_state.id,
role="user", role="user",
message=f"Send a message to all agents with tags {worker_tags} informing them of the secret password={long_random_string}", message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!",
) )
for m in response.messages: for m in response.messages:
@@ -201,8 +203,8 @@ def test_send_message_to_agents_with_tags(client):
# Conversation search the worker agents # Conversation search the worker agents
for agent in worker_agents: for agent in worker_agents:
result = base_functions.conversation_search(agent, long_random_string) result = base_functions.conversation_search(agent, secret_word)
assert long_random_string in result assert secret_word in result
# Test that the agent can still receive messages fine # Test that the agent can still receive messages fine
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
@@ -212,3 +214,56 @@ def test_send_message_to_agents_with_tags(client):
client.delete_agent(manager_agent_state.id) client.delete_agent(manager_agent_state.id)
for agent in worker_agents: for agent in worker_agents:
client.delete_agent(agent.agent_state.id) client.delete_agent(agent.agent_state.id)
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_agents_async_simple(client):
"""
Test two agents with multi-agent tools sending messages back and forth to count to 5.
The chain is started by prompting one of the agents.
"""
# Cleanup from potentially failed previous runs
existing_agents = client.server.agent_manager.list_agents(client.user)
for agent in existing_agents:
client.delete_agent(agent.id)
# Create two agents with multi-agent tools
send_message_to_agent_async_tool_id = client.get_tool_id(name="send_message_to_agent_async")
memory_a = ChatMemory(
human="Chad - I'm interested in hearing poem.",
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).",
)
charles_state = client.create_agent(name="charles", memory=memory_a, tool_ids=[send_message_to_agent_async_tool_id])
charles = client.server.load_agent(agent_id=charles_state.id, actor=client.user)
memory_b = ChatMemory(
human="No human - you are to only communicate with the other AI agent.",
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.",
)
sarah_state = client.create_agent(name="sarah", memory=memory_b, tool_ids=[send_message_to_agent_async_tool_id])
# Start the count chain with Agent1
initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me."
client.send_message(
agent_id=charles.agent_state.id,
role="user",
message=initial_prompt,
)
found_in_charles = wait_for_incoming_message(
client=client,
agent_id=charles_state.id,
substring="[Incoming message from agent with ID",
max_wait_seconds=10,
sleep_interval=0.5,
)
assert found_in_charles, "Charles never received the system message from Sarah (timed out)."
found_in_sarah = wait_for_incoming_message(
client=client,
agent_id=sarah_state.id,
substring="[Incoming message from agent with ID",
max_wait_seconds=10,
sleep_interval=0.5,
)
assert found_in_sarah, "Sarah never received the system message from Charles (timed out)."

View File

@@ -1,5 +1,6 @@
import datetime import datetime
import os import os
import time
from datetime import datetime from datetime import datetime
from importlib import util from importlib import util
from typing import Dict, Iterator, List, Tuple from typing import Dict, Iterator, List, Tuple
@@ -8,6 +9,7 @@ import requests
from letta.config import LettaConfig from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector from letta.data_sources.connectors import DataConnector
from letta.schemas.enums import MessageRole
from letta.schemas.file import FileMetadata from letta.schemas.file import FileMetadata
from letta.settings import TestSettings from letta.settings import TestSettings
@@ -145,3 +147,27 @@ def with_qdrant_storage(storage: list[str]):
storage.append("qdrant") storage.append("qdrant")
return storage return storage
def wait_for_incoming_message(
client,
agent_id: str,
substring: str = "[Incoming message from agent with ID",
max_wait_seconds: float = 10.0,
sleep_interval: float = 0.5,
) -> bool:
"""
Polls for up to `max_wait_seconds` to see if the agent's message list
contains a system message with `substring`.
Returns True if found, otherwise False after timeout.
"""
deadline = time.time() + max_wait_seconds
while time.time() < deadline:
messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id)
# Check for the system message containing `substring`
if any(message.role == MessageRole.system and substring in (message.text or "") for message in messages):
return True
time.sleep(sleep_interval)
return False