Files
letta-server/letta/services/tool_executor/async_tool_execution_sandbox.py
2025-03-27 15:24:50 -07:00

398 lines
17 KiB
Python

import ast
import asyncio
import base64
import os
import pickle
import sys
import tempfile
import uuid
from typing import Any, Dict, Optional, Tuple
from letta.functions.helpers import generate_model_from_args_json_schema
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult, SandboxType
from letta.services.helpers.tool_execution_helper import (
add_imports_and_pydantic_schemas_for_args,
create_venv_for_local_sandbox,
find_python_executable,
install_pip_requirements_for_sandbox,
)
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.tracing import log_event, trace_method
from letta.utils import get_friendly_error_msg
class AsyncToolExecutionSandbox:
METADATA_CONFIG_STATE_KEY = "config_state"
REQUIREMENT_TXT_NAME = "requirements.txt"
# For generating long, random marker hashes
NAMESPACE = uuid.NAMESPACE_DNS
LOCAL_SANDBOX_RESULT_START_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-start-marker"))
LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker"))
# This is the variable name in the auto-generated code that contains the function results
# We make this a long random string to avoid collisions with any variables in the user's code
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
def __init__(self, tool_name: str, args: dict, user, force_recreate=True, force_recreate_venv=False, tool_object=None):
self.tool_name = tool_name
self.args = args
self.user = user
# get organization
self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id)
self.privileged_tools = self.organization.privileged_tools
# If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None:
self.tool = tool_object
else:
# Get the tool via name
self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user)
if not self.tool:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
self.sandbox_config_manager = SandboxConfigManager()
self.force_recreate = force_recreate
self.force_recreate_venv = force_recreate_venv
async def run(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None, inject_agent_state: bool = False
) -> SandboxRunResult:
"""
Run the tool in a sandbox environment asynchronously,
*always* using a subprocess for execution.
"""
result = await self.run_local_dir_sandbox(
agent_state=agent_state, additional_env_vars=additional_env_vars, inject_agent_state=inject_agent_state
)
# Simple console logging for demonstration
for log_line in (result.stdout or []) + (result.stderr or []):
print(f"Tool execution log: {log_line}")
return result
@trace_method
async def run_local_dir_sandbox(
self, agent_state: Optional[AgentState], additional_env_vars: Optional[Dict], inject_agent_state: bool
) -> SandboxRunResult:
"""
Unified asynchronougit pus method to run the tool in a local sandbox environment,
always via subprocess for multi-core parallelism.
"""
# Get sandbox configuration
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config()
use_venv = local_configs.use_venv
# Prepare environment variables
env = os.environ.copy()
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env.update(env_vars)
if agent_state:
env.update(agent_state.get_agent_env_vars_as_dict())
if additional_env_vars:
env.update(additional_env_vars)
# Make sure sandbox directory exists
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
if not os.path.exists(sandbox_dir) or not os.path.isdir(sandbox_dir):
os.makedirs(sandbox_dir)
# If using a virtual environment, ensure it's prepared in parallel
venv_preparation_task = None
if use_venv:
venv_path = str(os.path.join(sandbox_dir, local_configs.venv_name))
if self.force_recreate_venv or not os.path.isdir(venv_path):
venv_preparation_task = asyncio.create_task(self._prepare_venv(local_configs, venv_path, env))
# Generate and write execution script (always with markers, since we rely on stdout)
with tempfile.NamedTemporaryFile(mode="w", dir=sandbox_dir, suffix=".py", delete=False) as temp_file:
code = self.generate_execution_script(agent_state=agent_state, inject_agent_state=inject_agent_state)
temp_file.write(code)
temp_file.flush()
temp_file_path = temp_file.name
try:
# If we started a venv preparation task, wait for it to complete
if venv_preparation_task:
await venv_preparation_task
# Determine the python executable and environment for the subprocess
exec_env = env.copy()
if use_venv:
venv_path = str(os.path.join(sandbox_dir, local_configs.venv_name))
python_executable = find_python_executable(local_configs)
exec_env["VIRTUAL_ENV"] = venv_path
exec_env["PATH"] = os.path.join(venv_path, "bin") + ":" + exec_env["PATH"]
else:
# If not using venv, use whatever Python we are running on
python_executable = sys.executable
exec_env["PYTHONWARNINGS"] = "ignore"
# Execute in subprocess
return await self._execute_tool_subprocess(
sbx_config=sbx_config,
python_executable=python_executable,
temp_file_path=temp_file_path,
env=exec_env,
cwd=sandbox_dir,
)
except Exception as e:
print(f"Executing tool {self.tool_name} has an unexpected error: {e}")
print(f"Auto-generated code for debugging:\n\n{code}")
raise e
finally:
# Clean up the temp file
os.remove(temp_file_path)
async def _prepare_venv(self, local_configs, venv_path: str, env: Dict[str, str]):
"""
Prepare virtual environment asynchronously (in a background thread).
"""
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path})
await asyncio.to_thread(
create_venv_for_local_sandbox,
sandbox_dir_path=sandbox_dir,
venv_path=venv_path,
env=env,
force_recreate=self.force_recreate_venv,
)
log_event(name="finish create_venv_for_local_sandbox")
log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
await asyncio.to_thread(install_pip_requirements_for_sandbox, local_configs, upgrade=True, user_install_if_no_venv=False, env=env)
log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
@trace_method
async def _execute_tool_subprocess(
self, sbx_config, python_executable: str, temp_file_path: str, env: Dict[str, str], cwd: str
) -> SandboxRunResult:
"""
Execute user code in a subprocess, always capturing stdout and stderr.
We parse special markers to extract the pickled result string.
"""
try:
log_event(name="start subprocess")
process = await asyncio.create_subprocess_exec(
python_executable, temp_file_path, env=env, cwd=cwd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(process.communicate(), timeout=60)
except asyncio.TimeoutError:
# Terminate the process on timeout
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
process.kill()
raise TimeoutError(f"Executing tool {self.tool_name} timed out after 60 seconds.")
stdout = stdout_bytes.decode("utf-8") if stdout_bytes else ""
stderr = stderr_bytes.decode("utf-8") if stderr_bytes else ""
log_event(name="finish subprocess")
# Parse markers to isolate the function result
func_result, stdout_text = self.parse_out_function_results_markers(stdout)
func_return, agent_state = self.parse_best_effort(func_result)
return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,
stdout=[stdout_text] if stdout_text else [],
stderr=[stderr] if stderr else [],
status="success" if process.returncode == 0 else "error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
except (TimeoutError, Exception) as e:
# Distinguish between timeouts and other exceptions for clarity
if isinstance(e, TimeoutError):
raise e
print(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}")
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
return SandboxRunResult(
func_return=func_return,
agent_state=None,
stdout=[],
stderr=[str(e)],
status="error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
def parse_out_function_results_markers(self, text: str) -> Tuple[str, str]:
"""
Parse the function results out of the stdout using special markers.
Returns (function_result_str, stripped_stdout).
"""
if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text:
# No markers found, so nothing to parse
return "", text
marker_len = len(self.LOCAL_SANDBOX_RESULT_START_MARKER)
start_index = text.index(self.LOCAL_SANDBOX_RESULT_START_MARKER) + marker_len
end_index = text.index(self.LOCAL_SANDBOX_RESULT_END_MARKER)
# The actual pickled base64 is between start_index and end_index
results_str = text[start_index:end_index]
# The rest of stdout (minus the markers)
remainder = text[: start_index - marker_len] + text[end_index + marker_len :]
return results_str, remainder
def parse_best_effort(self, text: str) -> Tuple[Any, Optional[AgentState]]:
"""
Decode and unpickle the result from the function execution if possible.
Returns (function_return_value, agent_state).
"""
if not text:
return None, None
result = pickle.loads(base64.b64decode(text))
agent_state = result["agent_state"] if result["agent_state"] is not None else None
return result["results"], agent_state
def parse_function_arguments(self, source_code: str, tool_name: str) -> list:
"""
Get arguments of the given function from its source code via AST.
"""
tree = ast.parse(source_code)
args = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == tool_name:
for arg in node.args.args:
args.append(arg.arg)
return args
def generate_execution_script(self, agent_state: Optional[AgentState], inject_agent_state: bool) -> str:
"""
Generate code to run inside of execution sandbox.
Serialize the agent state and arguments, call the tool,
then base64-encode/pickle the result.
"""
code = "from typing import *\n"
code += "import pickle\n"
code += "import sys\n"
code += "import base64\n"
# Additional imports to support agent state
if inject_agent_state:
code += "import letta\n"
code += "from letta import * \n"
# Add schema code if available
if self.tool.args_json_schema:
schema_code = add_imports_and_pydantic_schemas_for_args(self.tool.args_json_schema)
if "from __future__ import annotations" in schema_code:
schema_code = schema_code.replace("from __future__ import annotations", "").lstrip()
code = "from __future__ import annotations\n\n" + code
code += schema_code + "\n"
# Load the agent state
if inject_agent_state:
agent_state_pickle = pickle.dumps(agent_state)
code += f"agent_state = pickle.loads({agent_state_pickle})\n"
else:
code += "agent_state = None\n"
# Initialize arguments
if self.tool.args_json_schema:
args_schema = generate_model_from_args_json_schema(self.tool.args_json_schema)
code += f"args_object = {args_schema.__name__}(**{self.args})\n"
for param in self.args:
code += f"{param} = args_object.{param}\n"
else:
for param in self.args:
code += self.initialize_param(param, self.args[param])
# Insert the tool's source code
code += "\n" + self.tool.source_code + "\n"
# Invoke the function and store the result in a global variable
code += (
f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME}"
+ ' = {"results": '
+ self.invoke_function_call(inject_agent_state=inject_agent_state)
+ ', "agent_state": agent_state}\n'
)
code += (
f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME} = base64.b64encode("
f"pickle.dumps({self.LOCAL_SANDBOX_RESULT_VAR_NAME})"
").decode('utf-8')\n"
)
# If we're always in a subprocess, we must rely on markers to parse out the result
code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_START_MARKER}')\n"
code += f"sys.stdout.write(str({self.LOCAL_SANDBOX_RESULT_VAR_NAME}))\n"
code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_END_MARKER}')\n"
return code
def _convert_param_to_value(self, param_type: str, raw_value: str) -> str:
"""
Convert parameter to Python code representation based on JSON schema type.
"""
if param_type == "string":
# Safely inject a Python string via pickle
value = "pickle.loads(" + str(pickle.dumps(raw_value)) + ")"
elif param_type in ["integer", "boolean", "number", "array", "object"]:
# This is simplistic. In real usage, ensure correct type-casting or sanitization.
value = raw_value
else:
raise TypeError(f"Unsupported type: {param_type}, raw_value={raw_value}")
return str(value)
def initialize_param(self, name: str, raw_value: str) -> str:
"""
Produce code for initializing a single parameter in the generated script.
"""
params = self.tool.json_schema["parameters"]["properties"]
spec = params.get(name)
if spec is None:
# Possibly an extra param like 'self' that we ignore
return ""
param_type = spec.get("type")
if param_type is None and spec.get("parameters"):
param_type = spec["parameters"].get("type")
value = self._convert_param_to_value(param_type, raw_value)
return f"{name} = {value}\n"
def invoke_function_call(self, inject_agent_state: bool) -> str:
"""
Generate the function call code string with the appropriate arguments.
"""
kwargs = []
for name in self.args:
if name in self.tool.json_schema["parameters"]["properties"]:
kwargs.append(name)
param_list = [f"{arg}={arg}" for arg in kwargs]
if inject_agent_state:
param_list.append("agent_state=agent_state")
params = ", ".join(param_list)
func_call_str = self.tool.name + "(" + params + ")"
return func_call_str