diff --git a/letta/agent.py b/letta/agent.py index e3fc3603..913b46b8 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -12,6 +12,7 @@ from letta.constants import ( FIRST_MESSAGE_ATTEMPTS, FUNC_FAILED_HEARTBEAT_MESSAGE, LETTA_CORE_TOOL_MODULE_NAME, + LETTA_MULTI_AGENT_TOOL_MODULE_NAME, LLM_MAX_TOKENS, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, @@ -25,6 +26,7 @@ from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error from letta.llm_api.llm_api_tools import create from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages +from letta.log import get_logger from letta.memory import summarize_messages from letta.orm import User from letta.orm.enums import ToolType @@ -143,6 +145,9 @@ class Agent(BaseAgent): # Load last function response from message history self.last_function_response = self.load_last_function_response() + # Logger that the Agent specifically can use, will also report the agent_state ID with the logs + self.logger = get_logger(agent_state.id) + def load_last_function_response(self): """Load the last function response from message history""" in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) @@ -207,6 +212,10 @@ class Agent(BaseAgent): callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) function_args["self"] = self # need to attach self to arg since it's dynamically linked function_response = callable_func(**function_args) + elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE: + callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) + function_args["self"] = self # need to attach self to arg since it's dynamically linked + function_response = callable_func(**function_args) elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE: callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) agent_state_copy = self.agent_state.__deepcopy__() diff --git a/letta/client/client.py b/letta/client/client.py index c4e1497f..686171e2 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2251,6 +2251,7 @@ class LocalClient(AbstractClient): tool_ids: Optional[List[str]] = None, tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, + include_multi_agent_tools: bool = False, # metadata metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, description: Optional[str] = None, @@ -2268,6 +2269,7 @@ class LocalClient(AbstractClient): tools (List[str]): List of tools tool_rules (Optional[List[BaseToolRule]]): List of tool rules include_base_tools (bool): Include base tools + include_multi_agent_tools (bool): Include multi agent tools metadata (Dict): Metadata description (str): Description tags (List[str]): Tags for filtering agents @@ -2277,11 +2279,6 @@ class LocalClient(AbstractClient): """ # construct list of tools tool_ids = tool_ids or [] - tool_names = [] - if include_base_tools: - tool_names += BASE_TOOLS - tool_names += BASE_MEMORY_TOOLS - tool_ids += [self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user).id for name in tool_names] # check if default configs are provided assert embedding_config or self._default_embedding_config, f"Embedding config must be provided" @@ -2304,6 +2301,7 @@ class LocalClient(AbstractClient): "tool_ids": tool_ids, "tool_rules": tool_rules, "include_base_tools": include_base_tools, + "include_multi_agent_tools": include_multi_agent_tools, "system": system, "agent_type": agent_type, "llm_config": llm_config if llm_config else self._default_llm_config, diff --git a/letta/constants.py b/letta/constants.py index d1a18e37..0b46202a 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -12,6 +12,7 @@ COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY" COMPOSIO_TOOL_TAG_NAME = "composio" LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base" +LETTA_MULTI_AGENT_TOOL_MODULE_NAME = "letta.functions.function_sets.multi_agent" # String in the error message for when the context window is too large # Example full message: @@ -48,6 +49,10 @@ DEFAULT_PRESET = "memgpt_chat" BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"] # Base memory tools CAN be edited, and are added by default by the server BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] +# Multi agent tools +MULTI_AGENT_TOOLS = ["send_message_to_specific_agent", "send_message_to_agents_matching_all_tags"] +MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3 +MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60 # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py new file mode 100644 index 00000000..015ac9c1 --- /dev/null +++ b/letta/functions/function_sets/multi_agent.py @@ -0,0 +1,96 @@ +import asyncio +from typing import TYPE_CHECKING, List, Optional + +from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT +from letta.functions.helpers import async_send_message_with_retries +from letta.orm.errors import NoResultFound +from letta.server.rest_api.utils import get_letta_server + +if TYPE_CHECKING: + from letta.agent import Agent + + +def send_message_to_specific_agent(self: "Agent", message: str, other_agent_id: str) -> Optional[str]: + """ + Send a message to a specific Letta agent within the same organization. + + Args: + message (str): The message to be sent to the target Letta agent. + other_agent_id (str): The identifier of the target Letta agent. + + Returns: + Optional[str]: The response from the Letta agent. It's possible that the agent does not respond. + """ + server = get_letta_server() + + # Ensure the target agent is in the same org + try: + server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=self.user) + except NoResultFound: + raise ValueError( + f"The passed-in agent_id {other_agent_id} either does not exist, " + f"or does not belong to the same org ({self.user.organization_id})." + ) + + # Async logic to send a message with retries and timeout + async def async_send_single_agent(): + return await async_send_message_with_retries( + server=server, + sender_agent=self, + target_agent_id=other_agent_id, + message_text=message, + max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, # or your chosen constants + timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, # e.g., 1200 for 20 min + logging_prefix="[send_message_to_specific_agent]", + ) + + # Run in the current event loop or create one if needed + try: + return asyncio.run(async_send_single_agent()) + except RuntimeError: + # e.g., in case there's already an active loop + loop = asyncio.get_event_loop() + if loop.is_running(): + return loop.run_until_complete(async_send_single_agent()) + else: + raise + + +def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]: + """ + Send a message to all agents in the same organization that match ALL of the given tags. + + Messages are sent in parallel for improved performance, with retries on flaky calls and timeouts for long-running requests. + This function does not use a cursor (pagination) and enforces a limit of 100 agents. + + Args: + message (str): The message to be sent to each matching agent. + tags (List[str]): The list of tags that each agent must have (match_all_tags=True). + + Returns: + List[str]: A list of responses from the agents that match all tags. + Each response corresponds to one agent. + """ + server = get_letta_server() + + # Retrieve agents that match ALL specified tags + matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, cursor=None, limit=100) + + async def send_messages_to_all_agents(): + tasks = [ + async_send_message_with_retries( + server=server, + sender_agent=self, + target_agent_id=agent_state.id, + message_text=message, + max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, + timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, + logging_prefix="[send_message_to_agents_matching_all_tags]", + ) + for agent_state in matching_agents + ] + # Run all tasks in parallel + return await asyncio.gather(*tasks) + + # Run the async function and return results + return asyncio.run(send_messages_to_all_agents()) diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index c03751a2..cbdb5001 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -1,10 +1,15 @@ +import json from typing import Any, Optional, Union import humps from composio.constants import DEFAULT_ENTITY_ID from pydantic import BaseModel -from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY +from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage +from letta.schemas.letta_response import LettaResponse +from letta.schemas.message import MessageCreate def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]: @@ -206,3 +211,102 @@ def generate_import_code(module_attr_map: Optional[dict]): code_lines.append(f" # Access the {attr} from the module") code_lines.append(f" {attr} = getattr({module_name}, '{attr}')") return "\n".join(code_lines) + + +def parse_letta_response_for_assistant_message( + letta_response: LettaResponse, + assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, +) -> Optional[str]: + reasoning_message = "" + for m in letta_response.messages: + if isinstance(m, AssistantMessage): + return m.assistant_message + elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name: + try: + return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg] + except Exception: # TODO: Make this more specific + continue + elif isinstance(m, ReasoningMessage): + # This is not ideal, but we would like to return something rather than nothing + reasoning_message += f"{m.reasoning}\n" + + return None + + +import asyncio +from random import uniform +from typing import Optional + + +async def async_send_message_with_retries( + server, + sender_agent: "Agent", + target_agent_id: str, + message_text: str, + max_retries: int, + timeout: int, + logging_prefix: Optional[str] = None, +) -> str: + """ + Shared helper coroutine to send a message to an agent with retries and a timeout. + + Args: + server: The Letta server instance (from get_letta_server()). + sender_agent (Agent): The agent initiating the send action. + target_agent_id (str): The ID of the agent to send the message to. + message_text (str): The text to send as the user message. + max_retries (int): Maximum number of retries for the request. + timeout (int): Maximum time to wait for a response (in seconds). + logging_prefix (str): A prefix to append to logging + Returns: + str: The response or an error message. + """ + logging_prefix = logging_prefix or "[async_send_message_with_retries]" + for attempt in range(1, max_retries + 1): + try: + messages = [MessageCreate(role=MessageRole.user, text=message_text, name=sender_agent.agent_state.name)] + # Wrap in a timeout + response = await asyncio.wait_for( + server.send_message_to_agent( + agent_id=target_agent_id, + actor=sender_agent.user, + messages=messages, + stream_steps=False, + stream_tokens=False, + use_assistant_message=True, + assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, + ), + timeout=timeout, + ) + + # Extract assistant message + assistant_message = parse_letta_response_for_assistant_message( + response, + assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, + ) + if assistant_message: + msg = f"Agent {target_agent_id} said '{assistant_message}'" + sender_agent.logger.info(f"{logging_prefix} - {msg}") + return msg + else: + msg = f"(No response from agent {target_agent_id})" + sender_agent.logger.info(f"{logging_prefix} - {msg}") + return msg + except asyncio.TimeoutError: + error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})" + sender_agent.logger.warning(f"{logging_prefix} - {error_msg}") + except Exception as e: + error_msg = f"(Error on attempt {attempt}/{max_retries} for agent {target_agent_id}: {e})" + sender_agent.logger.warning(f"{logging_prefix} - {error_msg}") + + # Exponential backoff before retrying + if attempt < max_retries: + backoff = uniform(0.5, 2) * (2**attempt) + sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent to agent send_message...sleeping for {backoff}") + await asyncio.sleep(backoff) + else: + sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}") + return error_msg diff --git a/letta/local_llm/utils.py b/letta/local_llm/utils.py index b0529c35..f5d54174 100644 --- a/letta/local_llm/utils.py +++ b/letta/local_llm/utils.py @@ -122,6 +122,10 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): for o in v["enum"]: function_tokens += 3 function_tokens += len(encoding.encode(o)) + elif field == "items": + function_tokens += 2 + if isinstance(v["items"], dict) and "type" in v["items"]: + function_tokens += len(encoding.encode(v["items"]["type"])) else: warnings.warn(f"num_tokens_from_functions: Unsupported field {field} in function {function}") function_tokens += 11 diff --git a/letta/orm/enums.py b/letta/orm/enums.py index 5238098d..aa7f800b 100644 --- a/letta/orm/enums.py +++ b/letta/orm/enums.py @@ -5,6 +5,7 @@ class ToolType(str, Enum): CUSTOM = "custom" LETTA_CORE = "letta_core" LETTA_MEMORY_CORE = "letta_memory_core" + LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core" class JobType(str, Enum): diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index e20c6e48..697734bd 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -115,7 +115,12 @@ class CreateAgent(BaseModel, validate_assignment=True): # initial_message_sequence: Optional[List[MessageCreate]] = Field( None, description="The initial set of messages to put in the agent's in-context memory." ) - include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.") + include_base_tools: bool = Field( + True, description="If true, attaches the Letta core tools (e.g. archival_memory and core_memory related functions)." + ) + include_multi_agent_tools: bool = Field( + False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)." + ) description: Optional[str] = Field(None, description="The description of the agent.") metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") llm: Optional[str] = Field( diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index fe5c4cc9..610685b4 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -2,7 +2,12 @@ from typing import Any, Dict, List, Optional from pydantic import Field, model_validator -from letta.constants import COMPOSIO_TOOL_TAG_NAME, FUNCTION_RETURN_CHAR_LIMIT, LETTA_CORE_TOOL_MODULE_NAME +from letta.constants import ( + COMPOSIO_TOOL_TAG_NAME, + FUNCTION_RETURN_CHAR_LIMIT, + LETTA_CORE_TOOL_MODULE_NAME, + LETTA_MULTI_AGENT_TOOL_MODULE_NAME, +) from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper from letta.functions.schema_generator import generate_schema_from_args_schema_v2 @@ -64,6 +69,9 @@ class Tool(BaseTool): elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}: # If it's letta core tool, we generate the json_schema on the fly here self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name) + elif self.tool_type in {ToolType.LETTA_MULTI_AGENT_CORE}: + # If it's letta multi-agent tool, we also generate the json_schema on the fly here + self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name) # Derive name from the JSON schema if not provided if not self.name: diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 3ba29ba9..5088c7c7 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional import numpy as np from sqlalchemy import Select, func, literal, select, union_all -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM, MULTI_AGENT_TOOLS from letta.embeddings import embedding_model from letta.log import get_logger from letta.orm import Agent as AgentModel @@ -88,6 +88,8 @@ class AgentManager: tool_names = [] if agent_create.include_base_tools: tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS) + if agent_create.include_multi_agent_tools: + tool_names.extend(MULTI_AGENT_TOOLS) if agent_create.tools: tool_names.extend(agent_create.tools) # Remove duplicates diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index f2672d96..d2192329 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -2,7 +2,7 @@ import importlib import warnings from typing import List, Optional -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS from letta.functions.functions import derive_openai_json_schema, load_function_set from letta.orm.enums import ToolType @@ -133,39 +133,42 @@ class ToolManager: @enforce_types def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: - """Add default tools in base.py""" - module_name = "base" - full_module_name = f"letta.functions.function_sets.{module_name}" - try: - module = importlib.import_module(full_module_name) - except Exception as e: - # Handle other general exceptions - raise e + """Add default tools in base.py and multi_agent.py""" + functions_to_schema = {} + module_names = ["base", "multi_agent"] - functions_to_schema = [] - try: - # Load the function set - functions_to_schema = load_function_set(module) - except ValueError as e: - err = f"Error loading function set '{module_name}': {e}" - warnings.warn(err) + for module_name in module_names: + full_module_name = f"letta.functions.function_sets.{module_name}" + try: + module = importlib.import_module(full_module_name) + except Exception as e: + # Handle other general exceptions + raise e + + try: + # Load the function set + functions_to_schema.update(load_function_set(module)) + except ValueError as e: + err = f"Error loading function set '{module_name}': {e}" + warnings.warn(err) # create tool in db tools = [] for name, schema in functions_to_schema.items(): - if name in BASE_TOOLS + BASE_MEMORY_TOOLS: - tags = [module_name] - if module_name == "base": - tags.append("letta-base") - - # BASE_MEMORY_TOOLS should be executed in an e2b sandbox - # so they should NOT be letta_core tools, instead, treated as custom tools + if name in BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS: if name in BASE_TOOLS: tool_type = ToolType.LETTA_CORE + tags = [tool_type.value] elif name in BASE_MEMORY_TOOLS: tool_type = ToolType.LETTA_MEMORY_CORE + tags = [tool_type.value] + elif name in MULTI_AGENT_TOOLS: + tool_type = ToolType.LETTA_MULTI_AGENT_CORE + tags = [tool_type.value] else: - raise ValueError(f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS}") + raise ValueError( + f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS}" + ) # create to tool tools.append( @@ -180,4 +183,6 @@ class ToolManager: ) ) + # TODO: Delete any base tools that are stale + return tools diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 5b5bec6f..258083d8 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,8 +1,13 @@ +import json +import secrets +import string + import pytest import letta.functions.function_sets.base as base_functions from letta import LocalClient, create_client from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.llm_config import LLMConfig @@ -18,7 +23,7 @@ def client(): @pytest.fixture(scope="module") def agent_obj(client: LocalClient): """Create a test agent that we can call functions on""" - agent_state = client.create_agent() + agent_state = client.create_agent(include_multi_agent_tools=True) agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) yield agent_obj @@ -26,6 +31,17 @@ def agent_obj(client: LocalClient): client.delete_agent(agent_obj.agent_state.id) +@pytest.fixture(scope="module") +def other_agent_obj(client: LocalClient): + """Create another test agent that we can call functions on""" + agent_state = client.create_agent(include_multi_agent_tools=False) + + other_agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) + yield other_agent_obj + + client.delete_agent(other_agent_obj.agent_state.id) + + def query_in_search_results(search_results, query): for result in search_results: if query.lower() in result["content"].lower(): @@ -97,3 +113,98 @@ def test_recall(client, agent_obj): # Conversation search result = base_functions.conversation_search(agent_obj, "banana") assert keyword in result + + +def test_send_message_to_agent(client, agent_obj, other_agent_obj): + long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10)) + + # Encourage the agent to send a message to the other agent_obj with the secret string + client.send_message( + agent_id=agent_obj.agent_state.id, + role="user", + message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} with the secret password={long_random_string}", + ) + + # Conversation search the other agent + result = base_functions.conversation_search(other_agent_obj, long_random_string) + assert long_random_string in result + + # Search the sender agent for the response from another agent + in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) + found = False + target_snippet = f"Agent {other_agent_obj.agent_state.id} said " + + for m in in_context_messages: + if target_snippet in m.text: + found = True + break + + print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.text for m in in_context_messages[1:]])}") + if not found: + pytest.fail(f"Was not able to find an instance of the target snippet: {target_snippet}") + + # Test that the agent can still receive messages fine + response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?") + print(response.messages) + + +def test_send_message_to_agents_with_tags(client): + worker_tags = ["worker", "user-456"] + + # Clean up first from possibly failed tests + prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True) + for agent in prev_worker_agents: + client.delete_agent(agent.id) + + long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10)) + + # Create "manager" agent + manager_agent_state = client.create_agent(include_multi_agent_tools=True) + manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) + + # Create 3 worker agents + worker_agents = [] + worker_tags = ["worker", "user-123"] + for _ in range(3): + worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags) + worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) + worker_agents.append(worker_agent) + + # Create 2 worker agents that belong to a different user (These should NOT get the message) + worker_agents = [] + worker_tags = ["worker", "user-456"] + for _ in range(3): + worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags) + worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) + worker_agents.append(worker_agent) + + # Encourage the manager to send a message to the other agent_obj with the secret string + response = client.send_message( + agent_id=manager_agent.agent_state.id, + role="user", + message=f"Send a message to all agents with tags {worker_tags} informing them of the secret password={long_random_string}", + ) + + for m in response.messages: + if isinstance(m, ToolReturnMessage): + tool_response = eval(json.loads(m.tool_return)["message"]) + print(f"\n\nManager agent tool response: \n{tool_response}\n\n") + assert len(tool_response) == len(worker_agents) + + # We can break after this, the ToolReturnMessage after is not related + break + + # Conversation search the worker agents + # TODO: This search if flaky for some reason + # for agent in worker_agents: + # result = base_functions.conversation_search(agent, long_random_string) + # assert long_random_string in result + + # Test that the agent can still receive messages fine + response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") + print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages])) + + # Clean up agents + client.delete_agent(manager_agent_state.id) + for agent in worker_agents: + client.delete_agent(agent.agent_state.id) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 202adf17..51fe6d54 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -11,7 +11,7 @@ from sqlalchemy import delete from letta import create_client from letta.client.client import LocalClient, RESTClient -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET, MULTI_AGENT_TOOLS from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig @@ -339,7 +339,7 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): def test_list_tools(client: Union[LocalClient, RESTClient]): tools = client.upsert_base_tools() tool_names = [t.name for t in tools] - expected = BASE_TOOLS + BASE_MEMORY_TOOLS + expected = BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS assert sorted(tool_names) == sorted(expected) diff --git a/tests/test_managers.py b/tests/test_managers.py index 5d9f9aa2..efe736f6 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -7,7 +7,7 @@ from sqlalchemy import delete from sqlalchemy.exc import IntegrityError from letta.config import LettaConfig -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm import ( @@ -1716,7 +1716,7 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user): def test_upsert_base_tools(server: SyncServer, default_user): tools = server.tool_manager.upsert_base_tools(actor=default_user) - expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS) + expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS) assert sorted([t.name for t in tools]) == expected_tool_names # Call it again to make sure it doesn't create duplicates @@ -1727,8 +1727,12 @@ def test_upsert_base_tools(server: SyncServer, default_user): for t in tools: if t.name in BASE_TOOLS: assert t.tool_type == ToolType.LETTA_CORE - else: + elif t.name in BASE_MEMORY_TOOLS: assert t.tool_type == ToolType.LETTA_MEMORY_CORE + elif t.name in MULTI_AGENT_TOOLS: + assert t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE + else: + pytest.fail(f"The tool name is unrecognized as a base tool: {t.name}") assert t.source_code is None assert t.json_schema