feat: Add async support for local/e2b sandbox (#2981)

This commit is contained in:
Matthew Zhou
2025-06-23 19:47:19 -07:00
committed by GitHub
parent 54562d88d7
commit 4996447326
5 changed files with 496 additions and 2 deletions

View File

@@ -28,7 +28,8 @@ def parse_function_arguments(source_code: str, tool_name: str):
tree = ast.parse(source_code)
args = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == tool_name:
# Handle both sync and async functions
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == tool_name:
for arg in node.args.args:
args.append(arg.arg)
return args

View File

@@ -52,6 +52,9 @@ class AsyncToolSandboxBase(ABC):
else:
self.inject_agent_state = False
# Detect if the tool function is async
self.is_async_function = self._detect_async_function()
# Lazily initialize the manager only when needed
@property
def sandbox_config_manager(self):
@@ -78,7 +81,8 @@ class AsyncToolSandboxBase(ABC):
"""
from letta.templates.template_helper import render_template
TEMPLATE_NAME = "sandbox_code_file.py.j2"
# Select the appropriate template based on whether the function is async
TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2"
future_import = False
schema_code = None
@@ -114,6 +118,7 @@ class AsyncToolSandboxBase(ABC):
invoke_function_call=self.invoke_function_call(),
wrap_print_with_markers=wrap_print_with_markers,
start_marker=self.LOCAL_SANDBOX_RESULT_START_MARKER,
use_top_level_await=self.use_top_level_await(),
)
def initialize_param(self, name: str, raw_value: JsonValue) -> str:
@@ -150,5 +155,38 @@ class AsyncToolSandboxBase(ABC):
func_call_str = self.tool.name + "(" + params + ")"
return func_call_str
def _detect_async_function(self) -> bool:
"""
Detect if the tool function is an async function by examining its source code.
Uses AST parsing to reliably detect 'async def' declarations.
"""
import ast
try:
# Parse the source code to AST
tree = ast.parse(self.tool.source_code)
# Look for function definitions
for node in ast.walk(tree):
if isinstance(node, ast.AsyncFunctionDef) and node.name == self.tool.name:
return True
elif isinstance(node, ast.FunctionDef) and node.name == self.tool.name:
return False
# If we couldn't find the function definition, fall back to string matching
return "async def " + self.tool.name in self.tool.source_code
except SyntaxError:
# If source code can't be parsed, fall back to string matching
return "async def " + self.tool.name in self.tool.source_code
def use_top_level_await(self) -> bool:
"""
Determine if this sandbox environment supports top-level await.
Should be overridden by subclasses to return True for environments
with running event loops (like E2B), False for local execution.
"""
return False # Default to False for local execution
def _update_env_vars(self):
pass # TODO

View File

@@ -250,6 +250,13 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
return sbx
def use_top_level_await(self) -> bool:
"""
E2B sandboxes run in a Jupyter-like environment with an active event loop,
so they support top-level await.
"""
return True
@staticmethod
async def list_running_e2b_sandboxes():
# List running sandboxes and access metadata.

View File

@@ -0,0 +1,59 @@
{{ 'from __future__ import annotations' if future_import else '' }}
from typing import *
import pickle
import sys
import base64
import struct
import hashlib
import asyncio
{# Additional imports to support agent state #}
{% if inject_agent_state %}
import letta
from letta import *
{% endif %}
{# Add schema code if available #}
{{ schema_imports or '' }}
{# Load agent state #}
agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickle else 'None' }}
{{ tool_args }}
{# The tool's source code #}
{{ tool_source_code }}
{# Async wrapper to handle the function call and store the result #}
async def _async_wrapper():
result = await {{ invoke_function_call }}
return {
"results": result,
"agent_state": agent_state
}
{# Run the async function - method depends on environment #}
{% if use_top_level_await %}
{# Environment with running event loop (like E2B) - use top-level await #}
{{ local_sandbox_result_var_name }} = await _async_wrapper()
{% else %}
{# Local execution environment - use asyncio.run #}
{{ local_sandbox_result_var_name }} = asyncio.run(_async_wrapper())
{% endif %}
{{ local_sandbox_result_var_name }}_pkl = pickle.dumps({{ local_sandbox_result_var_name }})
{% if wrap_print_with_markers %}
{# Combine everything to flush and write at once. #}
data_checksum = hashlib.md5({{ local_sandbox_result_var_name }}_pkl).hexdigest().encode('ascii')
{{ local_sandbox_result_var_name }}_msg = (
{{ start_marker }} +
struct.pack('>I', len({{ local_sandbox_result_var_name }}_pkl)) +
data_checksum +
{{ local_sandbox_result_var_name }}_pkl
)
sys.stdout.buffer.write({{ local_sandbox_result_var_name }}_msg)
sys.stdout.buffer.flush()
{% else %}
base64.b64encode({{ local_sandbox_result_var_name }}_pkl).decode('utf-8')
{% endif %}

View File

@@ -336,6 +336,141 @@ def core_memory_tools(test_user):
yield tools
@pytest.fixture
def async_add_integers_tool(test_user):
async def async_add(x: int, y: int) -> int:
"""
Async function that adds two integers.
Parameters:
x (int): The first integer to add.
y (int): The second integer to add.
Returns:
int: The result of adding x and y.
"""
import asyncio
# Add a small delay to simulate async work
await asyncio.sleep(0.1)
return x + y
tool = create_tool_from_func(async_add)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def async_get_env_tool(test_user):
async def async_get_env() -> str:
"""
Async function that returns the secret word env variable.
Returns:
str: The secret word
"""
import asyncio
import os
# Add a small delay to simulate async work
await asyncio.sleep(0.1)
secret_word = os.getenv("secret_word")
print(secret_word)
return secret_word
tool = create_tool_from_func(async_get_env)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def async_stateful_tool(test_user):
async def async_clear_memory(agent_state: "AgentState"):
"""Async function that clears the core memory"""
import asyncio
# Add a small delay to simulate async work
await asyncio.sleep(0.1)
agent_state.memory.get_block("human").value = ""
agent_state.memory.get_block("persona").value = ""
tool = create_tool_from_func(async_clear_memory)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def async_error_tool(test_user):
async def async_error() -> str:
"""
Async function that errors
Returns:
str: not important
"""
import asyncio
# Add some async work before erroring
await asyncio.sleep(0.1)
print("Going to error now")
raise ValueError("This is an intentional async error!")
tool = create_tool_from_func(async_error)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def async_list_tool(test_user):
async def async_create_list() -> list:
"""Async function that returns a list"""
import asyncio
await asyncio.sleep(0.05)
return [1, 2, 3, 4, 5]
tool = create_tool_from_func(async_create_list)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def async_complex_tool(test_user):
async def async_complex_computation(iterations: int = 3) -> dict:
"""
Async function that performs complex computation with multiple awaits.
Parameters:
iterations (int): Number of iterations to perform.
Returns:
dict: Results of the computation.
"""
import asyncio
import time
results = []
start_time = time.time()
for i in range(iterations):
# Simulate async I/O
await asyncio.sleep(0.1)
results.append(i * 2)
end_time = time.time()
return {
"results": results,
"duration": end_time - start_time,
"iterations": iterations,
"average": sum(results) / len(results) if results else 0,
}
tool = create_tool_from_func(async_complex_computation)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
@@ -719,3 +854,257 @@ async def test_e2b_sandbox_with_broken_tool_pip_requirements_error_handling(
# Should mention one of the problematic packages
assert "numpy==1.24.0" in error_message or "nonexistent-package-12345" in error_message
# Async function tests
def test_async_function_detection(add_integers_tool, async_add_integers_tool, test_user):
"""Test that async function detection works correctly"""
# Test sync function detection
sync_sandbox = AsyncToolSandboxE2B(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool)
assert not sync_sandbox.is_async_function
# Test async function detection
async_sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool)
assert async_sandbox.is_async_function
def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user):
"""Test that correct templates are selected for sync vs async functions"""
# Test sync function uses regular template
sync_sandbox = AsyncToolSandboxE2B(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool)
sync_script = sync_sandbox.generate_execution_script(agent_state=None)
print("=== SYNC SCRIPT ===")
print(sync_script)
print("=== END SYNC SCRIPT ===")
assert "import asyncio" not in sync_script
assert "asyncio.run" not in sync_script
# Test async function uses async template
async_sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool)
async_script = async_sandbox.generate_execution_script(agent_state=None)
print("=== ASYNC SCRIPT ===")
print(async_script)
print("=== END ASYNC SCRIPT ===")
assert "import asyncio" in async_script
assert "await _async_wrapper()" in async_script # E2B uses top-level await
assert "_async_wrapper" in async_script
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async_add_integers_tool, test_user, event_loop):
"""Test that async functions execute correctly in local sandbox"""
args = {"x": 15, "y": 25}
sandbox = AsyncToolSandboxLocal(async_add_integers_tool.name, args, user=test_user)
result = await sandbox.run()
assert result.func_return == args["x"] + args["y"]
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_add_integers_tool, test_user, event_loop):
"""Test that async functions execute correctly in E2B sandbox"""
args = {"x": 20, "y": 30}
sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, args, user=test_user)
result = await sandbox.run()
assert int(result.func_return) == args["x"] + args["y"]
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, async_complex_tool, test_user, event_loop):
"""Test complex async computation with multiple awaits in local sandbox"""
args = {"iterations": 2}
sandbox = AsyncToolSandboxLocal(async_complex_tool.name, args, user=test_user)
result = await sandbox.run()
assert isinstance(result.func_return, dict)
assert result.func_return["results"] == [0, 2]
assert result.func_return["iterations"] == 2
assert result.func_return["average"] == 1.0
assert result.func_return["duration"] > 0.15 # Should take at least 0.2s due to sleep
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async_complex_tool, test_user, event_loop):
"""Test complex async computation with multiple awaits in E2B sandbox"""
args = {"iterations": 2}
sandbox = AsyncToolSandboxE2B(async_complex_tool.name, args, user=test_user)
result = await sandbox.run()
func_return = result.func_return
assert isinstance(func_return, dict)
assert func_return["results"] == [0, 2]
assert func_return["iterations"] == 2
assert func_return["average"] == 1.0
assert func_return["duration"] > 0.15
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_tool, test_user, event_loop):
"""Test async function returning list in local sandbox"""
sandbox = AsyncToolSandboxLocal(async_list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert result.func_return == [1, 2, 3, 4, 5]
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_tool, test_user, event_loop):
"""Test async function returning list in E2B sandbox"""
sandbox = AsyncToolSandboxE2B(async_list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert result.func_return == [1, 2, 3, 4, 5]
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_env_tool, test_user, event_loop):
"""Test async function with environment variables in local sandbox"""
manager = SandboxConfigManager()
# Create custom local sandbox config
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
# Create environment variable
key = "secret_word"
test_value = "async_local_test_value_789"
manager.create_sandbox_env_var(
SandboxEnvironmentVariableCreate(key=key, value=test_value), sandbox_config_id=config.id, actor=test_user
)
sandbox = AsyncToolSandboxLocal(async_get_env_tool.name, {}, user=test_user)
result = await sandbox.run()
assert test_value in result.func_return
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_env_tool, test_user, event_loop):
"""Test async function with environment variables in E2B sandbox"""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
# Create environment variable
key = "secret_word"
test_value = "async_e2b_test_value_456"
manager.create_sandbox_env_var(
SandboxEnvironmentVariableCreate(key=key, value=test_value), sandbox_config_id=config.id, actor=test_user
)
sandbox = AsyncToolSandboxE2B(async_get_env_tool.name, {}, user=test_user)
result = await sandbox.run()
assert test_value in result.func_return
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_stateful_tool, test_user, agent_state, event_loop):
"""Test async function with agent state in local sandbox"""
sandbox = AsyncToolSandboxLocal(async_stateful_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert result.agent_state is not None
assert result.agent_state.memory.get_block("human").value == ""
assert result.agent_state.memory.get_block("persona").value == ""
assert result.func_return is None
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_stateful_tool, test_user, agent_state, event_loop):
"""Test async function with agent state in E2B sandbox"""
sandbox = AsyncToolSandboxE2B(async_stateful_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert result.agent_state.memory.get_block("human").value == ""
assert result.agent_state.memory.get_block("persona").value == ""
assert result.func_return is None
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_error_tool, test_user, event_loop):
"""Test async function error handling in local sandbox"""
sandbox = AsyncToolSandboxLocal(async_error_tool.name, {}, user=test_user)
result = await sandbox.run()
# Check that error was captured
assert len(result.stdout) != 0, "stdout not empty"
assert "error" in result.stdout[0], "stdout contains printed string"
assert len(result.stderr) != 0, "stderr not empty"
assert "ValueError: This is an intentional async error!" in result.stderr[0], "stderr contains expected error"
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_error_tool, test_user, event_loop):
"""Test async function error handling in E2B sandbox"""
sandbox = AsyncToolSandboxE2B(async_error_tool.name, {}, user=test_user)
result = await sandbox.run()
# Check that error was captured
assert len(result.stdout) != 0, "stdout not empty"
assert "error" in result.stdout[0], "stdout contains printed string"
assert len(result.stderr) != 0, "stderr not empty"
assert "ValueError: This is an intentional async error!" in result.stderr[0], "stderr contains expected error"
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_env_tool, agent_state, test_user, event_loop):
"""Test async function with per-agent environment variables in local sandbox"""
manager = SandboxConfigManager()
key = "secret_word"
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
wrong_val = "wrong_async_local_value"
manager.create_sandbox_env_var(SandboxEnvironmentVariableCreate(key=key, value=wrong_val), sandbox_config_id=config.id, actor=test_user)
correct_val = "correct_async_local_value"
agent_state.tool_exec_environment_variables = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)]
sandbox = AsyncToolSandboxLocal(async_get_env_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert wrong_val not in result.func_return
assert correct_val in result.func_return
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_env_tool, agent_state, test_user, event_loop):
"""Test async function with per-agent environment variables in E2B sandbox"""
manager = SandboxConfigManager()
key = "secret_word"
wrong_val = "wrong_async_e2b_value"
correct_val = "correct_async_e2b_value"
config_create = SandboxConfigCreate(config=LocalSandboxConfig().model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
manager.create_sandbox_env_var(
SandboxEnvironmentVariableCreate(key=key, value=wrong_val),
sandbox_config_id=config.id,
actor=test_user,
)
agent_state.tool_exec_environment_variables = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)]
sandbox = AsyncToolSandboxE2B(async_get_env_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert wrong_val not in result.func_return
assert correct_val in result.func_return