feat: support client injection for E2B (#6360)
This commit is contained in:
committed by
Caren Thomas
parent
1d55a0f4c5
commit
f7ade17c4a
@@ -191,22 +191,26 @@ class AsyncToolSandboxBase(ABC):
|
||||
|
||||
# Initialize Letta client if needed
|
||||
if inject_letta_client:
|
||||
from letta.settings import settings
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"# Initialize Letta client for tool execution",
|
||||
"letta_client = Letta(",
|
||||
f" base_url={repr(settings.default_base_url)},",
|
||||
f" token={repr(settings.default_token)}",
|
||||
")",
|
||||
"# Compatibility shim for client.agents.get",
|
||||
"try:",
|
||||
" _agents = letta_client.agents",
|
||||
" if not hasattr(_agents, 'get') and hasattr(_agents, 'retrieve'):",
|
||||
" setattr(_agents, 'get', _agents.retrieve)",
|
||||
"except Exception:",
|
||||
" pass",
|
||||
"import os",
|
||||
"letta_client = None",
|
||||
"if os.getenv('LETTA_API_KEY'):",
|
||||
" # Check letta_client version to use correct parameter name",
|
||||
" from packaging import version as pkg_version",
|
||||
" import letta_client as lc_module",
|
||||
" lc_version = pkg_version.parse(lc_module.__version__)",
|
||||
" if lc_version < pkg_version.parse('1.0.0'):",
|
||||
" letta_client = Letta(",
|
||||
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
|
||||
" token=os.getenv('LETTA_API_KEY')",
|
||||
" )",
|
||||
" else:",
|
||||
" letta_client = Letta(",
|
||||
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
|
||||
" api_key=os.getenv('LETTA_API_KEY')",
|
||||
" )",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -394,4 +398,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
if additional_env_vars:
|
||||
env.update(additional_env_vars)
|
||||
|
||||
# Filter out None values to prevent subprocess errors
|
||||
env = {k: v for k, v in env.items() if v is not None}
|
||||
|
||||
return env
|
||||
|
||||
@@ -80,6 +80,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
if additional_env_vars:
|
||||
env.update(additional_env_vars)
|
||||
|
||||
# Filter out None values to prevent subprocess errors
|
||||
env = {k: v for k, v in env.items() if v is not None}
|
||||
|
||||
# Make sure sandbox directory exists
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
|
||||
if not await asyncio.to_thread(lambda: os.path.exists(sandbox_dir) and os.path.isdir(sandbox_dir)):
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
@@ -19,7 +24,6 @@ from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
@@ -63,19 +67,49 @@ def disable_db_pooling_for_tests():
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Creates a SyncServer instance for testing.
|
||||
|
||||
Loads and saves config to ensure proper initialization.
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it's accepting connections.
|
||||
"""
|
||||
config = LettaConfig.load()
|
||||
|
||||
config.save()
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
server = SyncServer(init_with_default_org_and_user=True)
|
||||
# create user/org
|
||||
yield server
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -237,10 +271,22 @@ async def external_codebase_tool(test_user):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_state(server: SyncServer):
|
||||
async def agent_state(server_url: str):
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
|
||||
Note: This fixture uses the server's internal async API instead of the client API
|
||||
because the sandbox tests need the full server-side AgentState object with all
|
||||
its methods (like get_agent_env_vars_as_dict()), not the simplified DTO returned
|
||||
by the REST API.
|
||||
"""
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# Import here to ensure server is running first
|
||||
server = SyncServer()
|
||||
await server.init_async(init_with_default_org_and_user=True)
|
||||
actor = await server.user_manager.create_default_actor_async()
|
||||
agent_state = await server.create_agent_async(
|
||||
agent_state_instance = await server.create_agent_async(
|
||||
CreateAgent(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
@@ -259,8 +305,7 @@ async def agent_state(server: SyncServer):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_state.tool_rules = []
|
||||
yield agent_state
|
||||
yield agent_state_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1141,3 +1186,165 @@ async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_e
|
||||
result = await sandbox.run(agent_state=agent_state)
|
||||
assert wrong_val not in result.func_return
|
||||
assert correct_val in result.func_return
|
||||
|
||||
|
||||
# Client injection tests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def list_tools_with_client_tool(test_user):
|
||||
"""Tool that uses injected client to list tools.
|
||||
|
||||
Note: This fixture uses ToolManager directly instead of the client API
|
||||
because it needs to create a tool with a custom schema that excludes
|
||||
the 'client' parameter (which is injected by the sandbox, not passed by the LLM).
|
||||
"""
|
||||
from letta.schemas.enums import ToolType
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
|
||||
source_code = '''
|
||||
def list_tools_via_client(client: "Letta") -> str:
|
||||
"""
|
||||
List available tools using the injected Letta client.
|
||||
|
||||
Args:
|
||||
client: Letta client instance (injected by system)
|
||||
|
||||
Returns:
|
||||
str: Comma-separated list of tool names
|
||||
"""
|
||||
if not client:
|
||||
return "ERROR: client not injected"
|
||||
|
||||
try:
|
||||
tools = client.tools.list()
|
||||
tool_names = [tool.name for tool in tools]
|
||||
return f"Found {len(tool_names)} tools: {', '.join(tool_names)}"
|
||||
except Exception as e:
|
||||
return f"ERROR: {str(e)}"
|
||||
'''
|
||||
|
||||
# Create the tool with proper schema (client is injected, not in schema)
|
||||
tool = PydanticTool(
|
||||
name="list_tools_via_client",
|
||||
description="List tools using injected client",
|
||||
source_code=source_code,
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
|
||||
# Manually set schema without 'client' parameter since it's injected
|
||||
tool.json_schema = {
|
||||
"name": "list_tools_via_client",
|
||||
"description": "List tools using injected client",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
# Use ToolManager directly for this special case
|
||||
created_tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||
yield created_tool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.local_sandbox
|
||||
async def test_local_sandbox_with_client_injection(disable_e2b_api_key, list_tools_with_client_tool, test_user, server_url):
|
||||
"""Test that local sandbox can inject Letta client for tools that need it."""
|
||||
|
||||
# Add LETTA_API_KEY to sandbox environment
|
||||
api_key = os.getenv("LETTA_API_KEY") or "test-key"
|
||||
base_url = server_url # Use the server_url fixture
|
||||
|
||||
# Pass environment variables directly to avoid encryption issues
|
||||
sandbox_env_vars = {
|
||||
"LETTA_API_KEY": api_key,
|
||||
"LETTA_BASE_URL": base_url,
|
||||
}
|
||||
|
||||
# Create the sandbox and verify client injection is detected
|
||||
sandbox = AsyncToolSandboxLocal(
|
||||
tool_name=list_tools_with_client_tool.name,
|
||||
args={},
|
||||
user=test_user,
|
||||
tool_object=list_tools_with_client_tool,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
)
|
||||
|
||||
await sandbox._init_async()
|
||||
|
||||
# Verify that client injection was detected
|
||||
assert sandbox.inject_letta_client is True, "Tool should be detected as needing client injection"
|
||||
|
||||
# Generate the execution script to verify client initialization code is present
|
||||
script = await sandbox.generate_execution_script(agent_state=None)
|
||||
|
||||
# Debug: print the script
|
||||
print("=" * 80)
|
||||
print("GENERATED SCRIPT:")
|
||||
print("=" * 80)
|
||||
print(script)
|
||||
print("=" * 80)
|
||||
|
||||
# Verify the script contains Letta client initialization
|
||||
assert "from letta_client import Letta" in script, "Script should import Letta client"
|
||||
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
|
||||
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
|
||||
|
||||
# Run the tool and verify it works
|
||||
result = await sandbox.run(agent_state=None)
|
||||
|
||||
# The result should either list tools or indicate client wasn't available
|
||||
assert result.status == "success" or "ERROR" in str(result.func_return), f"Tool execution failed: {result.stderr}"
|
||||
|
||||
print("RESULT --------------------------------")
|
||||
print(result)
|
||||
assert "Found" in str(result.func_return), f"Tool should list tools when client is available: {result.func_return}"
|
||||
|
||||
# Verify client was injected successfully (connection may fail if no server is running)
|
||||
assert "ERROR: client not injected" not in str(result.func_return), "Client should be injected when LETTA_API_KEY is set"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.e2b_sandbox
|
||||
async def test_e2b_sandbox_with_client_injection(check_e2b_key_is_set, list_tools_with_client_tool, test_user, server_url):
|
||||
"""Test that E2B sandbox can inject Letta client for tools that need it."""
|
||||
|
||||
# Add LETTA_API_KEY to sandbox environment
|
||||
api_key = os.getenv("LETTA_API_KEY") or "test-key"
|
||||
base_url = server_url # Use the server_url fixture
|
||||
|
||||
# Pass environment variables directly to avoid encryption issues
|
||||
sandbox_env_vars = {
|
||||
"LETTA_API_KEY": api_key,
|
||||
"LETTA_BASE_URL": base_url,
|
||||
}
|
||||
|
||||
# Create the sandbox and verify client injection is detected
|
||||
sandbox = AsyncToolSandboxE2B(
|
||||
tool_name=list_tools_with_client_tool.name,
|
||||
args={},
|
||||
user=test_user,
|
||||
tool_object=list_tools_with_client_tool,
|
||||
sandbox_env_vars=sandbox_env_vars,
|
||||
)
|
||||
|
||||
await sandbox._init_async()
|
||||
|
||||
# Verify that client injection was detected
|
||||
assert sandbox.inject_letta_client is True, "Tool should be detected as needing client injection"
|
||||
|
||||
# Generate the execution script to verify client initialization code is present
|
||||
script = await sandbox.generate_execution_script(agent_state=None)
|
||||
|
||||
# Debug: print the script
|
||||
print("=" * 80)
|
||||
print("GENERATED SCRIPT:")
|
||||
print("=" * 80)
|
||||
print(script)
|
||||
print("=" * 80)
|
||||
|
||||
# Verify the script contains Letta client initialization
|
||||
assert "from letta_client import Letta" in script, "Script should import Letta client"
|
||||
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
|
||||
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
|
||||
|
||||
# Cannot run the tool since E2B is remote
|
||||
|
||||
Reference in New Issue
Block a user