feat: Tool executors (#1370)

This commit is contained in:
Matthew Zhou
2025-03-23 15:56:54 -07:00
committed by GitHub
parent 08218112d1
commit 10803b52cd
5 changed files with 266 additions and 119 deletions

View File

@@ -3,28 +3,21 @@ import time
import traceback
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from openai.types.beta.function_tool import FunctionTool as OpenAITool
from letta.constants import (
CLI_WARNING_PREFIX,
COMPOSIO_ENTITY_ENV_VAR_KEY,
ERROR_MESSAGE_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
LETTA_CORE_TOOL_MODULE_NAME,
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
LLM_MAX_TOKENS,
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import ContextWindowExceededError
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.functions import get_function_from_module
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.helpers import ToolRulesSolver
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.interface import AgentInterface
@@ -35,7 +28,6 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
from letta.log import get_logger
from letta.memory import summarize_messages
from letta.orm import User
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
@@ -46,8 +38,6 @@ from letta.schemas.message import Message, ToolReturn
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
@@ -58,7 +48,7 @@ from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.provider_manager import ProviderManager
from letta.services.step_manager import StepManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_executor.tool_executor import ToolExecutionManager
from letta.services.tool_manager import ToolManager
from letta.settings import summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface
@@ -210,107 +200,6 @@ class Agent(BaseAgent):
return True
return False
# TODO: Refactor into separate class v.s. large if/elses here
def execute_tool_and_persist_state(
self, function_name: str, function_args: dict, target_letta_tool: Tool
) -> tuple[Any, Optional[SandboxRunResult]]:
"""
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
"""
# TODO: add agent manager here
orig_memory_str = self.agent_state.memory.compile()
# TODO: need to have an AgentState object that actually has full access to the block data
# this is because the sandbox tools need to be able to access block.value to edit this data
try:
if target_letta_tool.tool_type == ToolType.LETTA_CORE:
# base tools are allowed to access the `Agent` object and run on the database
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE:
callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE:
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
agent_state_copy = self.agent_state.__deepcopy__()
function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
self.update_memory_if_changed(agent_state_copy.memory)
elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
action_name = generate_composio_action_from_func_name(target_letta_tool.name)
# Get entity ID from the agent_state
entity_id = None
for env_var in self.agent_state.tool_exec_environment_variables:
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
entity_id = env_var.value
# Get composio_api_key
composio_api_key = get_composio_api_key(actor=self.user, logger=self.logger)
function_response = execute_composio_action(
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
)
elif target_letta_tool.tool_type == ToolType.EXTERNAL_MCP:
# Get the server name from the tool tag
# TODO make a property instead?
server_name = target_letta_tool.tags[0].split(":")[1]
# Get the MCPClient from the server's handle
# TODO these don't get raised properly
if not self.mcp_clients:
raise ValueError(f"No MCP client available to use")
if server_name not in self.mcp_clients:
raise ValueError(f"Unknown MCP server name: {server_name}")
mcp_client = self.mcp_clients[server_name]
if not isinstance(mcp_client, BaseMCPClient):
raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
# Check that tool exists
available_tools = mcp_client.list_tools()
available_tool_names = [t.name for t in available_tools]
if function_name not in available_tool_names:
raise ValueError(
f"{function_name} is not available in MCP server {server_name}. Please check your `~/.letta/mcp_config.json` file."
)
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
return function_response, sandbox_run_result
else:
try:
# Parse the source code to extract function annotations
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
# Coerce the function arguments to the correct types based on the annotations
function_args = coerce_dict_args_by_annotations(function_args, annotations)
except ValueError as e:
self.logger.debug(f"Error coercing function arguments: {e}")
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
# TODO: This is only temporary, can remove after we publish a pip package with this object
agent_state_copy = self.agent_state.__deepcopy__()
agent_state_copy.tools = []
agent_state_copy.tool_rules = []
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
agent_state=agent_state_copy
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
if updated_agent_state is not None:
self.update_memory_if_changed(updated_agent_state.memory)
return function_response, sandbox_run_result
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
function_response = get_friendly_error_msg(
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
)
return function_response, SandboxRunResult(status="error")
return function_response, None
def _handle_function_error_response(
self,
error_msg: str,
@@ -613,7 +502,12 @@ class Agent(BaseAgent):
},
)
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
# TODO: Make this at the __init__ level
# TODO: Add refresh agent_state logic to ToolExecutionManager, either by passing in or retreiving from db
tool_execution_manager = ToolExecutionManager(self)
function_response, sandbox_run_result = tool_execution_manager.execute_tool(
function_name=function_name, function_args=function_args, tool=target_letta_tool
)
log_event(
"tool_call_ended",

View File

@@ -14,7 +14,8 @@ def send_message(self: "Agent", message: str) -> Optional[str]:
Optional[str]: None is always returned as this function does not produce a response.
"""
# FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference
self.interface.assistant_message(message) # , msg_obj=self._messages[-1])
if self.interface:
self.interface.assistant_message(message) # , msg_obj=self._messages[-1])
return None

View File

View File

@@ -0,0 +1,245 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, Type
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, LETTA_CORE_TOOL_MODULE_NAME, LETTA_MULTI_AGENT_TOOL_MODULE_NAME
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.functions import get_function_from_module
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.helpers.composio_helpers import get_composio_api_key
from letta.orm.enums import ToolType
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.tool import Tool
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.utils import get_friendly_error_msg
class ToolExecutor(ABC):
"""Abstract base class for tool executors."""
@abstractmethod
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
"""Execute the tool and return the result."""
class LettaCoreToolExecutor(ToolExecutor):
"""Executor for LETTA core tools."""
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
function_args["self"] = agent # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
return function_response, None
class LettaMultiAgentToolExecutor(ToolExecutor):
"""Executor for LETTA multi-agent core tools."""
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
function_args["self"] = agent # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
return function_response, None
class LettaMemoryToolExecutor(ToolExecutor):
"""Executor for LETTA memory core tools."""
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
agent_state_copy = agent.agent_state.__deepcopy__()
function_args["agent_state"] = agent_state_copy
function_response = callable_func(**function_args)
agent.update_memory_if_changed(agent_state_copy.memory)
return function_response, None
class ExternalComposioToolExecutor(ToolExecutor):
"""Executor for external Composio tools."""
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
action_name = generate_composio_action_from_func_name(tool.name)
# Get entity ID from the agent_state
entity_id = self._get_entity_id(agent)
# Get composio_api_key
composio_api_key = get_composio_api_key(actor=agent.user, logger=agent.logger)
# TODO (matt): Roll in execute_composio_action into this class
function_response = execute_composio_action(
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
)
return function_response, None
def _get_entity_id(self, agent: "Agent") -> Optional[str]:
"""Extract the entity ID from environment variables."""
for env_var in agent.agent_state.tool_exec_environment_variables:
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
return env_var.value
return None
class ExternalMCPToolExecutor(ToolExecutor):
"""Executor for external MCP tools."""
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
# Get the server name from the tool tag
server_name = self._extract_server_name(tool)
# Get the MCPClient
mcp_client = self._get_mcp_client(agent, server_name)
# Validate tool exists
self._validate_tool_exists(mcp_client, function_name, server_name)
# Execute the tool
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
return function_response, sandbox_run_result
def _extract_server_name(self, tool: Tool) -> str:
"""Extract server name from tool tags."""
return tool.tags[0].split(":")[1]
def _get_mcp_client(self, agent: "Agent", server_name: str):
"""Get the MCP client for the given server name."""
if not agent.mcp_clients:
raise ValueError("No MCP client available to use")
if server_name not in agent.mcp_clients:
raise ValueError(f"Unknown MCP server name: {server_name}")
mcp_client = agent.mcp_clients[server_name]
if not isinstance(mcp_client, BaseMCPClient):
raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
return mcp_client
def _validate_tool_exists(self, mcp_client, function_name: str, server_name: str):
"""Validate that the tool exists in the MCP server."""
available_tools = mcp_client.list_tools()
available_tool_names = [t.name for t in available_tools]
if function_name not in available_tool_names:
raise ValueError(
f"{function_name} is not available in MCP server {server_name}. " f"Please check your `~/.letta/mcp_config.json` file."
)
class SandboxToolExecutor(ToolExecutor):
"""Executor for sandboxed tools."""
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
# Store original memory state
orig_memory_str = agent.agent_state.memory.compile()
try:
# Prepare function arguments
function_args = self._prepare_function_args(function_args, tool, function_name)
# Create agent state copy for sandbox
agent_state_copy = self._create_agent_state_copy(agent)
# Execute in sandbox
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, agent.user, tool_object=tool).run(
agent_state=agent_state_copy
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
# Verify memory integrity
assert orig_memory_str == agent.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
# Update agent memory if needed
if updated_agent_state is not None:
agent.update_memory_if_changed(updated_agent_state.memory)
return function_response, sandbox_run_result
except Exception as e:
return self._handle_execution_error(e, function_name)
def _prepare_function_args(self, function_args: dict, tool: Tool, function_name: str) -> dict:
"""Prepare function arguments with proper type coercion."""
try:
# Parse the source code to extract function annotations
annotations = get_function_annotations_from_source(tool.source_code, function_name)
# Coerce the function arguments to the correct types based on the annotations
return coerce_dict_args_by_annotations(function_args, annotations)
except ValueError:
# Just log the error and continue with original args
# This is defensive programming - we try to coerce but fall back if it fails
return function_args
def _create_agent_state_copy(self, agent: "Agent"):
"""Create a copy of agent state for sandbox execution."""
agent_state_copy = agent.agent_state.__deepcopy__()
# Remove tools from copy to prevent nested tool execution
agent_state_copy.tools = []
agent_state_copy.tool_rules = []
return agent_state_copy
def _handle_execution_error(self, exception: Exception, function_name: str) -> Tuple[str, SandboxRunResult]:
"""Handle tool execution errors."""
error_message = get_friendly_error_msg(
function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception)
)
return error_message, SandboxRunResult(status="error")
class ToolExecutorFactory:
"""Factory for creating appropriate tool executors based on tool type."""
_executor_map: Dict[ToolType, Type[ToolExecutor]] = {
ToolType.LETTA_CORE: LettaCoreToolExecutor,
ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor,
ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor,
ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor,
ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor,
}
@classmethod
def get_executor(cls, tool_type: ToolType) -> ToolExecutor:
"""Get the appropriate executor for the given tool type."""
executor_class = cls._executor_map.get(tool_type)
if executor_class:
return executor_class()
# Default to sandbox executor for unknown types
return SandboxToolExecutor()
class ToolExecutionManager:
"""Manager class for tool execution operations."""
def __init__(self, agent: "Agent"):
self.agent = agent
self.logger = agent.logger
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
"""
Execute a tool and persist any state changes.
Args:
function_name: Name of the function to execute
function_args: Arguments to pass to the function
tool: Tool object containing metadata about the tool
Returns:
Tuple containing the function response and sandbox run result (if applicable)
"""
try:
# Get the appropriate executor for this tool type
executor = ToolExecutorFactory.get_executor(tool.tool_type)
# Execute the tool
return executor.execute(function_name, function_args, self.agent, tool)
except Exception as e:
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
return error_message, SandboxRunResult(status="error")

View File

@@ -10,6 +10,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool import ToolCreate
from letta.server.rest_api.app import app
from letta.server.server import SyncServer
from letta.services.tool_executor.tool_executor import ToolExecutionManager
logger = get_logger(__name__)
@@ -67,8 +68,12 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_
actor=default_user,
)
agent = server.load_agent(agent_state.id, actor=default_user)
response = agent.execute_tool_and_persist_state(composio_gmail_get_profile_tool.name, {}, composio_gmail_get_profile_tool)
assert response[0]["response_data"]["emailAddress"] == "sarah@letta.com"
function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool(
function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool
)
assert function_response["response_data"]["emailAddress"] == "sarah@letta.com"
# Add agent variable changing the entity ID
agent_state = server.agent_manager.update_agent(
@@ -77,5 +82,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_
actor=default_user,
)
agent = server.load_agent(agent_state.id, actor=default_user)
response = agent.execute_tool_and_persist_state(composio_gmail_get_profile_tool.name, {}, composio_gmail_get_profile_tool)
assert response[0]["response_data"]["emailAddress"] == "matt@letta.com"
function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool(
function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool
)
assert function_response["response_data"]["emailAddress"] == "matt@letta.com"