feat: inject letta_client and agent_id into local sandbox (#5192)

This commit is contained in:
Sarah Wooders
2025-10-17 14:40:42 -07:00
committed by Caren Thomas
parent 5a475fd1a5
commit 305bb8c8f7
3 changed files with 209 additions and 9 deletions

View File

@@ -57,11 +57,19 @@ class AsyncToolSandboxBase(ABC):
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
# Check for reserved keyword arguments
tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name)
# TODO: deprecate this
if "agent_state" in parse_function_arguments(self.tool.source_code, self.tool.name):
if "agent_state" in tool_arguments:
self.inject_agent_state = True
else:
self.inject_agent_state = False
# Check for Letta client and agent_id injection
self.inject_letta_client = "letta_client" in tool_arguments or "client" in tool_arguments
self.inject_agent_id = "agent_id" in tool_arguments
self.is_async_function = self._detect_async_function()
self._initialized = True
@@ -112,12 +120,16 @@ class AsyncToolSandboxBase(ABC):
tool_args += self.initialize_param(param, self.args[param])
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
agent_id = agent_state.id if agent_state else None
code = self._render_sandbox_code(
future_import=future_import,
inject_agent_state=self.inject_agent_state,
inject_letta_client=self.inject_letta_client,
inject_agent_id=self.inject_agent_id,
schema_imports=schema_code or "",
agent_state_pickle=agent_state_pickle,
agent_id=agent_id,
tool_args=tool_args,
tool_source_code=self.tool.source_code,
local_sandbox_result_var_name=self.LOCAL_SANDBOX_RESULT_VAR_NAME,
@@ -133,8 +145,11 @@ class AsyncToolSandboxBase(ABC):
*,
future_import: bool,
inject_agent_state: bool,
inject_letta_client: bool,
inject_agent_id: bool,
schema_imports: str,
agent_state_pickle: bytes | None,
agent_id: str | None,
tool_args: str,
tool_source_code: str,
local_sandbox_result_var_name: str,
@@ -162,6 +177,10 @@ class AsyncToolSandboxBase(ABC):
if inject_agent_state:
lines.extend(["import letta", "from letta import *"]) # noqa: F401
# Import Letta client if needed
if inject_letta_client:
lines.append("from letta_client import Letta")
if schema_imports:
lines.append(schema_imports.rstrip())
@@ -170,6 +189,34 @@ class AsyncToolSandboxBase(ABC):
else:
lines.append("agent_state = None")
# Initialize Letta client if needed
if inject_letta_client:
from letta.settings import settings
lines.extend(
[
"# Initialize Letta client for tool execution",
"letta_client = Letta(",
f" base_url={repr(settings.default_base_url)},",
f" token={repr(settings.default_token)}",
")",
"# Compatibility shim for client.agents.get",
"try:",
" _agents = letta_client.agents",
" if not hasattr(_agents, 'get') and hasattr(_agents, 'retrieve'):",
" setattr(_agents, 'get', _agents.retrieve)",
"except Exception:",
" pass",
]
)
# Set agent_id if needed
if inject_agent_id:
if agent_id:
lines.append(f"agent_id = {repr(agent_id)}")
else:
lines.append("agent_id = None")
if tool_args:
lines.append(tool_args.rstrip())
@@ -286,9 +333,22 @@ class AsyncToolSandboxBase(ABC):
kwargs.append(name)
param_list = [f"{arg}={arg}" for arg in kwargs]
# Add reserved keyword arguments
if self.inject_agent_state:
param_list.append("agent_state=agent_state")
if self.inject_letta_client:
# Check if the function expects 'client' or 'letta_client'
tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name)
if "client" in tool_arguments:
param_list.append("client=letta_client")
elif "letta_client" in tool_arguments:
param_list.append("letta_client=letta_client")
if self.inject_agent_id:
param_list.append("agent_id=agent_id")
params = ", ".join(param_list)
func_call_str = self.tool.name + "(" + params + ")"
return func_call_str

View File

@@ -329,6 +329,10 @@ class Settings(BaseSettings):
file_processing_timeout_minutes: int = 30
file_processing_timeout_error_message: str = "File processing timed out after {} minutes. Please try again."
# Letta client settings for tool execution
default_base_url: str = Field(default="http://localhost:8283", description="Default base URL for Letta client in tool execution")
default_token: Optional[str] = Field(default=None, description="Default token for Letta client in tool execution")
# enabling letta_agent_v1 architecture
use_letta_v1_agent: bool = False

View File

