diff --git a/letta/agent.py b/letta/agent.py index 6d47b2eb..97a7c90f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -3,28 +3,21 @@ import time import traceback import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from openai.types.beta.function_tool import FunctionTool as OpenAITool from letta.constants import ( CLI_WARNING_PREFIX, - COMPOSIO_ENTITY_ENV_VAR_KEY, ERROR_MESSAGE_PREFIX, FIRST_MESSAGE_ATTEMPTS, FUNC_FAILED_HEARTBEAT_MESSAGE, - LETTA_CORE_TOOL_MODULE_NAME, - LETTA_MULTI_AGENT_TOOL_MODULE_NAME, LLM_MAX_TOKENS, REQ_HEARTBEAT_MESSAGE, ) from letta.errors import ContextWindowExceededError -from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source -from letta.functions.functions import get_function_from_module -from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name from letta.functions.mcp_client.base_client import BaseMCPClient from letta.helpers import ToolRulesSolver -from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads from letta.interface import AgentInterface @@ -35,7 +28,6 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes from letta.log import get_logger from letta.memory import summarize_messages from letta.orm import User -from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig @@ -46,8 +38,6 @@ from letta.schemas.message import Message, ToolReturn from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage from letta.schemas.openai.chat_completion_response import UsageStatistics -from letta.schemas.sandbox_config import SandboxRunResult -from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics from letta.services.agent_manager import AgentManager @@ -58,7 +48,7 @@ from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.provider_manager import ProviderManager from letta.services.step_manager import StepManager -from letta.services.tool_execution_sandbox import ToolExecutionSandbox +from letta.services.tool_executor.tool_executor import ToolExecutionManager from letta.services.tool_manager import ToolManager from letta.settings import summarizer_settings from letta.streaming_interface import StreamingRefreshCLIInterface @@ -210,107 +200,6 @@ class Agent(BaseAgent): return True return False - # TODO: Refactor into separate class v.s. large if/elses here - def execute_tool_and_persist_state( - self, function_name: str, function_args: dict, target_letta_tool: Tool - ) -> tuple[Any, Optional[SandboxRunResult]]: - """ - Execute tool modifications and persist the state of the agent. - Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data - """ - # TODO: add agent manager here - orig_memory_str = self.agent_state.memory.compile() - - # TODO: need to have an AgentState object that actually has full access to the block data - # this is because the sandbox tools need to be able to access block.value to edit this data - try: - if target_letta_tool.tool_type == ToolType.LETTA_CORE: - # base tools are allowed to access the `Agent` object and run on the database - callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) - function_args["self"] = self # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) - elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE: - callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) - function_args["self"] = self # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) - elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE: - callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) - agent_state_copy = self.agent_state.__deepcopy__() - function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) - self.update_memory_if_changed(agent_state_copy.memory) - elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO: - action_name = generate_composio_action_from_func_name(target_letta_tool.name) - # Get entity ID from the agent_state - entity_id = None - for env_var in self.agent_state.tool_exec_environment_variables: - if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY: - entity_id = env_var.value - # Get composio_api_key - composio_api_key = get_composio_api_key(actor=self.user, logger=self.logger) - function_response = execute_composio_action( - action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id - ) - elif target_letta_tool.tool_type == ToolType.EXTERNAL_MCP: - # Get the server name from the tool tag - # TODO make a property instead? - server_name = target_letta_tool.tags[0].split(":")[1] - - # Get the MCPClient from the server's handle - # TODO these don't get raised properly - if not self.mcp_clients: - raise ValueError(f"No MCP client available to use") - if server_name not in self.mcp_clients: - raise ValueError(f"Unknown MCP server name: {server_name}") - mcp_client = self.mcp_clients[server_name] - if not isinstance(mcp_client, BaseMCPClient): - raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}") - - # Check that tool exists - available_tools = mcp_client.list_tools() - available_tool_names = [t.name for t in available_tools] - if function_name not in available_tool_names: - raise ValueError( - f"{function_name} is not available in MCP server {server_name}. Please check your `~/.letta/mcp_config.json` file." - ) - - function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args) - sandbox_run_result = SandboxRunResult(status="error" if is_error else "success") - return function_response, sandbox_run_result - else: - try: - # Parse the source code to extract function annotations - annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) - # Coerce the function arguments to the correct types based on the annotations - function_args = coerce_dict_args_by_annotations(function_args, annotations) - except ValueError as e: - self.logger.debug(f"Error coercing function arguments: {e}") - - # execute tool in a sandbox - # TODO: allow agent_state to specify which sandbox to execute tools in - # TODO: This is only temporary, can remove after we publish a pip package with this object - agent_state_copy = self.agent_state.__deepcopy__() - agent_state_copy.tools = [] - agent_state_copy.tool_rules = [] - - sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run( - agent_state=agent_state_copy - ) - function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state - assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" - if updated_agent_state is not None: - self.update_memory_if_changed(updated_agent_state.memory) - return function_response, sandbox_run_result - except Exception as e: - # Need to catch error here, or else trunction wont happen - # TODO: modify to function execution error - function_response = get_friendly_error_msg( - function_name=function_name, exception_name=type(e).__name__, exception_message=str(e) - ) - return function_response, SandboxRunResult(status="error") - - return function_response, None - def _handle_function_error_response( self, error_msg: str, @@ -613,7 +502,12 @@ class Agent(BaseAgent): }, ) - function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) + # TODO: Make this at the __init__ level + # TODO: Add refresh agent_state logic to ToolExecutionManager, either by passing in or retreiving from db + tool_execution_manager = ToolExecutionManager(self) + function_response, sandbox_run_result = tool_execution_manager.execute_tool( + function_name=function_name, function_args=function_args, tool=target_letta_tool + ) log_event( "tool_call_ended", diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 0ec88c96..89521c8c 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -14,7 +14,8 @@ def send_message(self: "Agent", message: str) -> Optional[str]: Optional[str]: None is always returned as this function does not produce a response. """ # FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference - self.interface.assistant_message(message) # , msg_obj=self._messages[-1]) + if self.interface: + self.interface.assistant_message(message) # , msg_obj=self._messages[-1]) return None diff --git a/letta/services/tool_executor/__init__.py b/letta/services/tool_executor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py new file mode 100644 index 00000000..aba556c6 --- /dev/null +++ b/letta/services/tool_executor/tool_executor.py @@ -0,0 +1,245 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Tuple, Type + +from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, LETTA_CORE_TOOL_MODULE_NAME, LETTA_MULTI_AGENT_TOOL_MODULE_NAME +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source +from letta.functions.functions import get_function_from_module +from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name +from letta.functions.mcp_client.base_client import BaseMCPClient +from letta.helpers.composio_helpers import get_composio_api_key +from letta.orm.enums import ToolType +from letta.schemas.sandbox_config import SandboxRunResult +from letta.schemas.tool import Tool +from letta.services.tool_execution_sandbox import ToolExecutionSandbox +from letta.utils import get_friendly_error_msg + + +class ToolExecutor(ABC): + """Abstract base class for tool executors.""" + + @abstractmethod + def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + """Execute the tool and return the result.""" + + +class LettaCoreToolExecutor(ToolExecutor): + """Executor for LETTA core tools.""" + + def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) + function_args["self"] = agent # need to attach self to arg since it's dynamically linked + function_response = callable_func(**function_args) + return function_response, None + + +class LettaMultiAgentToolExecutor(ToolExecutor): + """Executor for LETTA multi-agent core tools.""" + + def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) + function_args["self"] = agent # need to attach self to arg since it's dynamically linked + function_response = callable_func(**function_args) + return function_response, None + + +class LettaMemoryToolExecutor(ToolExecutor): + """Executor for LETTA memory core tools.""" + + def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) + agent_state_copy = agent.agent_state.__deepcopy__() + function_args["agent_state"] = agent_state_copy + function_response = callable_func(**function_args) + agent.update_memory_if_changed(agent_state_copy.memory) + return function_response, None + + +class ExternalComposioToolExecutor(ToolExecutor): + """Executor for external Composio tools.""" + + def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + action_name = generate_composio_action_from_func_name(tool.name) + + # Get entity ID from the agent_state + entity_id = self._get_entity_id(agent) + + # Get composio_api_key + composio_api_key = get_composio_api_key(actor=agent.user, logger=agent.logger) + + # TODO (matt): Roll in execute_composio_action into this class + function_response = execute_composio_action( + action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id + ) + + return function_response, None + + def _get_entity_id(self, agent: "Agent") -> Optional[str]: + """Extract the entity ID from environment variables.""" + for env_var in agent.agent_state.tool_exec_environment_variables: + if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY: + return env_var.value + return None + + +class ExternalMCPToolExecutor(ToolExecutor): + """Executor for external MCP tools.""" + + def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + # Get the server name from the tool tag + server_name = self._extract_server_name(tool) + + # Get the MCPClient + mcp_client = self._get_mcp_client(agent, server_name) + + # Validate tool exists + self._validate_tool_exists(mcp_client, function_name, server_name) + + # Execute the tool + function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args) + + sandbox_run_result = SandboxRunResult(status="error" if is_error else "success") + return function_response, sandbox_run_result + + def _extract_server_name(self, tool: Tool) -> str: + """Extract server name from tool tags.""" + return tool.tags[0].split(":")[1] + + def _get_mcp_client(self, agent: "Agent", server_name: str): + """Get the MCP client for the given server name.""" + if not agent.mcp_clients: + raise ValueError("No MCP client available to use") + + if server_name not in agent.mcp_clients: + raise ValueError(f"Unknown MCP server name: {server_name}") + + mcp_client = agent.mcp_clients[server_name] + if not isinstance(mcp_client, BaseMCPClient): + raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}") + + return mcp_client + + def _validate_tool_exists(self, mcp_client, function_name: str, server_name: str): + """Validate that the tool exists in the MCP server.""" + available_tools = mcp_client.list_tools() + available_tool_names = [t.name for t in available_tools] + + if function_name not in available_tool_names: + raise ValueError( + f"{function_name} is not available in MCP server {server_name}. " f"Please check your `~/.letta/mcp_config.json` file." + ) + + +class SandboxToolExecutor(ToolExecutor): + """Executor for sandboxed tools.""" + + def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + # Store original memory state + orig_memory_str = agent.agent_state.memory.compile() + + try: + # Prepare function arguments + function_args = self._prepare_function_args(function_args, tool, function_name) + + # Create agent state copy for sandbox + agent_state_copy = self._create_agent_state_copy(agent) + + # Execute in sandbox + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, agent.user, tool_object=tool).run( + agent_state=agent_state_copy + ) + + function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state + + # Verify memory integrity + assert orig_memory_str == agent.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" + + # Update agent memory if needed + if updated_agent_state is not None: + agent.update_memory_if_changed(updated_agent_state.memory) + + return function_response, sandbox_run_result + + except Exception as e: + return self._handle_execution_error(e, function_name) + + def _prepare_function_args(self, function_args: dict, tool: Tool, function_name: str) -> dict: + """Prepare function arguments with proper type coercion.""" + try: + # Parse the source code to extract function annotations + annotations = get_function_annotations_from_source(tool.source_code, function_name) + # Coerce the function arguments to the correct types based on the annotations + return coerce_dict_args_by_annotations(function_args, annotations) + except ValueError: + # Just log the error and continue with original args + # This is defensive programming - we try to coerce but fall back if it fails + return function_args + + def _create_agent_state_copy(self, agent: "Agent"): + """Create a copy of agent state for sandbox execution.""" + agent_state_copy = agent.agent_state.__deepcopy__() + # Remove tools from copy to prevent nested tool execution + agent_state_copy.tools = [] + agent_state_copy.tool_rules = [] + return agent_state_copy + + def _handle_execution_error(self, exception: Exception, function_name: str) -> Tuple[str, SandboxRunResult]: + """Handle tool execution errors.""" + error_message = get_friendly_error_msg( + function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception) + ) + return error_message, SandboxRunResult(status="error") + + +class ToolExecutorFactory: + """Factory for creating appropriate tool executors based on tool type.""" + + _executor_map: Dict[ToolType, Type[ToolExecutor]] = { + ToolType.LETTA_CORE: LettaCoreToolExecutor, + ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor, + ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor, + ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor, + ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor, + } + + @classmethod + def get_executor(cls, tool_type: ToolType) -> ToolExecutor: + """Get the appropriate executor for the given tool type.""" + executor_class = cls._executor_map.get(tool_type) + + if executor_class: + return executor_class() + + # Default to sandbox executor for unknown types + return SandboxToolExecutor() + + +class ToolExecutionManager: + """Manager class for tool execution operations.""" + + def __init__(self, agent: "Agent"): + self.agent = agent + self.logger = agent.logger + + def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + """ + Execute a tool and persist any state changes. + + Args: + function_name: Name of the function to execute + function_args: Arguments to pass to the function + tool: Tool object containing metadata about the tool + + Returns: + Tuple containing the function response and sandbox run result (if applicable) + """ + try: + # Get the appropriate executor for this tool type + executor = ToolExecutorFactory.get_executor(tool.tool_type) + + # Execute the tool + return executor.execute(function_name, function_args, self.agent, tool) + + except Exception as e: + self.logger.error(f"Error executing tool {function_name}: {str(e)}") + error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) + return error_message, SandboxRunResult(status="error") diff --git a/tests/integration_test_composio.py b/tests/integration_test_composio.py index 98e30373..a1898df5 100644 --- a/tests/integration_test_composio.py +++ b/tests/integration_test_composio.py @@ -10,6 +10,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.tool import ToolCreate from letta.server.rest_api.app import app from letta.server.server import SyncServer +from letta.services.tool_executor.tool_executor import ToolExecutionManager logger = get_logger(__name__) @@ -67,8 +68,12 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_ actor=default_user, ) agent = server.load_agent(agent_state.id, actor=default_user) - response = agent.execute_tool_and_persist_state(composio_gmail_get_profile_tool.name, {}, composio_gmail_get_profile_tool) - assert response[0]["response_data"]["emailAddress"] == "sarah@letta.com" + + function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool( + function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool + ) + + assert function_response["response_data"]["emailAddress"] == "sarah@letta.com" # Add agent variable changing the entity ID agent_state = server.agent_manager.update_agent( @@ -77,5 +82,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_ actor=default_user, ) agent = server.load_agent(agent_state.id, actor=default_user) - response = agent.execute_tool_and_persist_state(composio_gmail_get_profile_tool.name, {}, composio_gmail_get_profile_tool) - assert response[0]["response_data"]["emailAddress"] == "matt@letta.com" + function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool( + function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool + ) + assert function_response["response_data"]["emailAddress"] == "matt@letta.com"