feat: inject letta_client and agent_id into local sandbox (#5192)
This commit is contained in:
committed by
Caren Thomas
parent
5a475fd1a5
commit
305bb8c8f7
@@ -57,11 +57,19 @@ class AsyncToolSandboxBase(ABC):
|
||||
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
|
||||
)
|
||||
|
||||
# Check for reserved keyword arguments
|
||||
tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name)
|
||||
|
||||
# TODO: deprecate this
|
||||
if "agent_state" in parse_function_arguments(self.tool.source_code, self.tool.name):
|
||||
if "agent_state" in tool_arguments:
|
||||
self.inject_agent_state = True
|
||||
else:
|
||||
self.inject_agent_state = False
|
||||
|
||||
# Check for Letta client and agent_id injection
|
||||
self.inject_letta_client = "letta_client" in tool_arguments or "client" in tool_arguments
|
||||
self.inject_agent_id = "agent_id" in tool_arguments
|
||||
|
||||
self.is_async_function = self._detect_async_function()
|
||||
self._initialized = True
|
||||
|
||||
@@ -112,12 +120,16 @@ class AsyncToolSandboxBase(ABC):
|
||||
tool_args += self.initialize_param(param, self.args[param])
|
||||
|
||||
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
|
||||
agent_id = agent_state.id if agent_state else None
|
||||
|
||||
code = self._render_sandbox_code(
|
||||
future_import=future_import,
|
||||
inject_agent_state=self.inject_agent_state,
|
||||
inject_letta_client=self.inject_letta_client,
|
||||
inject_agent_id=self.inject_agent_id,
|
||||
schema_imports=schema_code or "",
|
||||
agent_state_pickle=agent_state_pickle,
|
||||
agent_id=agent_id,
|
||||
tool_args=tool_args,
|
||||
tool_source_code=self.tool.source_code,
|
||||
local_sandbox_result_var_name=self.LOCAL_SANDBOX_RESULT_VAR_NAME,
|
||||
@@ -133,8 +145,11 @@ class AsyncToolSandboxBase(ABC):
|
||||
*,
|
||||
future_import: bool,
|
||||
inject_agent_state: bool,
|
||||
inject_letta_client: bool,
|
||||
inject_agent_id: bool,
|
||||
schema_imports: str,
|
||||
agent_state_pickle: bytes | None,
|
||||
agent_id: str | None,
|
||||
tool_args: str,
|
||||
tool_source_code: str,
|
||||
local_sandbox_result_var_name: str,
|
||||
@@ -162,6 +177,10 @@ class AsyncToolSandboxBase(ABC):
|
||||
if inject_agent_state:
|
||||
lines.extend(["import letta", "from letta import *"]) # noqa: F401
|
||||
|
||||
# Import Letta client if needed
|
||||
if inject_letta_client:
|
||||
lines.append("from letta_client import Letta")
|
||||
|
||||
if schema_imports:
|
||||
lines.append(schema_imports.rstrip())
|
||||
|
||||
@@ -170,6 +189,34 @@ class AsyncToolSandboxBase(ABC):
|
||||
else:
|
||||
lines.append("agent_state = None")
|
||||
|
||||
# Initialize Letta client if needed
|
||||
if inject_letta_client:
|
||||
from letta.settings import settings
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"# Initialize Letta client for tool execution",
|
||||
"letta_client = Letta(",
|
||||
f" base_url={repr(settings.default_base_url)},",
|
||||
f" token={repr(settings.default_token)}",
|
||||
")",
|
||||
"# Compatibility shim for client.agents.get",
|
||||
"try:",
|
||||
" _agents = letta_client.agents",
|
||||
" if not hasattr(_agents, 'get') and hasattr(_agents, 'retrieve'):",
|
||||
" setattr(_agents, 'get', _agents.retrieve)",
|
||||
"except Exception:",
|
||||
" pass",
|
||||
]
|
||||
)
|
||||
|
||||
# Set agent_id if needed
|
||||
if inject_agent_id:
|
||||
if agent_id:
|
||||
lines.append(f"agent_id = {repr(agent_id)}")
|
||||
else:
|
||||
lines.append("agent_id = None")
|
||||
|
||||
if tool_args:
|
||||
lines.append(tool_args.rstrip())
|
||||
|
||||
@@ -286,9 +333,22 @@ class AsyncToolSandboxBase(ABC):
|
||||
kwargs.append(name)
|
||||
|
||||
param_list = [f"{arg}={arg}" for arg in kwargs]
|
||||
|
||||
# Add reserved keyword arguments
|
||||
if self.inject_agent_state:
|
||||
param_list.append("agent_state=agent_state")
|
||||
|
||||
if self.inject_letta_client:
|
||||
# Check if the function expects 'client' or 'letta_client'
|
||||
tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name)
|
||||
if "client" in tool_arguments:
|
||||
param_list.append("client=letta_client")
|
||||
elif "letta_client" in tool_arguments:
|
||||
param_list.append("letta_client=letta_client")
|
||||
|
||||
if self.inject_agent_id:
|
||||
param_list.append("agent_id=agent_id")
|
||||
|
||||
params = ", ".join(param_list)
|
||||
func_call_str = self.tool.name + "(" + params + ")"
|
||||
return func_call_str
|
||||
|
||||
@@ -329,6 +329,10 @@ class Settings(BaseSettings):
|
||||
file_processing_timeout_minutes: int = 30
|
||||
file_processing_timeout_error_message: str = "File processing timed out after {} minutes. Please try again."
|
||||
|
||||
# Letta client settings for tool execution
|
||||
default_base_url: str = Field(default="http://localhost:8283", description="Default base URL for Letta client in tool execution")
|
||||
default_token: Optional[str] = Field(default=None, description="Default token for Letta client in tool execution")
|
||||
|
||||
# enabling letta_agent_v1 architecture
|
||||
use_letta_v1_agent: bool = False
|
||||
|
||||
|
||||
@@ -12,16 +12,21 @@ from letta.functions.function_sets.base import core_memory_append, core_memory_r
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolType
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal
|
||||
from letta.services.user_manager import UserManager
|
||||
from tests.helpers.utils import create_tool_from_func
|
||||
|
||||
@@ -48,14 +53,14 @@ def server():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
async def clear_tables():
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
from letta.server.db import db_context
|
||||
from letta.server.db import db_registry
|
||||
|
||||
with db_context() as session:
|
||||
session.execute(delete(SandboxEnvironmentVariable))
|
||||
session.execute(delete(SandboxConfig))
|
||||
session.commit() # Commit the deletion
|
||||
async with db_registry.async_session() as session:
|
||||
await session.execute(delete(SandboxEnvironmentVariable))
|
||||
await session.execute(delete(SandboxConfig))
|
||||
await session.commit() # Commit the deletion
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -232,7 +237,7 @@ def agent_state(server):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_test_sandbox_config(test_user):
|
||||
async def custom_test_sandbox_config(test_user):
|
||||
"""
|
||||
Fixture to create a consistent local sandbox configuration for tests.
|
||||
|
||||
@@ -314,7 +319,8 @@ def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user):
|
||||
|
||||
|
||||
@pytest.mark.local_sandbox
|
||||
def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user):
|
||||
manager = SandboxConfigManager()
|
||||
|
||||
# Make a custom local sandbox config
|
||||
@@ -394,6 +400,136 @@ def test_local_sandbox_with_venv_and_warnings_does_not_error(disable_e2b_api_key
|
||||
assert result.func_return == "Hello World"
|
||||
|
||||
|
||||
@pytest.mark.local_sandbox
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_with_client_injection(disable_e2b_api_key, server: SyncServer, test_user):
|
||||
"""Test that tools can access injected letta_client and agent_id to modify agent blocks."""
|
||||
|
||||
# Create a tool that uses the injected client and agent_id to actually clear a memory block
|
||||
memory_clear_source = '''
|
||||
def memory_clear(label: str, agent_id: str, client: "Letta"):
|
||||
"""Test tool that clears a memory block using the injected client.
|
||||
|
||||
Args:
|
||||
label: The label of the memory block to clear
|
||||
agent_id: The agent's ID (injected by Letta system)
|
||||
client: The Letta client instance (injected by Letta system)
|
||||
"""
|
||||
# Verify that agent_id was injected
|
||||
if not agent_id or not isinstance(agent_id, str):
|
||||
return f"ERROR: agent_id not properly injected: {agent_id}"
|
||||
|
||||
# Verify that client was injected
|
||||
if not client or not hasattr(client, 'agents'):
|
||||
return f"ERROR: client not properly injected: {client}"
|
||||
|
||||
# Use the injected client to actually clear the memory block
|
||||
try:
|
||||
# Get the agent using the injected client
|
||||
agent = client.agents.get(agent_id=agent_id)
|
||||
|
||||
# Find the block with the specified label
|
||||
blocks = agent.memory.blocks
|
||||
target_block = None
|
||||
for block in blocks:
|
||||
if block.label == label:
|
||||
target_block = block
|
||||
break
|
||||
|
||||
if not target_block:
|
||||
return f"ERROR: Block with label '{label}' not found"
|
||||
|
||||
# Clear the block by setting its value to empty string
|
||||
original_value = target_block.value
|
||||
client.agents.update_block(
|
||||
agent_id=agent_id,
|
||||
block_id=target_block.id,
|
||||
value=""
|
||||
)
|
||||
|
||||
return f"SUCCESS: Cleared block '{label}' (was {len(original_value)} chars, now empty)"
|
||||
except Exception as e:
|
||||
return f"ERROR: Failed to clear block: {str(e)}"
|
||||
'''
|
||||
|
||||
# Create the tool
|
||||
memory_clear_tool = PydanticTool(
|
||||
name="memory_clear",
|
||||
description="Clear a memory block by setting its value to empty string",
|
||||
source_code=memory_clear_source,
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
|
||||
# Manually provide schema since client is an injected parameter
|
||||
memory_clear_tool.json_schema = {
|
||||
"name": "memory_clear",
|
||||
"description": "Clear a memory block by setting its value to empty string",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {"type": "string", "description": "The label of the memory block to clear"}
|
||||
# agent_id and client are injected, not passed by the user
|
||||
},
|
||||
"required": ["label"],
|
||||
},
|
||||
}
|
||||
|
||||
# Create the tool in the system
|
||||
created_tool = await server.tool_manager.create_tool_async(memory_clear_tool, actor=test_user)
|
||||
|
||||
# Create an agent with a memory block
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="test_agent_with_blocks",
|
||||
memory_blocks=[{"label": "test_block", "value": "Initial test content that should be cleared"}],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tools=["memory_clear"],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=test_user,
|
||||
)
|
||||
|
||||
# Verify the tool is attached
|
||||
assert created_tool.id in [t.id for t in agent.tools]
|
||||
|
||||
# Simulate tool execution with the reserved keywords
|
||||
# This would normally happen during agent execution, but we'll test the tool directly
|
||||
# Create the sandbox for the tool
|
||||
sandbox = AsyncToolSandboxLocal(tool_name="memory_clear", args={"label": "test_block"}, user=test_user, tool_object=created_tool)
|
||||
|
||||
# Initialize the sandbox to detect reserved keywords
|
||||
await sandbox._init_async()
|
||||
|
||||
# Verify that the tool correctly detects the need for injection
|
||||
assert sandbox.inject_letta_client == True # Should detect 'client' parameter
|
||||
assert sandbox.inject_agent_id == True # Should detect 'agent_id' parameter
|
||||
|
||||
# Generate the execution script to verify injection code is present
|
||||
script = await sandbox.generate_execution_script(agent_state=agent)
|
||||
|
||||
# Verify the script contains Letta client initialization
|
||||
assert "from letta import Letta" in script or "import letta" in script.lower()
|
||||
assert "agent_id =" in script
|
||||
|
||||
# Actually execute the tool using the sandbox
|
||||
result = await sandbox.run(agent_state=agent)
|
||||
|
||||
# Verify execution was successful
|
||||
assert result.status == "success", f"Tool execution failed: {result.stderr}"
|
||||
assert "SUCCESS:" in result.func_return, f"Tool didn't execute successfully: {result.func_return}"
|
||||
assert "Cleared block 'test_block'" in result.func_return, f"Block not cleared: {result.func_return}"
|
||||
assert "was 44 chars" in result.func_return, f"Original length not reported correctly: {result.func_return}"
|
||||
|
||||
# check the block status after the tool execution
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent.id, actor=test_user)
|
||||
assert agent_state.memory.get_block("test_block").value == ""
|
||||
|
||||
# Clean up
|
||||
await server.agent_manager.delete_agent_async(agent_id=agent.id, actor=test_user)
|
||||
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user):
|
||||
sandbox = ToolExecutionSandbox(always_err_tool.name, {}, user=test_user)
|
||||
|
||||
Reference in New Issue
Block a user