diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index b8a3001b..506adcb3 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -57,11 +57,19 @@ class AsyncToolSandboxBase(ABC): f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}" ) + # Check for reserved keyword arguments + tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name) + # TODO: deprecate this - if "agent_state" in parse_function_arguments(self.tool.source_code, self.tool.name): + if "agent_state" in tool_arguments: self.inject_agent_state = True else: self.inject_agent_state = False + + # Check for Letta client and agent_id injection + self.inject_letta_client = "letta_client" in tool_arguments or "client" in tool_arguments + self.inject_agent_id = "agent_id" in tool_arguments + self.is_async_function = self._detect_async_function() self._initialized = True @@ -112,12 +120,16 @@ class AsyncToolSandboxBase(ABC): tool_args += self.initialize_param(param, self.args[param]) agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None + agent_id = agent_state.id if agent_state else None code = self._render_sandbox_code( future_import=future_import, inject_agent_state=self.inject_agent_state, + inject_letta_client=self.inject_letta_client, + inject_agent_id=self.inject_agent_id, schema_imports=schema_code or "", agent_state_pickle=agent_state_pickle, + agent_id=agent_id, tool_args=tool_args, tool_source_code=self.tool.source_code, local_sandbox_result_var_name=self.LOCAL_SANDBOX_RESULT_VAR_NAME, @@ -133,8 +145,11 @@ class AsyncToolSandboxBase(ABC): *, future_import: bool, inject_agent_state: bool, + inject_letta_client: bool, + inject_agent_id: bool, schema_imports: str, agent_state_pickle: bytes | None, + agent_id: str | None, tool_args: str, tool_source_code: str, local_sandbox_result_var_name: str, @@ -162,6 +177,10 @@ class AsyncToolSandboxBase(ABC): if inject_agent_state: lines.extend(["import letta", "from letta import *"]) # noqa: F401 + # Import Letta client if needed + if inject_letta_client: + lines.append("from letta_client import Letta") + if schema_imports: lines.append(schema_imports.rstrip()) @@ -170,6 +189,34 @@ class AsyncToolSandboxBase(ABC): else: lines.append("agent_state = None") + # Initialize Letta client if needed + if inject_letta_client: + from letta.settings import settings + + lines.extend( + [ + "# Initialize Letta client for tool execution", + "letta_client = Letta(", + f" base_url={repr(settings.default_base_url)},", + f" token={repr(settings.default_token)}", + ")", + "# Compatibility shim for client.agents.get", + "try:", + " _agents = letta_client.agents", + " if not hasattr(_agents, 'get') and hasattr(_agents, 'retrieve'):", + " setattr(_agents, 'get', _agents.retrieve)", + "except Exception:", + " pass", + ] + ) + + # Set agent_id if needed + if inject_agent_id: + if agent_id: + lines.append(f"agent_id = {repr(agent_id)}") + else: + lines.append("agent_id = None") + if tool_args: lines.append(tool_args.rstrip()) @@ -286,9 +333,22 @@ class AsyncToolSandboxBase(ABC): kwargs.append(name) param_list = [f"{arg}={arg}" for arg in kwargs] + + # Add reserved keyword arguments if self.inject_agent_state: param_list.append("agent_state=agent_state") + if self.inject_letta_client: + # Check if the function expects 'client' or 'letta_client' + tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name) + if "client" in tool_arguments: + param_list.append("client=letta_client") + elif "letta_client" in tool_arguments: + param_list.append("letta_client=letta_client") + + if self.inject_agent_id: + param_list.append("agent_id=agent_id") + params = ", ".join(param_list) func_call_str = self.tool.name + "(" + params + ")" return func_call_str diff --git a/letta/settings.py b/letta/settings.py index f62bd56d..c87dc1bd 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -329,6 +329,10 @@ class Settings(BaseSettings): file_processing_timeout_minutes: int = 30 file_processing_timeout_error_message: str = "File processing timed out after {} minutes. Please try again." + # Letta client settings for tool execution + default_base_url: str = Field(default="http://localhost:8283", description="Default base URL for Letta client in tool execution") + default_token: Optional[str] = Field(default=None, description="Default token for Letta client in tool execution") + # enabling letta_agent_v1 architecture use_letta_v1_agent: bool = False diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 4194389d..fd1f989b 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -12,16 +12,21 @@ from letta.functions.function_sets.base import core_memory_append, core_memory_r from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.block import CreateBlock +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ToolType from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate +from letta.schemas.llm_config import LLMConfig from letta.schemas.organization import Organization from letta.schemas.pip_requirement import PipRequirement from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate +from letta.schemas.tool import Tool as PydanticTool from letta.schemas.user import User from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager +from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal from letta.services.user_manager import UserManager from tests.helpers.utils import create_tool_from_func @@ -48,14 +53,14 @@ def server(): @pytest.fixture(autouse=True) -def clear_tables(): +async def clear_tables(): """Fixture to clear the organization table before each test.""" - from letta.server.db import db_context + from letta.server.db import db_registry - with db_context() as session: - session.execute(delete(SandboxEnvironmentVariable)) - session.execute(delete(SandboxConfig)) - session.commit() # Commit the deletion + async with db_registry.async_session() as session: + await session.execute(delete(SandboxEnvironmentVariable)) + await session.execute(delete(SandboxConfig)) + await session.commit() # Commit the deletion @pytest.fixture @@ -232,7 +237,7 @@ def agent_state(server): @pytest.fixture -def custom_test_sandbox_config(test_user): +async def custom_test_sandbox_config(test_user): """ Fixture to create a consistent local sandbox configuration for tests. @@ -314,7 +319,8 @@ def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user): @pytest.mark.local_sandbox -def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user): +@pytest.mark.asyncio +async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user): manager = SandboxConfigManager() # Make a custom local sandbox config @@ -394,6 +400,136 @@ def test_local_sandbox_with_venv_and_warnings_does_not_error(disable_e2b_api_key assert result.func_return == "Hello World" +@pytest.mark.local_sandbox +@pytest.mark.asyncio +async def test_tool_with_client_injection(disable_e2b_api_key, server: SyncServer, test_user): + """Test that tools can access injected letta_client and agent_id to modify agent blocks.""" + + # Create a tool that uses the injected client and agent_id to actually clear a memory block + memory_clear_source = ''' +def memory_clear(label: str, agent_id: str, client: "Letta"): + """Test tool that clears a memory block using the injected client. + + Args: + label: The label of the memory block to clear + agent_id: The agent's ID (injected by Letta system) + client: The Letta client instance (injected by Letta system) + """ + # Verify that agent_id was injected + if not agent_id or not isinstance(agent_id, str): + return f"ERROR: agent_id not properly injected: {agent_id}" + + # Verify that client was injected + if not client or not hasattr(client, 'agents'): + return f"ERROR: client not properly injected: {client}" + + # Use the injected client to actually clear the memory block + try: + # Get the agent using the injected client + agent = client.agents.get(agent_id=agent_id) + + # Find the block with the specified label + blocks = agent.memory.blocks + target_block = None + for block in blocks: + if block.label == label: + target_block = block + break + + if not target_block: + return f"ERROR: Block with label '{label}' not found" + + # Clear the block by setting its value to empty string + original_value = target_block.value + client.agents.update_block( + agent_id=agent_id, + block_id=target_block.id, + value="" + ) + + return f"SUCCESS: Cleared block '{label}' (was {len(original_value)} chars, now empty)" + except Exception as e: + return f"ERROR: Failed to clear block: {str(e)}" +''' + + # Create the tool + memory_clear_tool = PydanticTool( + name="memory_clear", + description="Clear a memory block by setting its value to empty string", + source_code=memory_clear_source, + source_type="python", + tool_type=ToolType.CUSTOM, + ) + + # Manually provide schema since client is an injected parameter + memory_clear_tool.json_schema = { + "name": "memory_clear", + "description": "Clear a memory block by setting its value to empty string", + "parameters": { + "type": "object", + "properties": { + "label": {"type": "string", "description": "The label of the memory block to clear"} + # agent_id and client are injected, not passed by the user + }, + "required": ["label"], + }, + } + + # Create the tool in the system + created_tool = await server.tool_manager.create_tool_async(memory_clear_tool, actor=test_user) + + # Create an agent with a memory block + agent = await server.agent_manager.create_agent_async( + agent_create=CreateAgent( + name="test_agent_with_blocks", + memory_blocks=[{"label": "test_block", "value": "Initial test content that should be cleared"}], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tools=["memory_clear"], + include_base_tools=False, + ), + actor=test_user, + ) + + # Verify the tool is attached + assert created_tool.id in [t.id for t in agent.tools] + + # Simulate tool execution with the reserved keywords + # This would normally happen during agent execution, but we'll test the tool directly + # Create the sandbox for the tool + sandbox = AsyncToolSandboxLocal(tool_name="memory_clear", args={"label": "test_block"}, user=test_user, tool_object=created_tool) + + # Initialize the sandbox to detect reserved keywords + await sandbox._init_async() + + # Verify that the tool correctly detects the need for injection + assert sandbox.inject_letta_client == True # Should detect 'client' parameter + assert sandbox.inject_agent_id == True # Should detect 'agent_id' parameter + + # Generate the execution script to verify injection code is present + script = await sandbox.generate_execution_script(agent_state=agent) + + # Verify the script contains Letta client initialization + assert "from letta import Letta" in script or "import letta" in script.lower() + assert "agent_id =" in script + + # Actually execute the tool using the sandbox + result = await sandbox.run(agent_state=agent) + + # Verify execution was successful + assert result.status == "success", f"Tool execution failed: {result.stderr}" + assert "SUCCESS:" in result.func_return, f"Tool didn't execute successfully: {result.func_return}" + assert "Cleared block 'test_block'" in result.func_return, f"Block not cleared: {result.func_return}" + assert "was 44 chars" in result.func_return, f"Original length not reported correctly: {result.func_return}" + + # check the block status after the tool execution + agent_state = await server.agent_manager.get_agent_by_id_async(agent.id, actor=test_user) + assert agent_state.memory.get_block("test_block").value == "" + + # Clean up + await server.agent_manager.delete_agent_async(agent_id=agent.id, actor=test_user) + + @pytest.mark.e2b_sandbox def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user): sandbox = ToolExecutionSandbox(always_err_tool.name, {}, user=test_user)