From e360620b027e7b27b32e148c7dbbc0a9a9b6a09e Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 28 May 2025 14:25:17 -0700 Subject: [PATCH] feat(asyncify): migrate run tools (#2496) --- letta/server/rest_api/routers/v1/tools.py | 6 ++-- letta/server/server.py | 20 ++++++++--- .../tool_executor/tool_execution_manager.py | 6 ++-- letta/services/tool_executor/tool_executor.py | 24 +++++++------ tests/test_server.py | 35 +++++++++++-------- 5 files changed, 58 insertions(+), 33 deletions(-) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 60f6fba2..ec01bbf6 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -193,7 +193,7 @@ async def upsert_base_tools( @router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source") -def run_tool_from_source( +async def run_tool_from_source( server: SyncServer = Depends(get_letta_server), request: ToolRunFromSource = Body(...), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -201,10 +201,10 @@ def run_tool_from_source( """ Attempt to build a tool from source, then run it on the provided arguments """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - return server.run_tool_from_source( + return await server.run_tool_from_source( tool_source=request.source_code, tool_source_type=request.source_type, tool_args=request.args, diff --git a/letta/server/server.py b/letta/server/server.py index 02c9f196..9aa56b66 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -96,7 +96,7 @@ from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.step_manager import StepManager from letta.services.telemetry_manager import TelemetryManager -from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox +from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import model_settings, settings, tool_settings @@ -1865,7 +1865,7 @@ class SyncServer(Server): def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig: """Add a new embedding model""" - def run_tool_from_source( + async def run_tool_from_source( self, actor: User, tool_args: Dict[str, str], @@ -1898,8 +1898,20 @@ class SyncServer(Server): # Next, attempt to run the tool with the sandbox try: - tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run( - agent_state=agent_state, additional_env_vars=tool_env_vars + tool_execution_manager = ToolExecutionManager( + agent_state=agent_state, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=actor, + sandbox_env_vars=tool_env_vars, + ) + # TODO: Integrate sandbox result + tool_execution_result = await tool_execution_manager.execute_tool_async( + function_name=tool_name, + function_args=tool_args, + tool=tool, ) return ToolReturnMessage( id="null", diff --git a/letta/services/tool_executor/tool_execution_manager.py b/letta/services/tool_executor/tool_execution_manager.py index 93ee08fa..f8aa622e 100644 --- a/letta/services/tool_executor/tool_execution_manager.py +++ b/letta/services/tool_executor/tool_execution_manager.py @@ -69,8 +69,8 @@ class ToolExecutionManager: agent_manager: AgentManager, block_manager: BlockManager, passage_manager: PassageManager, - agent_state: AgentState, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ): @@ -98,7 +98,9 @@ class ToolExecutionManager: passage_manager=self.passage_manager, actor=self.actor, ) - result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor) + result = await executor.execute( + function_name, function_args, tool, self.actor, self.agent_state, self.sandbox_config, self.sandbox_env_vars + ) # trim result return_str = str(result.func_return) diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index b022aafd..e8a5fc39 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -68,9 +68,9 @@ class ToolExecutor(ABC): self, function_name: str, function_args: dict, - agent_state: AgentState, tool: Tool, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: @@ -84,13 +84,14 @@ class LettaCoreToolExecutor(ToolExecutor): self, function_name: str, function_args: dict, - agent_state: AgentState, tool: Tool, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: # Map function names to method calls + assert agent_state is not None, "Agent state is required for core tools" function_map = { "send_message": self.send_message, "conversation_search": self.conversation_search, @@ -537,12 +538,13 @@ class LettaMultiAgentToolExecutor(ToolExecutor): self, function_name: str, function_args: dict, - agent_state: AgentState, tool: Tool, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: + assert agent_state is not None, "Agent state is required for multi-agent tools" function_map = { "send_message_to_agent_and_wait_for_reply": self.send_message_to_agent_and_wait_for_reply, "send_message_to_agent_async": self.send_message_to_agent_async, @@ -644,12 +646,13 @@ class ExternalComposioToolExecutor(ToolExecutor): self, function_name: str, function_args: dict, - agent_state: AgentState, tool: Tool, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: + assert agent_state is not None, "Agent state is required for external Composio tools" action_name = generate_composio_action_from_func_name(tool.name) # Get entity ID from the agent_state @@ -684,9 +687,9 @@ class ExternalMCPToolExecutor(ToolExecutor): self, function_name: str, function_args: dict, - agent_state: AgentState, tool: Tool, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: @@ -718,21 +721,21 @@ class SandboxToolExecutor(ToolExecutor): self, function_name: str, function_args: JsonDict, - agent_state: AgentState, tool: Tool, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: # Store original memory state - orig_memory_str = agent_state.memory.compile() + orig_memory_str = agent_state.memory.compile() if agent_state else None try: # Prepare function arguments function_args = self._prepare_function_args(function_args, tool, function_name) - agent_state_copy = self._create_agent_state_copy(agent_state) + agent_state_copy = self._create_agent_state_copy(agent_state) if agent_state else None # Execute in sandbox depending on API key if tool_settings.e2b_api_key: @@ -747,7 +750,8 @@ class SandboxToolExecutor(ToolExecutor): tool_execution_result = await sandbox.run(agent_state=agent_state_copy) # Verify memory integrity - assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" + if agent_state: + assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" # Update agent memory if needed if tool_execution_result.agent_state is not None: @@ -805,9 +809,9 @@ class LettaBuiltinToolExecutor(ToolExecutor): self, function_name: str, function_args: dict, - agent_state: AgentState, tool: Tool, actor: User, + agent_state: Optional[AgentState] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: diff --git a/tests/test_server.py b/tests/test_server.py index e237e1ce..93112390 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -771,9 +771,10 @@ def ingest(message: str): import pytest -def test_tool_run_basic(server, disable_e2b_api_key, user): +@pytest.mark.asyncio +async def test_tool_run_basic(server, disable_e2b_api_key, user): """Test running a simple tool from source""" - result = server.run_tool_from_source( + result = await server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", @@ -785,9 +786,10 @@ def test_tool_run_basic(server, disable_e2b_api_key, user): assert not result.stderr -def test_tool_run_with_env_var(server, disable_e2b_api_key, user): +@pytest.mark.asyncio +async def test_tool_run_with_env_var(server, disable_e2b_api_key, user): """Test running a tool that uses an environment variable""" - result = server.run_tool_from_source( + result = await server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR, tool_source_type="python", @@ -800,9 +802,10 @@ def test_tool_run_with_env_var(server, disable_e2b_api_key, user): assert not result.stderr -def test_tool_run_invalid_args(server, disable_e2b_api_key, user): +@pytest.mark.asyncio +async def test_tool_run_invalid_args(server, disable_e2b_api_key, user): """Test running a tool with incorrect arguments""" - result = server.run_tool_from_source( + result = await server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", @@ -816,9 +819,10 @@ def test_tool_run_invalid_args(server, disable_e2b_api_key, user): assert "missing 1 required positional argument" in result.stderr[0] -def test_tool_run_with_distractor(server, disable_e2b_api_key, user): +@pytest.mark.asyncio +async def test_tool_run_with_distractor(server, disable_e2b_api_key, user): """Test running a tool with a distractor function in the source""" - result = server.run_tool_from_source( + result = await server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", @@ -831,9 +835,10 @@ def test_tool_run_with_distractor(server, disable_e2b_api_key, user): assert not result.stderr -def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user): +@pytest.mark.asyncio +async def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user): """Test selecting a tool by name when multiple tools exist in the source""" - result = server.run_tool_from_source( + result = await server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", @@ -847,9 +852,10 @@ def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user): assert not result.stderr -def test_tool_run_util_function(server, disable_e2b_api_key, user): +@pytest.mark.asyncio +async def test_tool_run_util_function(server, disable_e2b_api_key, user): """Test selecting a utility function that does not return anything meaningful""" - result = server.run_tool_from_source( + result = await server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", @@ -863,7 +869,8 @@ def test_tool_run_util_function(server, disable_e2b_api_key, user): assert not result.stderr -def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user): +@pytest.mark.asyncio +async def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user): """Test overriding the autogenerated JSON schema with an explicit one""" explicit_json_schema = { "name": "ingest", @@ -881,7 +888,7 @@ def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user): }, } - result = server.run_tool_from_source( + result = await server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python",