@@ -12,16 +12,21 @@ from letta.functions.function_sets.base import core_memory_append, core_memory_r
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolType
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.organization import Organization
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal
from letta.services.user_manager import UserManager
from tests.helpers.utils import create_tool_from_func
@@ -48,14 +53,14 @@ def server():
@pytest.fixture(autouse=True)
def clear_tables():
async def clear_tables():
"""Fixture to clear the organization table before each test."""
from letta.server.db import db_context
from letta.server.db import db_registry
with db_context() as session:
session.execute(delete(SandboxEnvironmentVariable))
session.execute(delete(SandboxConfig))
session.commit() # Commit the deletion
async with db_registry.async_session() as session:
await session.execute(delete(SandboxEnvironmentVariable))
await session.execute(delete(SandboxConfig))
await session.commit() # Commit the deletion
@pytest.fixture
@@ -232,7 +237,7 @@ def agent_state(server):
@pytest.fixture
def custom_test_sandbox_config(test_user):
async def custom_test_sandbox_config(test_user):
"""
Fixture to create a consistent local sandbox configuration for tests.
@@ -314,7 +319,8 @@ def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user):
@pytest.mark.local_sandbox
def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user):
@pytest.mark.asyncio
async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user):
manager = SandboxConfigManager()
# Make a custom local sandbox config
@@ -394,6 +400,136 @@ def test_local_sandbox_with_venv_and_warnings_does_not_error(disable_e2b_api_key
assert result.func_return == "Hello World"
@pytest.mark.local_sandbox
@pytest.mark.asyncio
async def test_tool_with_client_injection(disable_e2b_api_key, server: SyncServer, test_user):
"""Test that tools can access injected letta_client and agent_id to modify agent blocks."""
# Create a tool that uses the injected client and agent_id to actually clear a memory block
memory_clear_source = '''
def memory_clear(label: str, agent_id: str, client: "Letta"):
"""Test tool that clears a memory block using the injected client.
Args:
label: The label of the memory block to clear
agent_id: The agent's ID (injected by Letta system)
client: The Letta client instance (injected by Letta system)
"""
# Verify that agent_id was injected
if not agent_id or not isinstance(agent_id, str):
return f"ERROR: agent_id not properly injected: {agent_id}"
# Verify that client was injected
if not client or not hasattr(client, 'agents'):
return f"ERROR: client not properly injected: {client}"
# Use the injected client to actually clear the memory block
try:
# Get the agent using the injected client
agent = client.agents.get(agent_id=agent_id)
# Find the block with the specified label
blocks = agent.memory.blocks
target_block = None
for block in blocks:
if block.label == label:
target_block = block
break
if not target_block:
return f"ERROR: Block with label '{label}' not found"
# Clear the block by setting its value to empty string
original_value = target_block.value
client.agents.update_block(
agent_id=agent_id,
block_id=target_block.id,
value=""
)
return f"SUCCESS: Cleared block '{label}' (was {len(original_value)} chars, now empty)"
except Exception as e:
return f"ERROR: Failed to clear block: {str(e)}"
'''
# Create the tool
memory_clear_tool = PydanticTool(
name="memory_clear",
description="Clear a memory block by setting its value to empty string",
source_code=memory_clear_source,
source_type="python",
tool_type=ToolType.CUSTOM,
)
# Manually provide schema since client is an injected parameter
memory_clear_tool.json_schema = {
"name": "memory_clear",
"description": "Clear a memory block by setting its value to empty string",
"parameters": {
"type": "object",
"properties": {
"label": {"type": "string", "description": "The label of the memory block to clear"}
# agent_id and client are injected, not passed by the user
},
"required": ["label"],
},
}
# Create the tool in the system
created_tool = await server.tool_manager.create_tool_async(memory_clear_tool, actor=test_user)
# Create an agent with a memory block
agent = await server.agent_manager.create_agent_async(
agent_create=CreateAgent(
name="test_agent_with_blocks",
memory_blocks=[{"label": "test_block", "value": "Initial test content that should be cleared"}],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tools=["memory_clear"],
include_base_tools=False,
),
actor=test_user,
)
# Verify the tool is attached
assert created_tool.id in [t.id for t in agent.tools]
# Simulate tool execution with the reserved keywords
# This would normally happen during agent execution, but we'll test the tool directly
# Create the sandbox for the tool
sandbox = AsyncToolSandboxLocal(tool_name="memory_clear", args={"label": "test_block"}, user=test_user, tool_object=created_tool)
# Initialize the sandbox to detect reserved keywords
await sandbox._init_async()
# Verify that the tool correctly detects the need for injection
assert sandbox.inject_letta_client == True # Should detect 'client' parameter
assert sandbox.inject_agent_id == True # Should detect 'agent_id' parameter
# Generate the execution script to verify injection code is present
script = await sandbox.generate_execution_script(agent_state=agent)
# Verify the script contains Letta client initialization
assert "from letta import Letta" in script or "import letta" in script.lower()
assert "agent_id =" in script
# Actually execute the tool using the sandbox
result = await sandbox.run(agent_state=agent)
# Verify execution was successful
assert result.status == "success", f"Tool execution failed: {result.stderr}"
assert "SUCCESS:" in result.func_return, f"Tool didn't execute successfully: {result.func_return}"
assert "Cleared block 'test_block'" in result.func_return, f"Block not cleared: {result.func_return}"
assert "was 44 chars" in result.func_return, f"Original length not reported correctly: {result.func_return}"
# check the block status after the tool execution
agent_state = await server.agent_manager.get_agent_by_id_async(agent.id, actor=test_user)
assert agent_state.memory.get_block("test_block").value == ""
# Clean up
await server.agent_manager.delete_agent_async(agent_id=agent.id, actor=test_user)
@pytest.mark.e2b_sandbox
def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user):
sandbox = ToolExecutionSandbox(always_err_tool.name, {}, user=test_user)