feat: support client injection for E2B (#6360)

This commit is contained in:
Sarah Wooders
2025-11-25 10:58:17 -08:00
committed by Caren Thomas
parent 1d55a0f4c5
commit f7ade17c4a
3 changed files with 244 additions and 27 deletions

View File

@@ -191,22 +191,26 @@ class AsyncToolSandboxBase(ABC):
# Initialize Letta client if needed
if inject_letta_client:
from letta.settings import settings
lines.extend(
[
"# Initialize Letta client for tool execution",
"letta_client = Letta(",
f" base_url={repr(settings.default_base_url)},",
f" token={repr(settings.default_token)}",
")",
"# Compatibility shim for client.agents.get",
"try:",
" _agents = letta_client.agents",
" if not hasattr(_agents, 'get') and hasattr(_agents, 'retrieve'):",
" setattr(_agents, 'get', _agents.retrieve)",
"except Exception:",
" pass",
"import os",
"letta_client = None",
"if os.getenv('LETTA_API_KEY'):",
" # Check letta_client version to use correct parameter name",
" from packaging import version as pkg_version",
" import letta_client as lc_module",
" lc_version = pkg_version.parse(lc_module.__version__)",
" if lc_version < pkg_version.parse('1.0.0'):",
" letta_client = Letta(",
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
" token=os.getenv('LETTA_API_KEY')",
" )",
" else:",
" letta_client = Letta(",
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
" api_key=os.getenv('LETTA_API_KEY')",
" )",
]
)
@@ -394,4 +398,7 @@ class AsyncToolSandboxBase(ABC):
if additional_env_vars:
env.update(additional_env_vars)
# Filter out None values to prevent subprocess errors
env = {k: v for k, v in env.items() if v is not None}
return env

View File

@@ -80,6 +80,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
if additional_env_vars:
env.update(additional_env_vars)
# Filter out None values to prevent subprocess errors
env = {k: v for k, v in env.items() if v is not None}
# Make sure sandbox directory exists
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
if not await asyncio.to_thread(lambda: os.path.exists(sandbox_dir) and os.path.isdir(sandbox_dir)):

View File

@@ -1,11 +1,16 @@
import os
import secrets
import string
import threading
import time
import uuid
from pathlib import Path
from unittest.mock import patch
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from sqlalchemy import delete
from letta.config import LettaConfig
@@ -19,7 +24,6 @@ from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate
from letta.schemas.user import User
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
@@ -63,19 +67,49 @@ def disable_db_pooling_for_tests():
# Fixtures
@pytest.fixture(scope="module")
def server():
def server_url() -> str:
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
config = LettaConfig.load()
config.save()
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
server = SyncServer(init_with_default_org_and_user=True)
# create user/org
yield server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(autouse=True)
@@ -237,10 +271,22 @@ async def external_codebase_tool(test_user):
@pytest.fixture
async def agent_state(server: SyncServer):
async def agent_state(server_url: str):
"""
Creates and returns an agent state for testing with a pre-configured agent.
Note: This fixture uses the server's internal async API instead of the client API
because the sandbox tests need the full server-side AgentState object with all
its methods (like get_agent_env_vars_as_dict()), not the simplified DTO returned
by the REST API.
"""
from letta.server.server import SyncServer
# Import here to ensure server is running first
server = SyncServer()
await server.init_async(init_with_default_org_and_user=True)
actor = await server.user_manager.create_default_actor_async()
agent_state = await server.create_agent_async(
agent_state_instance = await server.create_agent_async(
CreateAgent(
memory_blocks=[
CreateBlock(
@@ -259,8 +305,7 @@ async def agent_state(server: SyncServer):
),
actor=actor,
)
agent_state.tool_rules = []
yield agent_state
yield agent_state_instance
@pytest.fixture
@@ -1141,3 +1186,165 @@ async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_e
result = await sandbox.run(agent_state=agent_state)
assert wrong_val not in result.func_return
assert correct_val in result.func_return
# Client injection tests
@pytest.fixture
async def list_tools_with_client_tool(test_user):
"""Tool that uses injected client to list tools.
Note: This fixture uses ToolManager directly instead of the client API
because it needs to create a tool with a custom schema that excludes
the 'client' parameter (which is injected by the sandbox, not passed by the LLM).
"""
from letta.schemas.enums import ToolType
from letta.schemas.tool import Tool as PydanticTool
source_code = '''
def list_tools_via_client(client: "Letta") -> str:
"""
List available tools using the injected Letta client.
Args:
client: Letta client instance (injected by system)
Returns:
str: Comma-separated list of tool names
"""
if not client:
return "ERROR: client not injected"
try:
tools = client.tools.list()
tool_names = [tool.name for tool in tools]
return f"Found {len(tool_names)} tools: {', '.join(tool_names)}"
except Exception as e:
return f"ERROR: {str(e)}"
'''
# Create the tool with proper schema (client is injected, not in schema)
tool = PydanticTool(
name="list_tools_via_client",
description="List tools using injected client",
source_code=source_code,
source_type="python",
tool_type=ToolType.CUSTOM,
)
# Manually set schema without 'client' parameter since it's injected
tool.json_schema = {
"name": "list_tools_via_client",
"description": "List tools using injected client",
"parameters": {"type": "object", "properties": {}, "required": []},
}
# Use ToolManager directly for this special case
created_tool = await ToolManager().create_or_update_tool_async(tool, test_user)
yield created_tool
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_with_client_injection(disable_e2b_api_key, list_tools_with_client_tool, test_user, server_url):
"""Test that local sandbox can inject Letta client for tools that need it."""
# Add LETTA_API_KEY to sandbox environment
api_key = os.getenv("LETTA_API_KEY") or "test-key"
base_url = server_url # Use the server_url fixture
# Pass environment variables directly to avoid encryption issues
sandbox_env_vars = {
"LETTA_API_KEY": api_key,
"LETTA_BASE_URL": base_url,
}
# Create the sandbox and verify client injection is detected
sandbox = AsyncToolSandboxLocal(
tool_name=list_tools_with_client_tool.name,
args={},
user=test_user,
tool_object=list_tools_with_client_tool,
sandbox_env_vars=sandbox_env_vars,
)
await sandbox._init_async()
# Verify that client injection was detected
assert sandbox.inject_letta_client is True, "Tool should be detected as needing client injection"
# Generate the execution script to verify client initialization code is present
script = await sandbox.generate_execution_script(agent_state=None)
# Debug: print the script
print("=" * 80)
print("GENERATED SCRIPT:")
print("=" * 80)
print(script)
print("=" * 80)
# Verify the script contains Letta client initialization
assert "from letta_client import Letta" in script, "Script should import Letta client"
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
# Run the tool and verify it works
result = await sandbox.run(agent_state=None)
# The result should either list tools or indicate client wasn't available
assert result.status == "success" or "ERROR" in str(result.func_return), f"Tool execution failed: {result.stderr}"
print("RESULT --------------------------------")
print(result)
assert "Found" in str(result.func_return), f"Tool should list tools when client is available: {result.func_return}"
# Verify client was injected successfully (connection may fail if no server is running)
assert "ERROR: client not injected" not in str(result.func_return), "Client should be injected when LETTA_API_KEY is set"
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_with_client_injection(check_e2b_key_is_set, list_tools_with_client_tool, test_user, server_url):
"""Test that E2B sandbox can inject Letta client for tools that need it."""
# Add LETTA_API_KEY to sandbox environment
api_key = os.getenv("LETTA_API_KEY") or "test-key"
base_url = server_url # Use the server_url fixture
# Pass environment variables directly to avoid encryption issues
sandbox_env_vars = {
"LETTA_API_KEY": api_key,
"LETTA_BASE_URL": base_url,
}
# Create the sandbox and verify client injection is detected
sandbox = AsyncToolSandboxE2B(
tool_name=list_tools_with_client_tool.name,
args={},
user=test_user,
tool_object=list_tools_with_client_tool,
sandbox_env_vars=sandbox_env_vars,
)
await sandbox._init_async()
# Verify that client injection was detected
assert sandbox.inject_letta_client is True, "Tool should be detected as needing client injection"
# Generate the execution script to verify client initialization code is present
script = await sandbox.generate_execution_script(agent_state=None)
# Debug: print the script
print("=" * 80)
print("GENERATED SCRIPT:")
print("=" * 80)
print(script)
print("=" * 80)
# Verify the script contains Letta client initialization
assert "from letta_client import Letta" in script, "Script should import Letta client"
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
# Cannot run the tool since E2B is remote