diff --git a/letta/services/helpers/tool_parser_helper.py b/letta/services/helpers/tool_parser_helper.py index b0142848..f38de929 100644 --- a/letta/services/helpers/tool_parser_helper.py +++ b/letta/services/helpers/tool_parser_helper.py @@ -28,7 +28,8 @@ def parse_function_arguments(source_code: str, tool_name: str): tree = ast.parse(source_code) args = [] for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef) and node.name == tool_name: + # Handle both sync and async functions + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == tool_name: for arg in node.args.args: args.append(arg.arg) return args diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index e21d4743..03083981 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -52,6 +52,9 @@ class AsyncToolSandboxBase(ABC): else: self.inject_agent_state = False + # Detect if the tool function is async + self.is_async_function = self._detect_async_function() + # Lazily initialize the manager only when needed @property def sandbox_config_manager(self): @@ -78,7 +81,8 @@ class AsyncToolSandboxBase(ABC): """ from letta.templates.template_helper import render_template - TEMPLATE_NAME = "sandbox_code_file.py.j2" + # Select the appropriate template based on whether the function is async + TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2" future_import = False schema_code = None @@ -114,6 +118,7 @@ class AsyncToolSandboxBase(ABC): invoke_function_call=self.invoke_function_call(), wrap_print_with_markers=wrap_print_with_markers, start_marker=self.LOCAL_SANDBOX_RESULT_START_MARKER, + use_top_level_await=self.use_top_level_await(), ) def initialize_param(self, name: str, raw_value: JsonValue) -> str: @@ -150,5 +155,38 @@ class AsyncToolSandboxBase(ABC): func_call_str = self.tool.name + "(" + params + ")" return func_call_str + def _detect_async_function(self) -> bool: + """ + Detect if the tool function is an async function by examining its source code. + Uses AST parsing to reliably detect 'async def' declarations. + """ + import ast + + try: + # Parse the source code to AST + tree = ast.parse(self.tool.source_code) + + # Look for function definitions + for node in ast.walk(tree): + if isinstance(node, ast.AsyncFunctionDef) and node.name == self.tool.name: + return True + elif isinstance(node, ast.FunctionDef) and node.name == self.tool.name: + return False + + # If we couldn't find the function definition, fall back to string matching + return "async def " + self.tool.name in self.tool.source_code + + except SyntaxError: + # If source code can't be parsed, fall back to string matching + return "async def " + self.tool.name in self.tool.source_code + + def use_top_level_await(self) -> bool: + """ + Determine if this sandbox environment supports top-level await. + Should be overridden by subclasses to return True for environments + with running event loops (like E2B), False for local execution. + """ + return False # Default to False for local execution + def _update_env_vars(self): pass # TODO diff --git a/letta/services/tool_sandbox/e2b_sandbox.py b/letta/services/tool_sandbox/e2b_sandbox.py index 1e232168..9b069b01 100644 --- a/letta/services/tool_sandbox/e2b_sandbox.py +++ b/letta/services/tool_sandbox/e2b_sandbox.py @@ -250,6 +250,13 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): return sbx + def use_top_level_await(self) -> bool: + """ + E2B sandboxes run in a Jupyter-like environment with an active event loop, + so they support top-level await. + """ + return True + @staticmethod async def list_running_e2b_sandboxes(): # List running sandboxes and access metadata. diff --git a/letta/templates/sandbox_code_file_async.py.j2 b/letta/templates/sandbox_code_file_async.py.j2 new file mode 100644 index 00000000..6ed9cdbe --- /dev/null +++ b/letta/templates/sandbox_code_file_async.py.j2 @@ -0,0 +1,59 @@ +{{ 'from __future__ import annotations' if future_import else '' }} +from typing import * +import pickle +import sys +import base64 +import struct +import hashlib +import asyncio + +{# Additional imports to support agent state #} +{% if inject_agent_state %} +import letta +from letta import * +{% endif %} + +{# Add schema code if available #} +{{ schema_imports or '' }} + +{# Load agent state #} +agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickle else 'None' }} + +{{ tool_args }} + +{# The tool's source code #} +{{ tool_source_code }} + +{# Async wrapper to handle the function call and store the result #} +async def _async_wrapper(): + result = await {{ invoke_function_call }} + return { + "results": result, + "agent_state": agent_state + } + +{# Run the async function - method depends on environment #} +{% if use_top_level_await %} +{# Environment with running event loop (like E2B) - use top-level await #} +{{ local_sandbox_result_var_name }} = await _async_wrapper() +{% else %} +{# Local execution environment - use asyncio.run #} +{{ local_sandbox_result_var_name }} = asyncio.run(_async_wrapper()) +{% endif %} + +{{ local_sandbox_result_var_name }}_pkl = pickle.dumps({{ local_sandbox_result_var_name }}) + +{% if wrap_print_with_markers %} +{# Combine everything to flush and write at once. #} +data_checksum = hashlib.md5({{ local_sandbox_result_var_name }}_pkl).hexdigest().encode('ascii') +{{ local_sandbox_result_var_name }}_msg = ( + {{ start_marker }} + + struct.pack('>I', len({{ local_sandbox_result_var_name }}_pkl)) + + data_checksum + + {{ local_sandbox_result_var_name }}_pkl +) +sys.stdout.buffer.write({{ local_sandbox_result_var_name }}_msg) +sys.stdout.buffer.flush() +{% else %} +base64.b64encode({{ local_sandbox_result_var_name }}_pkl).decode('utf-8') +{% endif %} diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index 4390fdac..d3528597 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -336,6 +336,141 @@ def core_memory_tools(test_user): yield tools +@pytest.fixture +def async_add_integers_tool(test_user): + async def async_add(x: int, y: int) -> int: + """ + Async function that adds two integers. + + Parameters: + x (int): The first integer to add. + y (int): The second integer to add. + + Returns: + int: The result of adding x and y. + """ + import asyncio + + # Add a small delay to simulate async work + await asyncio.sleep(0.1) + return x + y + + tool = create_tool_from_func(async_add) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def async_get_env_tool(test_user): + async def async_get_env() -> str: + """ + Async function that returns the secret word env variable. + + Returns: + str: The secret word + """ + import asyncio + import os + + # Add a small delay to simulate async work + await asyncio.sleep(0.1) + secret_word = os.getenv("secret_word") + print(secret_word) + return secret_word + + tool = create_tool_from_func(async_get_env) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def async_stateful_tool(test_user): + async def async_clear_memory(agent_state: "AgentState"): + """Async function that clears the core memory""" + import asyncio + + # Add a small delay to simulate async work + await asyncio.sleep(0.1) + agent_state.memory.get_block("human").value = "" + agent_state.memory.get_block("persona").value = "" + + tool = create_tool_from_func(async_clear_memory) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def async_error_tool(test_user): + async def async_error() -> str: + """ + Async function that errors + + Returns: + str: not important + """ + import asyncio + + # Add some async work before erroring + await asyncio.sleep(0.1) + print("Going to error now") + raise ValueError("This is an intentional async error!") + + tool = create_tool_from_func(async_error) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def async_list_tool(test_user): + async def async_create_list() -> list: + """Async function that returns a list""" + import asyncio + + await asyncio.sleep(0.05) + return [1, 2, 3, 4, 5] + + tool = create_tool_from_func(async_create_list) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def async_complex_tool(test_user): + async def async_complex_computation(iterations: int = 3) -> dict: + """ + Async function that performs complex computation with multiple awaits. + + Parameters: + iterations (int): Number of iterations to perform. + + Returns: + dict: Results of the computation. + """ + import asyncio + import time + + results = [] + start_time = time.time() + + for i in range(iterations): + # Simulate async I/O + await asyncio.sleep(0.1) + results.append(i * 2) + + end_time = time.time() + + return { + "results": results, + "duration": end_time - start_time, + "iterations": iterations, + "average": sum(results) / len(results) if results else 0, + } + + tool = create_tool_from_func(async_complex_computation) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + @pytest.fixture(scope="session") def event_loop(request): """Create an instance of the default event loop for each test case.""" @@ -719,3 +854,257 @@ async def test_e2b_sandbox_with_broken_tool_pip_requirements_error_handling( # Should mention one of the problematic packages assert "numpy==1.24.0" in error_message or "nonexistent-package-12345" in error_message + + +# Async function tests + + +def test_async_function_detection(add_integers_tool, async_add_integers_tool, test_user): + """Test that async function detection works correctly""" + # Test sync function detection + sync_sandbox = AsyncToolSandboxE2B(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool) + assert not sync_sandbox.is_async_function + + # Test async function detection + async_sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool) + assert async_sandbox.is_async_function + + +def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user): + """Test that correct templates are selected for sync vs async functions""" + # Test sync function uses regular template + sync_sandbox = AsyncToolSandboxE2B(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool) + sync_script = sync_sandbox.generate_execution_script(agent_state=None) + print("=== SYNC SCRIPT ===") + print(sync_script) + print("=== END SYNC SCRIPT ===") + assert "import asyncio" not in sync_script + assert "asyncio.run" not in sync_script + + # Test async function uses async template + async_sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool) + async_script = async_sandbox.generate_execution_script(agent_state=None) + print("=== ASYNC SCRIPT ===") + print(async_script) + print("=== END ASYNC SCRIPT ===") + assert "import asyncio" in async_script + assert "await _async_wrapper()" in async_script # E2B uses top-level await + assert "_async_wrapper" in async_script + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_async_function_execution(disable_e2b_api_key, async_add_integers_tool, test_user, event_loop): + """Test that async functions execute correctly in local sandbox""" + args = {"x": 15, "y": 25} + + sandbox = AsyncToolSandboxLocal(async_add_integers_tool.name, args, user=test_user) + result = await sandbox.run() + assert result.func_return == args["x"] + args["y"] + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_async_function_execution(check_e2b_key_is_set, async_add_integers_tool, test_user, event_loop): + """Test that async functions execute correctly in E2B sandbox""" + args = {"x": 20, "y": 30} + + sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, args, user=test_user) + result = await sandbox.run() + assert int(result.func_return) == args["x"] + args["y"] + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_async_complex_computation(disable_e2b_api_key, async_complex_tool, test_user, event_loop): + """Test complex async computation with multiple awaits in local sandbox""" + args = {"iterations": 2} + + sandbox = AsyncToolSandboxLocal(async_complex_tool.name, args, user=test_user) + result = await sandbox.run() + + assert isinstance(result.func_return, dict) + assert result.func_return["results"] == [0, 2] + assert result.func_return["iterations"] == 2 + assert result.func_return["average"] == 1.0 + assert result.func_return["duration"] > 0.15 # Should take at least 0.2s due to sleep + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_async_complex_computation(check_e2b_key_is_set, async_complex_tool, test_user, event_loop): + """Test complex async computation with multiple awaits in E2B sandbox""" + args = {"iterations": 2} + + sandbox = AsyncToolSandboxE2B(async_complex_tool.name, args, user=test_user) + result = await sandbox.run() + + func_return = result.func_return + assert isinstance(func_return, dict) + assert func_return["results"] == [0, 2] + assert func_return["iterations"] == 2 + assert func_return["average"] == 1.0 + assert func_return["duration"] > 0.15 + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_async_list_return(disable_e2b_api_key, async_list_tool, test_user, event_loop): + """Test async function returning list in local sandbox""" + sandbox = AsyncToolSandboxLocal(async_list_tool.name, {}, user=test_user) + result = await sandbox.run() + assert result.func_return == [1, 2, 3, 4, 5] + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_async_list_return(check_e2b_key_is_set, async_list_tool, test_user, event_loop): + """Test async function returning list in E2B sandbox""" + sandbox = AsyncToolSandboxE2B(async_list_tool.name, {}, user=test_user) + result = await sandbox.run() + assert result.func_return == [1, 2, 3, 4, 5] + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_async_with_env_vars(disable_e2b_api_key, async_get_env_tool, test_user, event_loop): + """Test async function with environment variables in local sandbox""" + manager = SandboxConfigManager() + + # Create custom local sandbox config + sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") + config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump()) + config = manager.create_or_update_sandbox_config(config_create, test_user) + + # Create environment variable + key = "secret_word" + test_value = "async_local_test_value_789" + manager.create_sandbox_env_var( + SandboxEnvironmentVariableCreate(key=key, value=test_value), sandbox_config_id=config.id, actor=test_user + ) + + sandbox = AsyncToolSandboxLocal(async_get_env_tool.name, {}, user=test_user) + result = await sandbox.run() + + assert test_value in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_async_with_env_vars(check_e2b_key_is_set, async_get_env_tool, test_user, event_loop): + """Test async function with environment variables in E2B sandbox""" + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) + config = manager.create_or_update_sandbox_config(config_create, test_user) + + # Create environment variable + key = "secret_word" + test_value = "async_e2b_test_value_456" + manager.create_sandbox_env_var( + SandboxEnvironmentVariableCreate(key=key, value=test_value), sandbox_config_id=config.id, actor=test_user + ) + + sandbox = AsyncToolSandboxE2B(async_get_env_tool.name, {}, user=test_user) + result = await sandbox.run() + + assert test_value in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_async_with_agent_state(disable_e2b_api_key, async_stateful_tool, test_user, agent_state, event_loop): + """Test async function with agent state in local sandbox""" + sandbox = AsyncToolSandboxLocal(async_stateful_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + + assert result.agent_state is not None + assert result.agent_state.memory.get_block("human").value == "" + assert result.agent_state.memory.get_block("persona").value == "" + assert result.func_return is None + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_async_with_agent_state(check_e2b_key_is_set, async_stateful_tool, test_user, agent_state, event_loop): + """Test async function with agent state in E2B sandbox""" + sandbox = AsyncToolSandboxE2B(async_stateful_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + + assert result.agent_state.memory.get_block("human").value == "" + assert result.agent_state.memory.get_block("persona").value == "" + assert result.func_return is None + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_async_error_handling(disable_e2b_api_key, async_error_tool, test_user, event_loop): + """Test async function error handling in local sandbox""" + sandbox = AsyncToolSandboxLocal(async_error_tool.name, {}, user=test_user) + result = await sandbox.run() + + # Check that error was captured + assert len(result.stdout) != 0, "stdout not empty" + assert "error" in result.stdout[0], "stdout contains printed string" + assert len(result.stderr) != 0, "stderr not empty" + assert "ValueError: This is an intentional async error!" in result.stderr[0], "stderr contains expected error" + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_async_error_handling(check_e2b_key_is_set, async_error_tool, test_user, event_loop): + """Test async function error handling in E2B sandbox""" + sandbox = AsyncToolSandboxE2B(async_error_tool.name, {}, user=test_user) + result = await sandbox.run() + + # Check that error was captured + assert len(result.stdout) != 0, "stdout not empty" + assert "error" in result.stdout[0], "stdout contains printed string" + assert len(result.stderr) != 0, "stderr not empty" + assert "ValueError: This is an intentional async error!" in result.stderr[0], "stderr contains expected error" + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_async_per_agent_env(disable_e2b_api_key, async_get_env_tool, agent_state, test_user, event_loop): + """Test async function with per-agent environment variables in local sandbox""" + manager = SandboxConfigManager() + key = "secret_word" + sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") + config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump()) + config = manager.create_or_update_sandbox_config(config_create, test_user) + + wrong_val = "wrong_async_local_value" + manager.create_sandbox_env_var(SandboxEnvironmentVariableCreate(key=key, value=wrong_val), sandbox_config_id=config.id, actor=test_user) + + correct_val = "correct_async_local_value" + agent_state.tool_exec_environment_variables = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)] + + sandbox = AsyncToolSandboxLocal(async_get_env_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + assert wrong_val not in result.func_return + assert correct_val in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_async_per_agent_env(check_e2b_key_is_set, async_get_env_tool, agent_state, test_user, event_loop): + """Test async function with per-agent environment variables in E2B sandbox""" + manager = SandboxConfigManager() + key = "secret_word" + wrong_val = "wrong_async_e2b_value" + correct_val = "correct_async_e2b_value" + + config_create = SandboxConfigCreate(config=LocalSandboxConfig().model_dump()) + config = manager.create_or_update_sandbox_config(config_create, test_user) + manager.create_sandbox_env_var( + SandboxEnvironmentVariableCreate(key=key, value=wrong_val), + sandbox_config_id=config.id, + actor=test_user, + ) + + agent_state.tool_exec_environment_variables = [AgentEnvironmentVariable(key=key, value=correct_val, agent_id=agent_state.id)] + + sandbox = AsyncToolSandboxE2B(async_get_env_tool.name, {}, user=test_user) + result = await sandbox.run(agent_state=agent_state) + assert wrong_val not in result.func_return + assert correct_val in result.func_return