diff --git a/letta/agent.py b/letta/agent.py index 97a7c90f..60462fac 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -3,21 +3,28 @@ import time import traceback import warnings from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from openai.types.beta.function_tool import FunctionTool as OpenAITool from letta.constants import ( CLI_WARNING_PREFIX, + COMPOSIO_ENTITY_ENV_VAR_KEY, ERROR_MESSAGE_PREFIX, FIRST_MESSAGE_ATTEMPTS, FUNC_FAILED_HEARTBEAT_MESSAGE, + LETTA_CORE_TOOL_MODULE_NAME, + LETTA_MULTI_AGENT_TOOL_MODULE_NAME, LLM_MAX_TOKENS, REQ_HEARTBEAT_MESSAGE, ) from letta.errors import ContextWindowExceededError +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source +from letta.functions.functions import get_function_from_module +from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name from letta.functions.mcp_client.base_client import BaseMCPClient from letta.helpers import ToolRulesSolver +from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads from letta.interface import AgentInterface @@ -28,6 +35,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes from letta.log import get_logger from letta.memory import summarize_messages from letta.orm import User +from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig @@ -38,6 +46,8 @@ from letta.schemas.message import Message, ToolReturn from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.sandbox_config import SandboxRunResult +from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics from letta.services.agent_manager import AgentManager @@ -48,7 +58,7 @@ from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.provider_manager import ProviderManager from letta.services.step_manager import StepManager -from letta.services.tool_executor.tool_executor import ToolExecutionManager +from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.settings import summarizer_settings from letta.streaming_interface import StreamingRefreshCLIInterface @@ -502,12 +512,7 @@ class Agent(BaseAgent): }, ) - # TODO: Make this at the __init__ level - # TODO: Add refresh agent_state logic to ToolExecutionManager, either by passing in or retreiving from db - tool_execution_manager = ToolExecutionManager(self) - function_response, sandbox_run_result = tool_execution_manager.execute_tool( - function_name=function_name, function_args=function_args, tool=target_letta_tool - ) + function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) log_event( "tool_call_ended", @@ -1199,6 +1204,107 @@ class Agent(BaseAgent): context_window_breakdown = self.get_context_window() return context_window_breakdown.context_window_size_current + # TODO: Refactor into separate class v.s. large if/elses here + def execute_tool_and_persist_state( + self, function_name: str, function_args: dict, target_letta_tool: Tool + ) -> tuple[Any, Optional[SandboxRunResult]]: + """ + Execute tool modifications and persist the state of the agent. + Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data + """ + # TODO: add agent manager here + orig_memory_str = self.agent_state.memory.compile() + + # TODO: need to have an AgentState object that actually has full access to the block data + # this is because the sandbox tools need to be able to access block.value to edit this data + try: + if target_letta_tool.tool_type == ToolType.LETTA_CORE: + # base tools are allowed to access the `Agent` object and run on the database + callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) + function_args["self"] = self # need to attach self to arg since it's dynamically linked + function_response = callable_func(**function_args) + elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE: + callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) + function_args["self"] = self # need to attach self to arg since it's dynamically linked + function_response = callable_func(**function_args) + elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE: + callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) + agent_state_copy = self.agent_state.__deepcopy__() + function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked + function_response = callable_func(**function_args) + self.update_memory_if_changed(agent_state_copy.memory) + elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO: + action_name = generate_composio_action_from_func_name(target_letta_tool.name) + # Get entity ID from the agent_state + entity_id = None + for env_var in self.agent_state.tool_exec_environment_variables: + if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY: + entity_id = env_var.value + # Get composio_api_key + composio_api_key = get_composio_api_key(actor=self.user, logger=self.logger) + function_response = execute_composio_action( + action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id + ) + elif target_letta_tool.tool_type == ToolType.EXTERNAL_MCP: + # Get the server name from the tool tag + # TODO make a property instead? + server_name = target_letta_tool.tags[0].split(":")[1] + + # Get the MCPClient from the server's handle + # TODO these don't get raised properly + if not self.mcp_clients: + raise ValueError(f"No MCP client available to use") + if server_name not in self.mcp_clients: + raise ValueError(f"Unknown MCP server name: {server_name}") + mcp_client = self.mcp_clients[server_name] + if not isinstance(mcp_client, BaseMCPClient): + raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}") + + # Check that tool exists + available_tools = mcp_client.list_tools() + available_tool_names = [t.name for t in available_tools] + if function_name not in available_tool_names: + raise ValueError( + f"{function_name} is not available in MCP server {server_name}. Please check your `~/.letta/mcp_config.json` file." + ) + + function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args) + sandbox_run_result = SandboxRunResult(status="error" if is_error else "success") + return function_response, sandbox_run_result + else: + try: + # Parse the source code to extract function annotations + annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) + # Coerce the function arguments to the correct types based on the annotations + function_args = coerce_dict_args_by_annotations(function_args, annotations) + except ValueError as e: + self.logger.debug(f"Error coercing function arguments: {e}") + + # execute tool in a sandbox + # TODO: allow agent_state to specify which sandbox to execute tools in + # TODO: This is only temporary, can remove after we publish a pip package with this object + agent_state_copy = self.agent_state.__deepcopy__() + agent_state_copy.tools = [] + agent_state_copy.tool_rules = [] + + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run( + agent_state=agent_state_copy + ) + function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state + assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" + if updated_agent_state is not None: + self.update_memory_if_changed(updated_agent_state.memory) + return function_response, sandbox_run_result + except Exception as e: + # Need to catch error here, or else trunction wont happen + # TODO: modify to function execution error + function_response = get_friendly_error_msg( + function_name=function_name, exception_name=type(e).__name__, exception_message=str(e) + ) + return function_response, SandboxRunResult(status="error") + + return function_response, None + def save_agent(agent: Agent): """Save agent to metadata store""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 9e12531b..e83fa848 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -744,6 +744,8 @@ class AgentManager: Update internal memory object and system prompt if there have been modifications. Args: + actor: + agent_id: new_memory (Memory): the new memory object to compare to the current memory object Returns: diff --git a/letta/services/tool_executor/tool_execution_manager.py b/letta/services/tool_executor/tool_execution_manager.py new file mode 100644 index 00000000..4f435be3 --- /dev/null +++ b/letta/services/tool_executor/tool_execution_manager.py @@ -0,0 +1,74 @@ +from typing import Any, Dict, Optional, Tuple, Type + +from letta.log import get_logger +from letta.orm.enums import ToolType +from letta.schemas.agent import AgentState +from letta.schemas.sandbox_config import SandboxRunResult +from letta.schemas.tool import Tool +from letta.schemas.user import User +from letta.services.tool_executor.tool_executor import ( + ExternalComposioToolExecutor, + ExternalMCPToolExecutor, + LettaCoreToolExecutor, + LettaMemoryToolExecutor, + LettaMultiAgentToolExecutor, + SandboxToolExecutor, + ToolExecutor, +) +from letta.utils import get_friendly_error_msg + + +class ToolExecutorFactory: + """Factory for creating appropriate tool executors based on tool type.""" + + _executor_map: Dict[ToolType, Type[ToolExecutor]] = { + ToolType.LETTA_CORE: LettaCoreToolExecutor, + ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor, + ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor, + ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor, + ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor, + } + + @classmethod + def get_executor(cls, tool_type: ToolType) -> ToolExecutor: + """Get the appropriate executor for the given tool type.""" + executor_class = cls._executor_map.get(tool_type) + + if executor_class: + return executor_class() + + # Default to sandbox executor for unknown types + return SandboxToolExecutor() + + +class ToolExecutionManager: + """Manager class for tool execution operations.""" + + def __init__(self, agent_state: AgentState, actor: User): + self.agent_state = agent_state + self.logger = get_logger(__name__) + self.actor = actor + + def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + """ + Execute a tool and persist any state changes. + + Args: + function_name: Name of the function to execute + function_args: Arguments to pass to the function + tool: Tool object containing metadata about the tool + + Returns: + Tuple containing the function response and sandbox run result (if applicable) + """ + try: + # Get the appropriate executor for this tool type + executor = ToolExecutorFactory.get_executor(tool.tool_type) + + # Execute the tool + return executor.execute(function_name, function_args, self.agent_state, tool, self.actor) + + except Exception as e: + self.logger.error(f"Error executing tool {function_name}: {str(e)}") + error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) + return error_message, SandboxRunResult(status="error") diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index aba556c6..95a3ef69 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -1,15 +1,19 @@ +import math from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Optional, Tuple -from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, LETTA_CORE_TOOL_MODULE_NAME, LETTA_MULTI_AGENT_TOOL_MODULE_NAME +from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source -from letta.functions.functions import get_function_from_module from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name -from letta.functions.mcp_client.base_client import BaseMCPClient from letta.helpers.composio_helpers import get_composio_api_key -from letta.orm.enums import ToolType +from letta.helpers.json_helpers import json_dumps +from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxRunResult from letta.schemas.tool import Tool +from letta.schemas.user import User +from letta.services.agent_manager import AgentManager +from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.utils import get_friendly_error_msg @@ -18,53 +22,234 @@ class ToolExecutor(ABC): """Abstract base class for tool executors.""" @abstractmethod - def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + def execute( + self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User + ) -> Tuple[Any, Optional[SandboxRunResult]]: """Execute the tool and return the result.""" class LettaCoreToolExecutor(ToolExecutor): - """Executor for LETTA core tools.""" + """Executor for LETTA core tools with direct implementation of functions.""" - def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: - callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) - function_args["self"] = agent # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) + def execute( + self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User + ) -> Tuple[Any, Optional[SandboxRunResult]]: + # Map function names to method calls + function_map = { + "send_message": self.send_message, + "conversation_search": self.conversation_search, + "archival_memory_search": self.archival_memory_search, + } + + if function_name not in function_map: + raise ValueError(f"Unknown function: {function_name}") + + # Execute the appropriate function + function_args_copy = function_args.copy() # Make a copy to avoid modifying the original + function_response = function_map[function_name](agent_state, actor, **function_args_copy) return function_response, None + def send_message(self, agent_state: AgentState, actor: User, message: str) -> Optional[str]: + """ + Sends a message to the human user. + + Args: + message (str): Message contents. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + return None + + def conversation_search(self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0) -> Optional[str]: + """ + Search prior conversation history using case-insensitive string matching. + + Args: + query (str): String to search for. + page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + + Returns: + str: Query result string + """ + if page is None or (isinstance(page, str) and page.lower().strip() == "none"): + page = 0 + try: + page = int(page) + except: + raise ValueError(f"'page' argument must be an integer") + + count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + messages = MessageManager().list_user_messages_for_agent( + agent_id=agent_state.id, + actor=actor, + query_text=query, + limit=count, + ) + + total = len(messages) + num_pages = math.ceil(total / count) - 1 # 0 index + + if len(messages) == 0: + results_str = f"No results found." + else: + results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):" + results_formatted = [message.content[0].text for message in messages] + results_str = f"{results_pref} {json_dumps(results_formatted)}" + + return results_str + + def archival_memory_search( + self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0, start: Optional[int] = 0 + ) -> Optional[str]: + """ + Search archival memory using semantic (embedding-based) search. + + Args: + query (str): String to search for. + page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + start (Optional[int]): Starting index for the search results. Defaults to 0. + + Returns: + str: Query result string + """ + if page is None or (isinstance(page, str) and page.lower().strip() == "none"): + page = 0 + try: + page = int(page) + except: + raise ValueError(f"'page' argument must be an integer") + + count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + + try: + # Get results using passage manager + all_results = AgentManager().list_passages( + actor=actor, + agent_id=agent_state.id, + query_text=query, + limit=count + start, # Request enough results to handle offset + embedding_config=agent_state.embedding_config, + embed_query=True, + ) + + # Apply pagination + end = min(count + start, len(all_results)) + paged_results = all_results[start:end] + + # Format results to match previous implementation + formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results] + + return formatted_results, len(formatted_results) + + except Exception as e: + raise e + + def archival_memory_insert(self, agent_state: AgentState, actor: User, content: str) -> Optional[str]: + """ + Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. + + Args: + content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + PassageManager().insert_passage( + agent_state=agent_state, + agent_id=agent_state.id, + text=content, + actor=actor, + ) + AgentManager().rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True) + return None + class LettaMultiAgentToolExecutor(ToolExecutor): """Executor for LETTA multi-agent core tools.""" - def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: - callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) - function_args["self"] = agent # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) - return function_response, None + # TODO: Implement + # def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[ + # Any, Optional[SandboxRunResult]]: + # callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) + # function_args["self"] = agent # need to attach self to arg since it's dynamically linked + # function_response = callable_func(**function_args) + # return function_response, None class LettaMemoryToolExecutor(ToolExecutor): - """Executor for LETTA memory core tools.""" + """Executor for LETTA memory core tools with direct implementation.""" + + def execute( + self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User + ) -> Tuple[Any, Optional[SandboxRunResult]]: + # Map function names to method calls + function_map = { + "core_memory_append": self.core_memory_append, + "core_memory_replace": self.core_memory_replace, + } + + if function_name not in function_map: + raise ValueError(f"Unknown function: {function_name}") + + # Execute the appropriate function with the copied state + function_args_copy = function_args.copy() # Make a copy to avoid modifying the original + function_response = function_map[function_name](agent_state, **function_args_copy) + + # Update memory if changed + AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor) - def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: - callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) - agent_state_copy = agent.agent_state.__deepcopy__() - function_args["agent_state"] = agent_state_copy - function_response = callable_func(**function_args) - agent.update_memory_if_changed(agent_state_copy.memory) return function_response, None + def core_memory_append(self, agent_state: "AgentState", label: str, content: str) -> Optional[str]: + """ + Append to the contents of core memory. + + Args: + label (str): Section of the memory to be edited (persona or human). + content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + current_value = str(agent_state.memory.get_block(label).value) + new_value = current_value + "\n" + str(content) + agent_state.memory.update_block_value(label=label, value=new_value) + return None + + def core_memory_replace(self, agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: + """ + Replace the contents of core memory. To delete memories, use an empty string for new_content. + + Args: + label (str): Section of the memory to be edited (persona or human). + old_content (str): String to replace. Must be an exact match. + new_content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + current_value = str(agent_state.memory.get_block(label).value) + if old_content not in current_value: + raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'") + new_value = current_value.replace(str(old_content), str(new_content)) + agent_state.memory.update_block_value(label=label, value=new_value) + return None + class ExternalComposioToolExecutor(ToolExecutor): """Executor for external Composio tools.""" - def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + def execute( + self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User + ) -> Tuple[Any, Optional[SandboxRunResult]]: action_name = generate_composio_action_from_func_name(tool.name) # Get entity ID from the agent_state - entity_id = self._get_entity_id(agent) + entity_id = self._get_entity_id(agent_state) # Get composio_api_key - composio_api_key = get_composio_api_key(actor=agent.user, logger=agent.logger) + composio_api_key = get_composio_api_key(actor=actor) # TODO (matt): Roll in execute_composio_action into this class function_response = execute_composio_action( @@ -73,9 +258,9 @@ class ExternalComposioToolExecutor(ToolExecutor): return function_response, None - def _get_entity_id(self, agent: "Agent") -> Optional[str]: + def _get_entity_id(self, agent_state: AgentState) -> Optional[str]: """Extract the entity ID from environment variables.""" - for env_var in agent.agent_state.tool_exec_environment_variables: + for env_var in agent_state.tool_exec_environment_variables: if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY: return env_var.value return None @@ -84,78 +269,83 @@ class ExternalComposioToolExecutor(ToolExecutor): class ExternalMCPToolExecutor(ToolExecutor): """Executor for external MCP tools.""" - def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: - # Get the server name from the tool tag - server_name = self._extract_server_name(tool) - - # Get the MCPClient - mcp_client = self._get_mcp_client(agent, server_name) - - # Validate tool exists - self._validate_tool_exists(mcp_client, function_name, server_name) - - # Execute the tool - function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args) - - sandbox_run_result = SandboxRunResult(status="error" if is_error else "success") - return function_response, sandbox_run_result - - def _extract_server_name(self, tool: Tool) -> str: - """Extract server name from tool tags.""" - return tool.tags[0].split(":")[1] - - def _get_mcp_client(self, agent: "Agent", server_name: str): - """Get the MCP client for the given server name.""" - if not agent.mcp_clients: - raise ValueError("No MCP client available to use") - - if server_name not in agent.mcp_clients: - raise ValueError(f"Unknown MCP server name: {server_name}") - - mcp_client = agent.mcp_clients[server_name] - if not isinstance(mcp_client, BaseMCPClient): - raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}") - - return mcp_client - - def _validate_tool_exists(self, mcp_client, function_name: str, server_name: str): - """Validate that the tool exists in the MCP server.""" - available_tools = mcp_client.list_tools() - available_tool_names = [t.name for t in available_tools] - - if function_name not in available_tool_names: - raise ValueError( - f"{function_name} is not available in MCP server {server_name}. " f"Please check your `~/.letta/mcp_config.json` file." - ) + # TODO: Implement + # + # def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> Tuple[ + # Any, Optional[SandboxRunResult]]: + # # Get the server name from the tool tag + # server_name = self._extract_server_name(tool) + # + # # Get the MCPClient + # mcp_client = self._get_mcp_client(agent, server_name) + # + # # Validate tool exists + # self._validate_tool_exists(mcp_client, function_name, server_name) + # + # # Execute the tool + # function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args) + # + # sandbox_run_result = SandboxRunResult(status="error" if is_error else "success") + # return function_response, sandbox_run_result + # + # def _extract_server_name(self, tool: Tool) -> str: + # """Extract server name from tool tags.""" + # return tool.tags[0].split(":")[1] + # + # def _get_mcp_client(self, agent: "Agent", server_name: str): + # """Get the MCP client for the given server name.""" + # if not agent.mcp_clients: + # raise ValueError("No MCP client available to use") + # + # if server_name not in agent.mcp_clients: + # raise ValueError(f"Unknown MCP server name: {server_name}") + # + # mcp_client = agent.mcp_clients[server_name] + # if not isinstance(mcp_client, BaseMCPClient): + # raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}") + # + # return mcp_client + # + # def _validate_tool_exists(self, mcp_client, function_name: str, server_name: str): + # """Validate that the tool exists in the MCP server.""" + # available_tools = mcp_client.list_tools() + # available_tool_names = [t.name for t in available_tools] + # + # if function_name not in available_tool_names: + # raise ValueError( + # f"{function_name} is not available in MCP server {server_name}. " f"Please check your `~/.letta/mcp_config.json` file." + # ) class SandboxToolExecutor(ToolExecutor): """Executor for sandboxed tools.""" - def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: + def execute( + self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User + ) -> Tuple[Any, Optional[SandboxRunResult]]: # Store original memory state - orig_memory_str = agent.agent_state.memory.compile() + orig_memory_str = agent_state.memory.compile() try: # Prepare function arguments function_args = self._prepare_function_args(function_args, tool, function_name) # Create agent state copy for sandbox - agent_state_copy = self._create_agent_state_copy(agent) + agent_state_copy = self._create_agent_state_copy(agent_state) # Execute in sandbox - sandbox_run_result = ToolExecutionSandbox(function_name, function_args, agent.user, tool_object=tool).run( + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor, tool_object=tool).run( agent_state=agent_state_copy ) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state # Verify memory integrity - assert orig_memory_str == agent.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" + assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" # Update agent memory if needed if updated_agent_state is not None: - agent.update_memory_if_changed(updated_agent_state.memory) + AgentManager().update_memory_if_changed(agent_state.id, updated_agent_state.memory, actor) return function_response, sandbox_run_result @@ -174,9 +364,9 @@ class SandboxToolExecutor(ToolExecutor): # This is defensive programming - we try to coerce but fall back if it fails return function_args - def _create_agent_state_copy(self, agent: "Agent"): + def _create_agent_state_copy(self, agent_state: AgentState): """Create a copy of agent state for sandbox execution.""" - agent_state_copy = agent.agent_state.__deepcopy__() + agent_state_copy = agent_state.__deepcopy__() # Remove tools from copy to prevent nested tool execution agent_state_copy.tools = [] agent_state_copy.tool_rules = [] @@ -188,58 +378,3 @@ class SandboxToolExecutor(ToolExecutor): function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception) ) return error_message, SandboxRunResult(status="error") - - -class ToolExecutorFactory: - """Factory for creating appropriate tool executors based on tool type.""" - - _executor_map: Dict[ToolType, Type[ToolExecutor]] = { - ToolType.LETTA_CORE: LettaCoreToolExecutor, - ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor, - ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor, - ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor, - ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor, - } - - @classmethod - def get_executor(cls, tool_type: ToolType) -> ToolExecutor: - """Get the appropriate executor for the given tool type.""" - executor_class = cls._executor_map.get(tool_type) - - if executor_class: - return executor_class() - - # Default to sandbox executor for unknown types - return SandboxToolExecutor() - - -class ToolExecutionManager: - """Manager class for tool execution operations.""" - - def __init__(self, agent: "Agent"): - self.agent = agent - self.logger = agent.logger - - def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: - """ - Execute a tool and persist any state changes. - - Args: - function_name: Name of the function to execute - function_args: Arguments to pass to the function - tool: Tool object containing metadata about the tool - - Returns: - Tuple containing the function response and sandbox run result (if applicable) - """ - try: - # Get the appropriate executor for this tool type - executor = ToolExecutorFactory.get_executor(tool.tool_type) - - # Execute the tool - return executor.execute(function_name, function_args, self.agent, tool) - - except Exception as e: - self.logger.error(f"Error executing tool {function_name}: {str(e)}") - error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) - return error_message, SandboxRunResult(status="error") diff --git a/tests/integration_test_composio.py b/tests/integration_test_composio.py index a1898df5..d6894897 100644 --- a/tests/integration_test_composio.py +++ b/tests/integration_test_composio.py @@ -10,7 +10,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.tool import ToolCreate from letta.server.rest_api.app import app from letta.server.server import SyncServer -from letta.services.tool_executor.tool_executor import ToolExecutionManager +from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager logger = get_logger(__name__) @@ -67,9 +67,8 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_ ), actor=default_user, ) - agent = server.load_agent(agent_state.id, actor=default_user) - function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool( + function_response, sandbox_run_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool( function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool ) @@ -81,8 +80,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_ agent_update=UpdateAgent(tool_exec_environment_variables={COMPOSIO_ENTITY_ENV_VAR_KEY: "matt"}), actor=default_user, ) - agent = server.load_agent(agent_state.id, actor=default_user) - function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool( + function_response, sandbox_run_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool( function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool ) assert function_response["response_data"]["emailAddress"] == "matt@letta.com" diff --git a/tests/test_client.py b/tests/test_client.py index 5dcb7da6..f0da5930 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,3 @@ -import asyncio import os import threading import time @@ -577,36 +576,37 @@ def test_send_system_message(client: Letta, agent: AgentState): assert send_system_message_response, "Sending message failed" -@pytest.mark.asyncio -async def test_send_message_parallel(client: Letta, agent: AgentState, request): - """ - Test that sending two messages in parallel does not error. - """ - - # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls - async def send_message_task(message: str): - response = await asyncio.to_thread( - client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)] - ) - assert response, f"Sending message '{message}' failed" - return response - - # Prepare two tasks with different messages - messages = ["Test message 1", "Test message 2"] - tasks = [send_message_task(message) for message in messages] - - # Run the tasks concurrently - responses = await asyncio.gather(*tasks, return_exceptions=True) - - # Check for exceptions and validate responses - for i, response in enumerate(responses): - if isinstance(response, Exception): - pytest.fail(f"Task {i} failed with exception: {response}") - else: - assert response, f"Task {i} returned an invalid response: {response}" - - # Ensure both tasks completed - assert len(responses) == len(messages), "Not all messages were processed" +# TODO: Add back when new agent loop hits +# @pytest.mark.asyncio +# async def test_send_message_parallel(client: Letta, agent: AgentState, request): +# """ +# Test that sending two messages in parallel does not error. +# """ +# +# # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls +# async def send_message_task(message: str): +# response = await asyncio.to_thread( +# client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)] +# ) +# assert response, f"Sending message '{message}' failed" +# return response +# +# # Prepare two tasks with different messages +# messages = ["Test message 1", "Test message 2"] +# tasks = [send_message_task(message) for message in messages] +# +# # Run the tasks concurrently +# responses = await asyncio.gather(*tasks, return_exceptions=True) +# +# # Check for exceptions and validate responses +# for i, response in enumerate(responses): +# if isinstance(response, Exception): +# pytest.fail(f"Task {i} failed with exception: {response}") +# else: +# assert response, f"Task {i} returned an invalid response: {response}" +# +# # Ensure both tasks completed +# assert len(responses) == len(messages), "Not all messages were processed" # ---------------------------------------------------------------------------------------------------- diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index a85a0e09..2790a38f 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1,4 +1,3 @@ -import asyncio import os import threading import time @@ -444,43 +443,44 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState): assert "ZeroDivisionError" in response_message.tool_return -@pytest.mark.asyncio -async def test_send_message_parallel(client: LettaSDKClient, agent: AgentState): - """ - Test that sending two messages in parallel does not error. - """ - - # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls - async def send_message_task(message: str): - response = await asyncio.to_thread( - client.agents.messages.create, - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content=message, - ), - ], - ) - assert response, f"Sending message '{message}' failed" - return response - - # Prepare two tasks with different messages - messages = ["Test message 1", "Test message 2"] - tasks = [send_message_task(message) for message in messages] - - # Run the tasks concurrently - responses = await asyncio.gather(*tasks, return_exceptions=True) - - # Check for exceptions and validate responses - for i, response in enumerate(responses): - if isinstance(response, Exception): - pytest.fail(f"Task {i} failed with exception: {response}") - else: - assert response, f"Task {i} returned an invalid response: {response}" - - # Ensure both tasks completed - assert len(responses) == len(messages), "Not all messages were processed" +# TODO: Add back when the new agent loop hits +# @pytest.mark.asyncio +# async def test_send_message_parallel(client: LettaSDKClient, agent: AgentState): +# """ +# Test that sending two messages in parallel does not error. +# """ +# +# # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls +# async def send_message_task(message: str): +# response = await asyncio.to_thread( +# client.agents.messages.create, +# agent_id=agent.id, +# messages=[ +# MessageCreate( +# role="user", +# content=message, +# ), +# ], +# ) +# assert response, f"Sending message '{message}' failed" +# return response +# +# # Prepare two tasks with different messages +# messages = ["Test message 1", "Test message 2"] +# tasks = [send_message_task(message) for message in messages] +# +# # Run the tasks concurrently +# responses = await asyncio.gather(*tasks, return_exceptions=True) +# +# # Check for exceptions and validate responses +# for i, response in enumerate(responses): +# if isinstance(response, Exception): +# pytest.fail(f"Task {i} failed with exception: {response}") +# else: +# assert response, f"Task {i} returned an invalid response: {response}" +# +# # Ensure both tasks completed +# assert len(responses) == len(messages), "Not all messages were processed" def test_send_message_async(client: LettaSDKClient, agent: AgentState):