diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 8066f9b2..32f6a8af 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -212,13 +212,9 @@ class ToolUpdate(LettaBase): # TODO: Remove this, and clean usage of ToolUpdate everywhere else -class ToolRun(LettaBase): - id: str = Field(..., description="The ID of the tool to run.") - args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).") - - class ToolRunFromSource(LettaBase): source_code: str = Field(..., description="The source code of the function.") - args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).") + args: Dict[str, str] = Field(..., description="The arguments to pass to the tool.") + env_vars: Dict[str, str] = Field(None, description="The environment variables to pass to the tool.") name: Optional[str] = Field(None, description="The name of the tool to run.") source_type: Optional[str] = Field(None, description="The type of the source code.") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index ffc2b212..8ea4d037 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -181,6 +181,7 @@ def run_tool_from_source( tool_source=request.source_code, tool_source_type=request.source_type, tool_args=request.args, + tool_env_vars=request.env_vars, tool_name=request.name, actor=actor, ) diff --git a/letta/server/server.py b/letta/server/server.py index 4b86544a..bbafd213 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1,11 +1,10 @@ # inspecting tools -import json import os import traceback import warnings from abc import abstractmethod from datetime import datetime -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from composio.client import Composio from composio.client.collections import ActionModel, AppModel @@ -1117,22 +1116,17 @@ class SyncServer(Server): def run_tool_from_source( self, actor: User, - tool_args: str, + tool_args: Dict[str, str], tool_source: str, + tool_env_vars: Optional[Dict[str, str]] = None, tool_source_type: Optional[str] = None, tool_name: Optional[str] = None, ) -> ToolReturnMessage: """Run a tool from source code""" - - try: - tool_args_dict = json.loads(tool_args) - except json.JSONDecodeError: - raise ValueError("Invalid JSON string for tool_args") - if tool_source_type is not None and tool_source_type != "python": raise ValueError("Only Python source code is supported at this time") - # NOTE: we're creating a floating Tool object and NOT persiting to DB + # NOTE: we're creating a floating Tool object and NOT persisting to DB tool = Tool( name=tool_name, source_code=tool_source, @@ -1144,7 +1138,9 @@ class SyncServer(Server): # Next, attempt to run the tool with the sandbox try: - sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state) + sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run( + agent_state=agent_state, additional_env_vars=tool_env_vars + ) return ToolReturnMessage( id="null", tool_call_id="null", diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index 93e3e265..dc1fcea3 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -59,22 +59,23 @@ class ToolExecutionSandbox: self.sandbox_config_manager = SandboxConfigManager(tool_settings) self.force_recreate = force_recreate - def run(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult: + def run(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult: """ Run the tool in a sandbox environment. Args: agent_state (Optional[AgentState]): The state of the agent invoking the tool + additional_env_vars (Optional[Dict]): Environment variables to inject into the sandbox Returns: Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state) """ if tool_settings.e2b_api_key: logger.debug(f"Using e2b sandbox to execute {self.tool_name}") - result = self.run_e2b_sandbox(agent_state=agent_state) + result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars) else: logger.debug(f"Using local sandbox to execute {self.tool_name}") - result = self.run_local_dir_sandbox(agent_state=agent_state) + result = self.run_local_dir_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars) # Log out any stdout/stderr from the tool run logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n") @@ -98,19 +99,25 @@ class ToolExecutionSandbox: os.environ.clear() os.environ.update(original_env) # Restore original environment variables - def run_local_dir_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult: + def run_local_dir_sandbox( + self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None + ) -> SandboxRunResult: sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user) local_configs = sbx_config.get_local_config() # Get environment variables for the sandbox - env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) env = os.environ.copy() + env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) env.update(env_vars) # Get environment variables for this agent specifically if agent_state: env.update(agent_state.get_agent_env_vars_as_dict()) + # Finally, get any that are passed explicitly into the `run` function call + if additional_env_vars: + env.update(additional_env_vars) + # Safety checks if not os.path.isdir(local_configs.sandbox_dir): raise FileNotFoundError(f"Sandbox directory does not exist: {local_configs.sandbox_dir}") @@ -277,7 +284,7 @@ class ToolExecutionSandbox: # e2b sandbox specific functions - def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult: + def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult: sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user) sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config) if not sbx or self.force_recreate: @@ -300,6 +307,10 @@ class ToolExecutionSandbox: if agent_state: env_vars.update(agent_state.get_agent_env_vars_as_dict()) + # Finally, get any that are passed explicitly into the `run` function call + if additional_env_vars: + env_vars.update(additional_env_vars) + code = self.generate_execution_script(agent_state=agent_state) execution = sbx.run_code(code, envs=env_vars) diff --git a/tests/test_server.py b/tests/test_server.py index fe0fcdc4..763400b6 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -687,6 +687,18 @@ def ingest(message: str): ''' +EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR = ''' +def ingest(): + """ + Ingest a message into the system. + + Returns: + str: The result of ingesting the message. + """ + import os + return os.getenv("secret") +''' + EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR = ''' def util_do_nothing(): @@ -721,7 +733,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", - tool_args=json.dumps({"message": "Hello, world!"}), + tool_args={"message": "Hello, world!"}, # tool_name="ingest", ) print(result) @@ -730,11 +742,24 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): assert not result.stdout assert not result.stderr + result = server.run_tool_from_source( + actor=user, + tool_source=EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR, + tool_source_type="python", + tool_args={}, + tool_env_vars={"secret": "banana"}, + ) + print(result) + assert result.status == "success" + assert result.tool_return == "banana", result.tool_return + assert not result.stdout + assert not result.stderr + result = server.run_tool_from_source( actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", - tool_args=json.dumps({"message": "Well well well"}), + tool_args={"message": "Well well well"}, # tool_name="ingest", ) print(result) @@ -747,7 +772,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", - tool_args=json.dumps({"bad_arg": "oh no"}), + tool_args={"bad_arg": "oh no"}, # tool_name="ingest", ) print(result) @@ -763,7 +788,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", - tool_args=json.dumps({"message": "Well well well"}), + tool_args={"message": "Well well well"}, # tool_name="ingest", ) print(result) @@ -778,7 +803,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", - tool_args=json.dumps({"message": "Well well well"}), + tool_args={"message": "Well well well"}, tool_name="ingest", ) print(result) @@ -793,7 +818,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", - tool_args=json.dumps({}), + tool_args={}, tool_name="util_do_nothing", ) print(result)