Files
letta-server/letta/services/tool_sandbox/base.py

418 lines
17 KiB
Python

import os
import pickle
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from letta.functions.helpers import generate_model_from_args_json_schema
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
from letta.services.helpers.tool_parser_helper import convert_param_to_str_value, parse_function_arguments
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.types import JsonDict, JsonValue
class AsyncToolSandboxBase(ABC):
NAMESPACE = uuid.NAMESPACE_DNS
LOCAL_SANDBOX_RESULT_START_MARKER = uuid.uuid5(NAMESPACE, "local-sandbox-result-start-marker").bytes
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
def __init__(
self,
tool_name: str,
args: JsonDict,
user,
tool_id: str,
agent_id: Optional[str] = None,
project_id: Optional[str] = None,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
self.tool_name = tool_name
self.args = args
self.user = user
self.agent_id = agent_id
self.project_id = project_id
self.tool_id = tool_id
self.tool = tool_object
# Store provided values or create manager to fetch them later
self.provided_sandbox_config = sandbox_config
self.provided_sandbox_env_vars = sandbox_env_vars
# Only create the manager if we need to (lazy initialization)
self._sandbox_config_manager = None
self._initialized = False
async def _init_async(self):
"""Must be called inside the run method before the sandbox can be used"""
if not self._initialized:
if not self.tool:
self.tool = await ToolManager().get_tool_by_name_async(tool_name=self.tool_name, actor=self.user)
# missing tool
if self.tool is None:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
# Check for reserved keyword arguments
tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name)
# TODO: deprecate this
if "agent_state" in tool_arguments:
self.inject_agent_state = True
else:
self.inject_agent_state = False
# Always inject Letta client (available as `client` variable in sandbox)
self.inject_letta_client = True
self.inject_agent_id = "agent_id" in tool_arguments
self.is_async_function = self._detect_async_function()
self._initialized = True
# Lazily initialize the manager only when needed
@property
def sandbox_config_manager(self):
if self._sandbox_config_manager is None:
self._sandbox_config_manager = SandboxConfigManager()
return self._sandbox_config_manager
@abstractmethod
async def run(
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> ToolExecutionResult:
"""
Run the tool in a sandbox environment asynchronously.
Must be implemented by subclasses.
"""
raise NotImplementedError
@trace_method
async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
"""
Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool,
then base64-encode/pickle the result. Constructs the python file.
"""
await self._init_async()
future_import = False
schema_code = None
if self.tool.args_json_schema:
# Add schema code if available
schema_code = add_imports_and_pydantic_schemas_for_args(self.tool.args_json_schema)
if "from __future__ import annotations" in schema_code:
schema_code = schema_code.replace("from __future__ import annotations", "").lstrip()
future_import = True
# Initialize arguments
args_schema = generate_model_from_args_json_schema(self.tool.args_json_schema)
tool_args = f"args_object = {args_schema.__name__}(**{self.args})\n"
for param in self.args:
tool_args += f"{param} = args_object.{param}\n"
else:
tool_args = ""
for param in self.args:
tool_args += self.initialize_param(param, self.args[param])
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
agent_id = agent_state.id if agent_state else None
code = self._render_sandbox_code(
future_import=future_import,
inject_agent_state=self.inject_agent_state,
inject_letta_client=self.inject_letta_client,
inject_agent_id=self.inject_agent_id,
schema_imports=schema_code or "",
agent_state_pickle=agent_state_pickle,
agent_id=agent_id,
tool_args=tool_args,
tool_source_code=self.tool.source_code,
local_sandbox_result_var_name=self.LOCAL_SANDBOX_RESULT_VAR_NAME,
invoke_function_call=self.invoke_function_call(),
wrap_print_with_markers=wrap_print_with_markers,
start_marker=self.LOCAL_SANDBOX_RESULT_START_MARKER,
use_top_level_await=self.use_top_level_await(),
)
return code
def _render_sandbox_code(
self,
*,
future_import: bool,
inject_agent_state: bool,
inject_letta_client: bool,
inject_agent_id: bool,
schema_imports: str,
agent_state_pickle: bytes | None,
agent_id: str | None,
tool_args: str,
tool_source_code: str,
local_sandbox_result_var_name: str,
invoke_function_call: str,
wrap_print_with_markers: bool,
start_marker: bytes,
use_top_level_await: bool,
) -> str:
lines: list[str] = []
if future_import:
lines.append("from __future__ import annotations")
lines.extend(
[
"from typing import *",
"import pickle",
"import sys",
"import base64",
"import struct",
"import hashlib",
]
)
if self.is_async_function:
lines.append("import asyncio")
if inject_agent_state:
lines.extend(["import letta", "from letta import *"]) # noqa: F401
# Import Letta client if available (wrapped in try/except for sandboxes without letta_client installed)
if inject_letta_client:
lines.extend(
[
"try:",
" from letta_client import Letta",
"except ImportError:",
" Letta = None",
]
)
if schema_imports:
lines.append(schema_imports.rstrip())
if agent_state_pickle is not None:
lines.append(f"agent_state = pickle.loads({repr(agent_state_pickle)})")
else:
lines.append("agent_state = None")
# Initialize Letta client if needed (client is always available as a variable, may be None)
if inject_letta_client:
lines.extend(
[
"# Initialize Letta client for tool execution",
"import os",
"client = None",
"if Letta is not None and os.getenv('LETTA_API_KEY'):",
" # Check letta_client version to use correct parameter name",
" from packaging import version as pkg_version",
" import letta_client as lc_module",
" lc_version = pkg_version.parse(lc_module.__version__)",
" if lc_version < pkg_version.parse('1.0.0'):",
" client = Letta(",
" token=os.getenv('LETTA_API_KEY')",
" )",
" else:",
" client = Letta(",
" api_key=os.getenv('LETTA_API_KEY')",
" )",
]
)
# Set agent_id if needed
if inject_agent_id:
if agent_id:
lines.append(f"agent_id = {repr(agent_id)}")
else:
lines.append("agent_id = None")
if tool_args:
lines.append(tool_args.rstrip())
if tool_source_code:
lines.append(tool_source_code.rstrip())
if not self.is_async_function:
# sync variant
lines.append(f"_function_result = {invoke_function_call}")
lines.extend(
[
"try:",
" from pydantic import BaseModel, ConfigDict",
" from typing import Any",
"",
" class _TempResultWrapper(BaseModel):",
" model_config = ConfigDict(arbitrary_types_allowed=True)",
" result: Any",
"",
" _wrapped = _TempResultWrapper(result=_function_result)",
" _serialized_result = _wrapped.model_dump()['result']",
"except ImportError:",
' print("Pydantic not available in sandbox environment, falling back to string conversion")',
" _serialized_result = str(_function_result)",
"except Exception as e:",
' print(f"Failed to serialize result with Pydantic wrapper: {e}")',
" _serialized_result = str(_function_result)",
"",
f"{local_sandbox_result_var_name} = {{",
' "results": _serialized_result,',
' "agent_state": agent_state',
"}",
f"{local_sandbox_result_var_name}_pkl = pickle.dumps({local_sandbox_result_var_name})",
]
)
else:
# async variant
lines.extend(
[
"async def _async_wrapper():",
f" _function_result = await {invoke_function_call}",
" try:",
" from pydantic import BaseModel, ConfigDict",
" from typing import Any",
"",
" class _TempResultWrapper(BaseModel):",
" model_config = ConfigDict(arbitrary_types_allowed=True)",
" result: Any",
"",
" _wrapped = _TempResultWrapper(result=_function_result)",
" _serialized_result = _wrapped.model_dump()['result']",
" except ImportError:",
' print("Pydantic not available in sandbox environment, falling back to string conversion")',
" _serialized_result = str(_function_result)",
" except Exception as e:",
' print(f"Failed to serialize result with Pydantic wrapper: {e}")',
" _serialized_result = str(_function_result)",
"",
" return {",
' "results": _serialized_result,',
' "agent_state": agent_state',
" }",
]
)
if use_top_level_await:
lines.append(f"{local_sandbox_result_var_name} = await _async_wrapper()")
else:
lines.append(f"{local_sandbox_result_var_name} = asyncio.run(_async_wrapper())")
lines.append(f"{local_sandbox_result_var_name}_pkl = pickle.dumps({local_sandbox_result_var_name})")
if wrap_print_with_markers:
lines.extend(
[
f"data_checksum = hashlib.md5({local_sandbox_result_var_name}_pkl).hexdigest().encode('ascii')",
f"{local_sandbox_result_var_name}_msg = (",
f" {repr(start_marker)} +",
f" struct.pack('>I', len({local_sandbox_result_var_name}_pkl)) +",
" data_checksum +",
f" {local_sandbox_result_var_name}_pkl",
")",
f"sys.stdout.buffer.write({local_sandbox_result_var_name}_msg)",
"sys.stdout.buffer.flush()",
]
)
else:
lines.append(f"base64.b64encode({local_sandbox_result_var_name}_pkl).decode('utf-8')")
return "\n".join(lines) + "\n"
def initialize_param(self, name: str, raw_value: JsonValue) -> str:
"""
Produce code for initializing a single parameter in the generated script.
"""
params = self.tool.json_schema["parameters"]["properties"]
spec = params.get(name)
if spec is None:
# Possibly an extra param like 'self' that we ignore
return ""
param_type = spec.get("type")
if param_type is None and spec.get("parameters"):
param_type = spec["parameters"].get("type")
value = convert_param_to_str_value(param_type, raw_value)
return f"{name} = {value}\n"
def invoke_function_call(self) -> str:
"""
Generate the function call code string with the appropriate arguments.
"""
kwargs = []
for name in self.args:
if name in self.tool.json_schema["parameters"]["properties"]:
kwargs.append(name)
param_list = [f"{arg}={arg}" for arg in kwargs]
# Add reserved keyword arguments
if self.inject_agent_state:
param_list.append("agent_state=agent_state")
# Note: client is always available as a variable in the sandbox scope
# Tools should access it directly rather than receiving it as a parameter
if self.inject_agent_id:
param_list.append("agent_id=agent_id")
params = ", ".join(param_list)
func_call_str = self.tool.name + "(" + params + ")"
return func_call_str
def _detect_async_function(self) -> bool:
"""
Detect if the tool function is an async function by examining its source code.
Uses AST parsing to reliably detect 'async def' declarations.
"""
import ast
try:
tree = ast.parse(self.tool.source_code)
for node in ast.walk(tree):
if isinstance(node, ast.AsyncFunctionDef) and node.name == self.tool.name:
return True
return False
except:
return False
def use_top_level_await(self) -> bool:
"""
Determine if this sandbox environment supports top-level await.
Should be overridden by subclasses to return True for environments
with running event loops (like E2B), False for local execution.
"""
return False # Default to False for local execution
async def _gather_env_vars(self, agent_state: AgentState | None, additional_env_vars: dict[str, str], sbx_id: str, is_local: bool):
env = os.environ.copy() if is_local else {}
if self.provided_sandbox_env_vars:
env.update(self.provided_sandbox_env_vars)
else:
env_vars = await self.sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
sandbox_config_id=sbx_id, actor=self.user, limit=None
)
env.update(env_vars)
if agent_state:
env.update(agent_state.get_agent_env_vars_as_dict())
if additional_env_vars:
env.update(additional_env_vars)
# Inject agent, project, and tool IDs as environment variables
if self.agent_id:
env["LETTA_AGENT_ID"] = self.agent_id
if self.project_id:
env["LETTA_PROJECT_ID"] = self.project_id
env["LETTA_TOOL_ID"] = self.tool_id
# Filter out None values to prevent subprocess errors
env = {k: v for k, v in env.items() if v is not None}
return env