feat: support client injection for E2B (#6360)
This commit is contained in:
committed by
Caren Thomas
parent
1d55a0f4c5
commit
f7ade17c4a
@@ -191,22 +191,26 @@ class AsyncToolSandboxBase(ABC):
|
|||||||
|
|
||||||
# Initialize Letta client if needed
|
# Initialize Letta client if needed
|
||||||
if inject_letta_client:
|
if inject_letta_client:
|
||||||
from letta.settings import settings
|
|
||||||
|
|
||||||
lines.extend(
|
lines.extend(
|
||||||
[
|
[
|
||||||
"# Initialize Letta client for tool execution",
|
"# Initialize Letta client for tool execution",
|
||||||
"letta_client = Letta(",
|
"import os",
|
||||||
f" base_url={repr(settings.default_base_url)},",
|
"letta_client = None",
|
||||||
f" token={repr(settings.default_token)}",
|
"if os.getenv('LETTA_API_KEY'):",
|
||||||
")",
|
" # Check letta_client version to use correct parameter name",
|
||||||
"# Compatibility shim for client.agents.get",
|
" from packaging import version as pkg_version",
|
||||||
"try:",
|
" import letta_client as lc_module",
|
||||||
" _agents = letta_client.agents",
|
" lc_version = pkg_version.parse(lc_module.__version__)",
|
||||||
" if not hasattr(_agents, 'get') and hasattr(_agents, 'retrieve'):",
|
" if lc_version < pkg_version.parse('1.0.0'):",
|
||||||
" setattr(_agents, 'get', _agents.retrieve)",
|
" letta_client = Letta(",
|
||||||
"except Exception:",
|
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
|
||||||
" pass",
|
" token=os.getenv('LETTA_API_KEY')",
|
||||||
|
" )",
|
||||||
|
" else:",
|
||||||
|
" letta_client = Letta(",
|
||||||
|
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
|
||||||
|
" api_key=os.getenv('LETTA_API_KEY')",
|
||||||
|
" )",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -394,4 +398,7 @@ class AsyncToolSandboxBase(ABC):
|
|||||||
if additional_env_vars:
|
if additional_env_vars:
|
||||||
env.update(additional_env_vars)
|
env.update(additional_env_vars)
|
||||||
|
|
||||||
|
# Filter out None values to prevent subprocess errors
|
||||||
|
env = {k: v for k, v in env.items() if v is not None}
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -80,6 +80,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
|||||||
if additional_env_vars:
|
if additional_env_vars:
|
||||||
env.update(additional_env_vars)
|
env.update(additional_env_vars)
|
||||||
|
|
||||||
|
# Filter out None values to prevent subprocess errors
|
||||||
|
env = {k: v for k, v in env.items() if v is not None}
|
||||||
|
|
||||||
# Make sure sandbox directory exists
|
# Make sure sandbox directory exists
|
||||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
|
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
|
||||||
if not await asyncio.to_thread(lambda: os.path.exists(sandbox_dir) and os.path.isdir(sandbox_dir)):
|
if not await asyncio.to_thread(lambda: os.path.exists(sandbox_dir) and os.path.isdir(sandbox_dir)):
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from letta_client import Letta
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
|
|
||||||
from letta.config import LettaConfig
|
from letta.config import LettaConfig
|
||||||
@@ -19,7 +24,6 @@ from letta.schemas.pip_requirement import PipRequirement
|
|||||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate
|
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.server.server import SyncServer
|
|
||||||
from letta.services.organization_manager import OrganizationManager
|
from letta.services.organization_manager import OrganizationManager
|
||||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||||
from letta.services.tool_manager import ToolManager
|
from letta.services.tool_manager import ToolManager
|
||||||
@@ -63,19 +67,49 @@ def disable_db_pooling_for_tests():
|
|||||||
|
|
||||||
# Fixtures
|
# Fixtures
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server_url() -> str:
|
||||||
"""
|
"""
|
||||||
Creates a SyncServer instance for testing.
|
Provides the URL for the Letta server.
|
||||||
|
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||||
Loads and saves config to ensure proper initialization.
|
and polls until it's accepting connections.
|
||||||
"""
|
"""
|
||||||
config = LettaConfig.load()
|
|
||||||
|
|
||||||
config.save()
|
def _run_server() -> None:
|
||||||
|
load_dotenv()
|
||||||
|
from letta.server.rest_api.app import start_server
|
||||||
|
|
||||||
server = SyncServer(init_with_default_org_and_user=True)
|
start_server(debug=True)
|
||||||
# create user/org
|
|
||||||
yield server
|
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||||
|
|
||||||
|
if not os.getenv("LETTA_SERVER_URL"):
|
||||||
|
thread = threading.Thread(target=_run_server, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Poll until the server is up (or timeout)
|
||||||
|
timeout_seconds = 30
|
||||||
|
deadline = time.time() + timeout_seconds
|
||||||
|
while time.time() < deadline:
|
||||||
|
try:
|
||||||
|
resp = requests.get(url + "/v1/health")
|
||||||
|
if resp.status_code < 500:
|
||||||
|
break
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(0.1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client(server_url: str) -> Letta:
|
||||||
|
"""
|
||||||
|
Creates and returns a synchronous Letta REST client for testing.
|
||||||
|
"""
|
||||||
|
client_instance = Letta(base_url=server_url)
|
||||||
|
yield client_instance
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -237,10 +271,22 @@ async def external_codebase_tool(test_user):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def agent_state(server: SyncServer):
|
async def agent_state(server_url: str):
|
||||||
|
"""
|
||||||
|
Creates and returns an agent state for testing with a pre-configured agent.
|
||||||
|
|
||||||
|
Note: This fixture uses the server's internal async API instead of the client API
|
||||||
|
because the sandbox tests need the full server-side AgentState object with all
|
||||||
|
its methods (like get_agent_env_vars_as_dict()), not the simplified DTO returned
|
||||||
|
by the REST API.
|
||||||
|
"""
|
||||||
|
from letta.server.server import SyncServer
|
||||||
|
|
||||||
|
# Import here to ensure server is running first
|
||||||
|
server = SyncServer()
|
||||||
await server.init_async(init_with_default_org_and_user=True)
|
await server.init_async(init_with_default_org_and_user=True)
|
||||||
actor = await server.user_manager.create_default_actor_async()
|
actor = await server.user_manager.create_default_actor_async()
|
||||||
agent_state = await server.create_agent_async(
|
agent_state_instance = await server.create_agent_async(
|
||||||
CreateAgent(
|
CreateAgent(
|
||||||
memory_blocks=[
|
memory_blocks=[
|
||||||
CreateBlock(
|
CreateBlock(
|
||||||
@@ -259,8 +305,7 @@ async def agent_state(server: SyncServer):
|
|||||||
),
|
),
|
||||||
actor=actor,
|
actor=actor,
|
||||||
)
|
)
|
||||||
agent_state.tool_rules = []
|
yield agent_state_instance
|
||||||
yield agent_state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -1141,3 +1186,165 @@ async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_e
|
|||||||
result = await sandbox.run(agent_state=agent_state)
|
result = await sandbox.run(agent_state=agent_state)
|
||||||
assert wrong_val not in result.func_return
|
assert wrong_val not in result.func_return
|
||||||
assert correct_val in result.func_return
|
assert correct_val in result.func_return
|
||||||
|
|
||||||
|
|
||||||
|
# Client injection tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def list_tools_with_client_tool(test_user):
|
||||||
|
"""Tool that uses injected client to list tools.
|
||||||
|
|
||||||
|
Note: This fixture uses ToolManager directly instead of the client API
|
||||||
|
because it needs to create a tool with a custom schema that excludes
|
||||||
|
the 'client' parameter (which is injected by the sandbox, not passed by the LLM).
|
||||||
|
"""
|
||||||
|
from letta.schemas.enums import ToolType
|
||||||
|
from letta.schemas.tool import Tool as PydanticTool
|
||||||
|
|
||||||
|
source_code = '''
|
||||||
|
def list_tools_via_client(client: "Letta") -> str:
|
||||||
|
"""
|
||||||
|
List available tools using the injected Letta client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Letta client instance (injected by system)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Comma-separated list of tool names
|
||||||
|
"""
|
||||||
|
if not client:
|
||||||
|
return "ERROR: client not injected"
|
||||||
|
|
||||||
|
try:
|
||||||
|
tools = client.tools.list()
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
return f"Found {len(tool_names)} tools: {', '.join(tool_names)}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"ERROR: {str(e)}"
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Create the tool with proper schema (client is injected, not in schema)
|
||||||
|
tool = PydanticTool(
|
||||||
|
name="list_tools_via_client",
|
||||||
|
description="List tools using injected client",
|
||||||
|
source_code=source_code,
|
||||||
|
source_type="python",
|
||||||
|
tool_type=ToolType.CUSTOM,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually set schema without 'client' parameter since it's injected
|
||||||
|
tool.json_schema = {
|
||||||
|
"name": "list_tools_via_client",
|
||||||
|
"description": "List tools using injected client",
|
||||||
|
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use ToolManager directly for this special case
|
||||||
|
created_tool = await ToolManager().create_or_update_tool_async(tool, test_user)
|
||||||
|
yield created_tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.local_sandbox
|
||||||
|
async def test_local_sandbox_with_client_injection(disable_e2b_api_key, list_tools_with_client_tool, test_user, server_url):
|
||||||
|
"""Test that local sandbox can inject Letta client for tools that need it."""
|
||||||
|
|
||||||
|
# Add LETTA_API_KEY to sandbox environment
|
||||||
|
api_key = os.getenv("LETTA_API_KEY") or "test-key"
|
||||||
|
base_url = server_url # Use the server_url fixture
|
||||||
|
|
||||||
|
# Pass environment variables directly to avoid encryption issues
|
||||||
|
sandbox_env_vars = {
|
||||||
|
"LETTA_API_KEY": api_key,
|
||||||
|
"LETTA_BASE_URL": base_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create the sandbox and verify client injection is detected
|
||||||
|
sandbox = AsyncToolSandboxLocal(
|
||||||
|
tool_name=list_tools_with_client_tool.name,
|
||||||
|
args={},
|
||||||
|
user=test_user,
|
||||||
|
tool_object=list_tools_with_client_tool,
|
||||||
|
sandbox_env_vars=sandbox_env_vars,
|
||||||
|
)
|
||||||
|
|
||||||
|
await sandbox._init_async()
|
||||||
|
|
||||||
|
# Verify that client injection was detected
|
||||||
|
assert sandbox.inject_letta_client is True, "Tool should be detected as needing client injection"
|
||||||
|
|
||||||
|
# Generate the execution script to verify client initialization code is present
|
||||||
|
script = await sandbox.generate_execution_script(agent_state=None)
|
||||||
|
|
||||||
|
# Debug: print the script
|
||||||
|
print("=" * 80)
|
||||||
|
print("GENERATED SCRIPT:")
|
||||||
|
print("=" * 80)
|
||||||
|
print(script)
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Verify the script contains Letta client initialization
|
||||||
|
assert "from letta_client import Letta" in script, "Script should import Letta client"
|
||||||
|
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
|
||||||
|
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
|
||||||
|
|
||||||
|
# Run the tool and verify it works
|
||||||
|
result = await sandbox.run(agent_state=None)
|
||||||
|
|
||||||
|
# The result should either list tools or indicate client wasn't available
|
||||||
|
assert result.status == "success" or "ERROR" in str(result.func_return), f"Tool execution failed: {result.stderr}"
|
||||||
|
|
||||||
|
print("RESULT --------------------------------")
|
||||||
|
print(result)
|
||||||
|
assert "Found" in str(result.func_return), f"Tool should list tools when client is available: {result.func_return}"
|
||||||
|
|
||||||
|
# Verify client was injected successfully (connection may fail if no server is running)
|
||||||
|
assert "ERROR: client not injected" not in str(result.func_return), "Client should be injected when LETTA_API_KEY is set"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.e2b_sandbox
|
||||||
|
async def test_e2b_sandbox_with_client_injection(check_e2b_key_is_set, list_tools_with_client_tool, test_user, server_url):
|
||||||
|
"""Test that E2B sandbox can inject Letta client for tools that need it."""
|
||||||
|
|
||||||
|
# Add LETTA_API_KEY to sandbox environment
|
||||||
|
api_key = os.getenv("LETTA_API_KEY") or "test-key"
|
||||||
|
base_url = server_url # Use the server_url fixture
|
||||||
|
|
||||||
|
# Pass environment variables directly to avoid encryption issues
|
||||||
|
sandbox_env_vars = {
|
||||||
|
"LETTA_API_KEY": api_key,
|
||||||
|
"LETTA_BASE_URL": base_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create the sandbox and verify client injection is detected
|
||||||
|
sandbox = AsyncToolSandboxE2B(
|
||||||
|
tool_name=list_tools_with_client_tool.name,
|
||||||
|
args={},
|
||||||
|
user=test_user,
|
||||||
|
tool_object=list_tools_with_client_tool,
|
||||||
|
sandbox_env_vars=sandbox_env_vars,
|
||||||
|
)
|
||||||
|
|
||||||
|
await sandbox._init_async()
|
||||||
|
|
||||||
|
# Verify that client injection was detected
|
||||||
|
assert sandbox.inject_letta_client is True, "Tool should be detected as needing client injection"
|
||||||
|
|
||||||
|
# Generate the execution script to verify client initialization code is present
|
||||||
|
script = await sandbox.generate_execution_script(agent_state=None)
|
||||||
|
|
||||||
|
# Debug: print the script
|
||||||
|
print("=" * 80)
|
||||||
|
print("GENERATED SCRIPT:")
|
||||||
|
print("=" * 80)
|
||||||
|
print(script)
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Verify the script contains Letta client initialization
|
||||||
|
assert "from letta_client import Letta" in script, "Script should import Letta client"
|
||||||
|
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
|
||||||
|
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
|
||||||
|
|
||||||
|
# Cannot run the tool since E2B is remote
|
||||||
|
|||||||
Reference in New Issue
Block a user