feat: Native agent to agent messaging (#668)
This commit is contained in:
@@ -12,6 +12,7 @@ from letta.constants import (
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
LETTA_CORE_TOOL_MODULE_NAME,
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
LLM_MAX_TOKENS,
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
@@ -25,6 +26,7 @@ from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
from letta.memory import summarize_messages
|
||||
from letta.orm import User
|
||||
from letta.orm.enums import ToolType
|
||||
@@ -143,6 +145,9 @@ class Agent(BaseAgent):
|
||||
# Load last function response from message history
|
||||
self.last_function_response = self.load_last_function_response()
|
||||
|
||||
# Logger that the Agent specifically can use, will also report the agent_state ID with the logs
|
||||
self.logger = get_logger(agent_state.id)
|
||||
|
||||
def load_last_function_response(self):
|
||||
"""Load the last function response from message history"""
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
@@ -207,6 +212,10 @@ class Agent(BaseAgent):
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE:
|
||||
callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE:
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
agent_state_copy = self.agent_state.__deepcopy__()
|
||||
|
||||
@@ -2251,6 +2251,7 @@ class LocalClient(AbstractClient):
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
include_multi_agent_tools: bool = False,
|
||||
# metadata
|
||||
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
|
||||
description: Optional[str] = None,
|
||||
@@ -2268,6 +2269,7 @@ class LocalClient(AbstractClient):
|
||||
tools (List[str]): List of tools
|
||||
tool_rules (Optional[List[BaseToolRule]]): List of tool rules
|
||||
include_base_tools (bool): Include base tools
|
||||
include_multi_agent_tools (bool): Include multi agent tools
|
||||
metadata (Dict): Metadata
|
||||
description (str): Description
|
||||
tags (List[str]): Tags for filtering agents
|
||||
@@ -2277,11 +2279,6 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
# construct list of tools
|
||||
tool_ids = tool_ids or []
|
||||
tool_names = []
|
||||
if include_base_tools:
|
||||
tool_names += BASE_TOOLS
|
||||
tool_names += BASE_MEMORY_TOOLS
|
||||
tool_ids += [self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user).id for name in tool_names]
|
||||
|
||||
# check if default configs are provided
|
||||
assert embedding_config or self._default_embedding_config, f"Embedding config must be provided"
|
||||
@@ -2304,6 +2301,7 @@ class LocalClient(AbstractClient):
|
||||
"tool_ids": tool_ids,
|
||||
"tool_rules": tool_rules,
|
||||
"include_base_tools": include_base_tools,
|
||||
"include_multi_agent_tools": include_multi_agent_tools,
|
||||
"system": system,
|
||||
"agent_type": agent_type,
|
||||
"llm_config": llm_config if llm_config else self._default_llm_config,
|
||||
|
||||
@@ -12,6 +12,7 @@ COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY"
|
||||
COMPOSIO_TOOL_TAG_NAME = "composio"
|
||||
|
||||
LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base"
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME = "letta.functions.function_sets.multi_agent"
|
||||
|
||||
# String in the error message for when the context window is too large
|
||||
# Example full message:
|
||||
@@ -48,6 +49,10 @@ DEFAULT_PRESET = "memgpt_chat"
|
||||
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
|
||||
# Base memory tools CAN be edited, and are added by default by the server
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
# Multi agent tools
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_specific_agent", "send_message_to_agents_matching_all_tags"]
|
||||
MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3
|
||||
MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60
|
||||
|
||||
# The name of the tool used to send message to the user
|
||||
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
|
||||
|
||||
96
letta/functions/function_sets/multi_agent.py
Normal file
96
letta/functions/function_sets/multi_agent.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT
|
||||
from letta.functions.helpers import async_send_message_with_retries
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agent import Agent
|
||||
|
||||
|
||||
def send_message_to_specific_agent(self: "Agent", message: str, other_agent_id: str) -> Optional[str]:
|
||||
"""
|
||||
Send a message to a specific Letta agent within the same organization.
|
||||
|
||||
Args:
|
||||
message (str): The message to be sent to the target Letta agent.
|
||||
other_agent_id (str): The identifier of the target Letta agent.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The response from the Letta agent. It's possible that the agent does not respond.
|
||||
"""
|
||||
server = get_letta_server()
|
||||
|
||||
# Ensure the target agent is in the same org
|
||||
try:
|
||||
server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=self.user)
|
||||
except NoResultFound:
|
||||
raise ValueError(
|
||||
f"The passed-in agent_id {other_agent_id} either does not exist, "
|
||||
f"or does not belong to the same org ({self.user.organization_id})."
|
||||
)
|
||||
|
||||
# Async logic to send a message with retries and timeout
|
||||
async def async_send_single_agent():
|
||||
return await async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=self,
|
||||
target_agent_id=other_agent_id,
|
||||
message_text=message,
|
||||
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, # or your chosen constants
|
||||
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, # e.g., 1200 for 20 min
|
||||
logging_prefix="[send_message_to_specific_agent]",
|
||||
)
|
||||
|
||||
# Run in the current event loop or create one if needed
|
||||
try:
|
||||
return asyncio.run(async_send_single_agent())
|
||||
except RuntimeError:
|
||||
# e.g., in case there's already an active loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
return loop.run_until_complete(async_send_single_agent())
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]:
|
||||
"""
|
||||
Send a message to all agents in the same organization that match ALL of the given tags.
|
||||
|
||||
Messages are sent in parallel for improved performance, with retries on flaky calls and timeouts for long-running requests.
|
||||
This function does not use a cursor (pagination) and enforces a limit of 100 agents.
|
||||
|
||||
Args:
|
||||
message (str): The message to be sent to each matching agent.
|
||||
tags (List[str]): The list of tags that each agent must have (match_all_tags=True).
|
||||
|
||||
Returns:
|
||||
List[str]: A list of responses from the agents that match all tags.
|
||||
Each response corresponds to one agent.
|
||||
"""
|
||||
server = get_letta_server()
|
||||
|
||||
# Retrieve agents that match ALL specified tags
|
||||
matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, cursor=None, limit=100)
|
||||
|
||||
async def send_messages_to_all_agents():
|
||||
tasks = [
|
||||
async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=self,
|
||||
target_agent_id=agent_state.id,
|
||||
message_text=message,
|
||||
max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES,
|
||||
timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT,
|
||||
logging_prefix="[send_message_to_agents_matching_all_tags]",
|
||||
)
|
||||
for agent_state in matching_agents
|
||||
]
|
||||
# Run all tasks in parallel
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Run the async function and return results
|
||||
return asyncio.run(send_messages_to_all_agents())
|
||||
@@ -1,10 +1,15 @@
|
||||
import json
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import humps
|
||||
from composio.constants import DEFAULT_ENTITY_ID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import MessageCreate
|
||||
|
||||
|
||||
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
|
||||
@@ -206,3 +211,102 @@ def generate_import_code(module_attr_map: Optional[dict]):
|
||||
code_lines.append(f" # Access the {attr} from the module")
|
||||
code_lines.append(f" {attr} = getattr({module_name}, '{attr}')")
|
||||
return "\n".join(code_lines)
|
||||
|
||||
|
||||
def parse_letta_response_for_assistant_message(
|
||||
letta_response: LettaResponse,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> Optional[str]:
|
||||
reasoning_message = ""
|
||||
for m in letta_response.messages:
|
||||
if isinstance(m, AssistantMessage):
|
||||
return m.assistant_message
|
||||
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
|
||||
try:
|
||||
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
|
||||
except Exception: # TODO: Make this more specific
|
||||
continue
|
||||
elif isinstance(m, ReasoningMessage):
|
||||
# This is not ideal, but we would like to return something rather than nothing
|
||||
reasoning_message += f"{m.reasoning}\n"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
import asyncio
|
||||
from random import uniform
|
||||
from typing import Optional
|
||||
|
||||
|
||||
async def async_send_message_with_retries(
|
||||
server,
|
||||
sender_agent: "Agent",
|
||||
target_agent_id: str,
|
||||
message_text: str,
|
||||
max_retries: int,
|
||||
timeout: int,
|
||||
logging_prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Shared helper coroutine to send a message to an agent with retries and a timeout.
|
||||
|
||||
Args:
|
||||
server: The Letta server instance (from get_letta_server()).
|
||||
sender_agent (Agent): The agent initiating the send action.
|
||||
target_agent_id (str): The ID of the agent to send the message to.
|
||||
message_text (str): The text to send as the user message.
|
||||
max_retries (int): Maximum number of retries for the request.
|
||||
timeout (int): Maximum time to wait for a response (in seconds).
|
||||
logging_prefix (str): A prefix to append to logging
|
||||
Returns:
|
||||
str: The response or an error message.
|
||||
"""
|
||||
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
messages = [MessageCreate(role=MessageRole.user, text=message_text, name=sender_agent.agent_state.name)]
|
||||
# Wrap in a timeout
|
||||
response = await asyncio.wait_for(
|
||||
server.send_message_to_agent(
|
||||
agent_id=target_agent_id,
|
||||
actor=sender_agent.user,
|
||||
messages=messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
use_assistant_message=True,
|
||||
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract assistant message
|
||||
assistant_message = parse_letta_response_for_assistant_message(
|
||||
response,
|
||||
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
)
|
||||
if assistant_message:
|
||||
msg = f"Agent {target_agent_id} said '{assistant_message}'"
|
||||
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
||||
return msg
|
||||
else:
|
||||
msg = f"(No response from agent {target_agent_id})"
|
||||
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
||||
return msg
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})"
|
||||
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
|
||||
except Exception as e:
|
||||
error_msg = f"(Error on attempt {attempt}/{max_retries} for agent {target_agent_id}: {e})"
|
||||
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
|
||||
|
||||
# Exponential backoff before retrying
|
||||
if attempt < max_retries:
|
||||
backoff = uniform(0.5, 2) * (2**attempt)
|
||||
sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent to agent send_message...sleeping for {backoff}")
|
||||
await asyncio.sleep(backoff)
|
||||
else:
|
||||
sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}")
|
||||
return error_msg
|
||||
|
||||
@@ -122,6 +122,10 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
|
||||
for o in v["enum"]:
|
||||
function_tokens += 3
|
||||
function_tokens += len(encoding.encode(o))
|
||||
elif field == "items":
|
||||
function_tokens += 2
|
||||
if isinstance(v["items"], dict) and "type" in v["items"]:
|
||||
function_tokens += len(encoding.encode(v["items"]["type"]))
|
||||
else:
|
||||
warnings.warn(f"num_tokens_from_functions: Unsupported field {field} in function {function}")
|
||||
function_tokens += 11
|
||||
|
||||
@@ -5,6 +5,7 @@ class ToolType(str, Enum):
|
||||
CUSTOM = "custom"
|
||||
LETTA_CORE = "letta_core"
|
||||
LETTA_MEMORY_CORE = "letta_memory_core"
|
||||
LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core"
|
||||
|
||||
|
||||
class JobType(str, Enum):
|
||||
|
||||
@@ -115,7 +115,12 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
initial_message_sequence: Optional[List[MessageCreate]] = Field(
|
||||
None, description="The initial set of messages to put in the agent's in-context memory."
|
||||
)
|
||||
include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.")
|
||||
include_base_tools: bool = Field(
|
||||
True, description="If true, attaches the Letta core tools (e.g. archival_memory and core_memory related functions)."
|
||||
)
|
||||
include_multi_agent_tools: bool = Field(
|
||||
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
|
||||
)
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
llm: Optional[str] = Field(
|
||||
|
||||
@@ -2,7 +2,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.constants import COMPOSIO_TOOL_TAG_NAME, FUNCTION_RETURN_CHAR_LIMIT, LETTA_CORE_TOOL_MODULE_NAME
|
||||
from letta.constants import (
|
||||
COMPOSIO_TOOL_TAG_NAME,
|
||||
FUNCTION_RETURN_CHAR_LIMIT,
|
||||
LETTA_CORE_TOOL_MODULE_NAME,
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
)
|
||||
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
|
||||
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
|
||||
@@ -64,6 +69,9 @@ class Tool(BaseTool):
|
||||
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}:
|
||||
# If it's letta core tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
|
||||
elif self.tool_type in {ToolType.LETTA_MULTI_AGENT_CORE}:
|
||||
# If it's letta multi-agent tool, we also generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name)
|
||||
|
||||
# Derive name from the JSON schema if not provided
|
||||
if not self.name:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
|
||||
import numpy as np
|
||||
from sqlalchemy import Select, func, literal, select, union_all
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM, MULTI_AGENT_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
@@ -88,6 +88,8 @@ class AgentManager:
|
||||
tool_names = []
|
||||
if agent_create.include_base_tools:
|
||||
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
if agent_create.include_multi_agent_tools:
|
||||
tool_names.extend(MULTI_AGENT_TOOLS)
|
||||
if agent_create.tools:
|
||||
tool_names.extend(agent_create.tools)
|
||||
# Remove duplicates
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
from letta.orm.enums import ToolType
|
||||
|
||||
@@ -133,39 +133,42 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""Add default tools in base.py"""
|
||||
module_name = "base"
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
"""Add default tools in base.py and multi_agent.py"""
|
||||
functions_to_schema = {}
|
||||
module_names = ["base", "multi_agent"]
|
||||
|
||||
functions_to_schema = []
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema = load_function_set(module)
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
warnings.warn(err)
|
||||
for module_name in module_names:
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema.update(load_function_set(module))
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
warnings.warn(err)
|
||||
|
||||
# create tool in db
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
if name in BASE_TOOLS + BASE_MEMORY_TOOLS:
|
||||
tags = [module_name]
|
||||
if module_name == "base":
|
||||
tags.append("letta-base")
|
||||
|
||||
# BASE_MEMORY_TOOLS should be executed in an e2b sandbox
|
||||
# so they should NOT be letta_core tools, instead, treated as custom tools
|
||||
if name in BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS:
|
||||
if name in BASE_TOOLS:
|
||||
tool_type = ToolType.LETTA_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BASE_MEMORY_TOOLS:
|
||||
tool_type = ToolType.LETTA_MEMORY_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in MULTI_AGENT_TOOLS:
|
||||
tool_type = ToolType.LETTA_MULTI_AGENT_CORE
|
||||
tags = [tool_type.value]
|
||||
else:
|
||||
raise ValueError(f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS}")
|
||||
raise ValueError(
|
||||
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS}"
|
||||
)
|
||||
|
||||
# create to tool
|
||||
tools.append(
|
||||
@@ -180,4 +183,6 @@ class ToolManager:
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Delete any base tools that are stale
|
||||
|
||||
return tools
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
|
||||
import pytest
|
||||
|
||||
import letta.functions.function_sets.base as base_functions
|
||||
from letta import LocalClient, create_client
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
@@ -18,7 +23,7 @@ def client():
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_obj(client: LocalClient):
|
||||
"""Create a test agent that we can call functions on"""
|
||||
agent_state = client.create_agent()
|
||||
agent_state = client.create_agent(include_multi_agent_tools=True)
|
||||
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
yield agent_obj
|
||||
@@ -26,6 +31,17 @@ def agent_obj(client: LocalClient):
|
||||
client.delete_agent(agent_obj.agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def other_agent_obj(client: LocalClient):
|
||||
"""Create another test agent that we can call functions on"""
|
||||
agent_state = client.create_agent(include_multi_agent_tools=False)
|
||||
|
||||
other_agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
yield other_agent_obj
|
||||
|
||||
client.delete_agent(other_agent_obj.agent_state.id)
|
||||
|
||||
|
||||
def query_in_search_results(search_results, query):
|
||||
for result in search_results:
|
||||
if query.lower() in result["content"].lower():
|
||||
@@ -97,3 +113,98 @@ def test_recall(client, agent_obj):
|
||||
# Conversation search
|
||||
result = base_functions.conversation_search(agent_obj, "banana")
|
||||
assert keyword in result
|
||||
|
||||
|
||||
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10))
|
||||
|
||||
# Encourage the agent to send a message to the other agent_obj with the secret string
|
||||
client.send_message(
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
role="user",
|
||||
message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} with the secret password={long_random_string}",
|
||||
)
|
||||
|
||||
# Conversation search the other agent
|
||||
result = base_functions.conversation_search(other_agent_obj, long_random_string)
|
||||
assert long_random_string in result
|
||||
|
||||
# Search the sender agent for the response from another agent
|
||||
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
|
||||
found = False
|
||||
target_snippet = f"Agent {other_agent_obj.agent_state.id} said "
|
||||
|
||||
for m in in_context_messages:
|
||||
if target_snippet in m.text:
|
||||
found = True
|
||||
break
|
||||
|
||||
print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.text for m in in_context_messages[1:]])}")
|
||||
if not found:
|
||||
pytest.fail(f"Was not able to find an instance of the target snippet: {target_snippet}")
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?")
|
||||
print(response.messages)
|
||||
|
||||
|
||||
def test_send_message_to_agents_with_tags(client):
|
||||
worker_tags = ["worker", "user-456"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(10))
|
||||
|
||||
# Create "manager" agent
|
||||
manager_agent_state = client.create_agent(include_multi_agent_tools=True)
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-123"]
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Create 2 worker agents that belong to a different user (These should NOT get the message)
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-456"]
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=f"Send a message to all agents with tags {worker_tags} informing them of the secret password={long_random_string}",
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolReturnMessage):
|
||||
tool_response = eval(json.loads(m.tool_return)["message"])
|
||||
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
|
||||
assert len(tool_response) == len(worker_agents)
|
||||
|
||||
# We can break after this, the ToolReturnMessage after is not related
|
||||
break
|
||||
|
||||
# Conversation search the worker agents
|
||||
# TODO: This search if flaky for some reason
|
||||
# for agent in worker_agents:
|
||||
# result = base_functions.conversation_search(agent, long_random_string)
|
||||
# assert long_random_string in result
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import delete
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET, MULTI_AGENT_TOOLS
|
||||
from letta.orm import FileMetadata, Source
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -339,7 +339,7 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
|
||||
def test_list_tools(client: Union[LocalClient, RESTClient]):
|
||||
tools = client.upsert_base_tools()
|
||||
tool_names = [t.name for t in tools]
|
||||
expected = BASE_TOOLS + BASE_MEMORY_TOOLS
|
||||
expected = BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS
|
||||
assert sorted(tool_names) == sorted(expected)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import delete
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.orm import (
|
||||
@@ -1716,7 +1716,7 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
|
||||
def test_upsert_base_tools(server: SyncServer, default_user):
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
# Call it again to make sure it doesn't create duplicates
|
||||
@@ -1727,8 +1727,12 @@ def test_upsert_base_tools(server: SyncServer, default_user):
|
||||
for t in tools:
|
||||
if t.name in BASE_TOOLS:
|
||||
assert t.tool_type == ToolType.LETTA_CORE
|
||||
else:
|
||||
elif t.name in BASE_MEMORY_TOOLS:
|
||||
assert t.tool_type == ToolType.LETTA_MEMORY_CORE
|
||||
elif t.name in MULTI_AGENT_TOOLS:
|
||||
assert t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE
|
||||
else:
|
||||
pytest.fail(f"The tool name is unrecognized as a base tool: {t.name}")
|
||||
assert t.source_code is None
|
||||
assert t.json_schema
|
||||
|
||||
|
||||
Reference in New Issue
Block a user