From 28c3624a88b6415578a18e7e061e6dcd3ccd35b6 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Fri, 23 May 2025 18:02:15 -0700 Subject: [PATCH] fix: numerous tool execution bugs (#2371) --- letta/__main__.py | 3 + letta/functions/ast_parsers.py | 38 ++---- letta/functions/schema_generator.py | 14 +- letta/schemas/message.py | 6 + letta/schemas/tool.py | 4 +- letta/server/rest_api/interface.py | 8 +- .../services/helpers/tool_execution_helper.py | 3 + letta/services/helpers/tool_parser_helper.py | 63 +++++++++ .../tool_executor/tool_execution_sandbox.py | 38 +----- letta/services/tool_executor/tool_executor.py | 31 +++-- letta/services/tool_sandbox/base.py | 127 +++++------------- letta/services/tool_sandbox/e2b_sandbox.py | 27 ++-- letta/services/tool_sandbox/local_sandbox.py | 70 +++++++--- letta/templates/__init__.py | 0 letta/templates/sandbox_code_file.py.j2 | 47 +++++++ letta/templates/template_helper.py | 16 +++ letta/types/__init__.py | 7 + letta/utils.py | 14 ++ .../list_of_pydantic_example.json | 2 +- .../list_of_pydantic_example_so.json | 2 +- .../nested_pydantic_as_arg_example.json | 2 +- .../nested_pydantic_as_arg_example_so.json | 2 +- .../simple_d20.json | 2 +- .../simple_d20_so.json | 2 +- 24 files changed, 323 insertions(+), 205 deletions(-) create mode 100644 letta/__main__.py create mode 100644 letta/services/helpers/tool_parser_helper.py create mode 100644 letta/templates/__init__.py create mode 100644 letta/templates/sandbox_code_file.py.j2 create mode 100644 letta/templates/template_helper.py diff --git a/letta/__main__.py b/letta/__main__.py new file mode 100644 index 00000000..89f11424 --- /dev/null +++ b/letta/__main__.py @@ -0,0 +1,3 @@ +from .main import app + +app() diff --git a/letta/functions/ast_parsers.py b/letta/functions/ast_parsers.py index 3113cd96..620e8ba2 100644 --- a/letta/functions/ast_parsers.py +++ b/letta/functions/ast_parsers.py @@ -5,23 +5,13 @@ import typing from typing import Dict, Optional, Tuple from letta.errors import LettaToolCreateError - -# Registry of known types for annotation resolution -BUILTIN_TYPES = { - "int": int, - "float": float, - "str": str, - "dict": dict, - "list": list, - "set": set, - "tuple": tuple, - "bool": bool, -} +from letta.types import JsonDict def resolve_type(annotation: str): """ Resolve a type annotation string into a Python type. + Previously, primitive support for int, float, str, dict, list, set, tuple, bool. Args: annotation (str): The annotation string (e.g., 'int', 'list[int]', 'dict[str, int]'). @@ -32,24 +22,19 @@ def resolve_type(annotation: str): Raises: ValueError: If the annotation is unsupported or invalid. """ - if annotation in BUILTIN_TYPES: - return BUILTIN_TYPES[annotation] + python_types = {**vars(typing), **vars(builtins)} + + if annotation in python_types: + return python_types[annotation] try: # Allow use of typing and builtins in a safe eval context - namespace = { - **vars(typing), - **vars(builtins), - "list": list, - "dict": dict, - "tuple": tuple, - "set": set, - } - return eval(annotation, namespace) + return eval(annotation, python_types) except Exception: raise ValueError(f"Unsupported annotation: {annotation}") +# TODO :: THIS MUST BE EDITED TO HANDLE THINGS def get_function_annotations_from_source(source_code: str, function_name: str) -> Dict[str, str]: """ Parse the source code to extract annotations for a given function name. @@ -76,7 +61,8 @@ def get_function_annotations_from_source(source_code: str, function_name: str) - raise ValueError(f"Function '{function_name}' not found in the provided source code.") -def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict: +# NOW json_loads -> ast.literal_eval -> typing.get_origin +def coerce_dict_args_by_annotations(function_args: JsonDict, annotations: Dict[str, str]) -> dict: coerced_args = dict(function_args) # Shallow copy for arg_name, value in coerced_args.items(): @@ -110,8 +96,8 @@ def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, return coerced_args -def get_function_name_and_description(source_code: str, name: Optional[str] = None) -> Tuple[str, str]: - """Gets the name and description for a given function source code by parsing the AST. +def get_function_name_and_docstring(source_code: str, name: Optional[str] = None) -> Tuple[str, str]: + """Gets the name and docstring for a given function source code by parsing the AST. Args: source_code: The source code to parse diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index d2650630..6f1c003c 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -143,7 +143,10 @@ def pydantic_model_to_open_ai(model: Type[BaseModel]) -> dict: parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v) if "description" not in schema: - if docstring.short_description: + # Support multiline docstrings for complex functions, TODO (cliandy): consider having this as a setting + if docstring.long_description: + schema["description"] = docstring.long_description + elif docstring.short_description: schema["description"] = docstring.short_description else: raise ValueError(f"No description found in docstring or description field (model: {model}, docstring: {docstring})") @@ -330,10 +333,17 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ # Parse the docstring docstring = parse(function.__doc__) + if not description: + # Support multiline docstrings for complex functions, TODO (cliandy): consider having this as a setting + if docstring.long_description: + description = docstring.long_description + else: + description = docstring.short_description + # Prepare the schema dictionary schema = { "name": function.__name__ if name is None else name, - "description": docstring.short_description if description is None else description, + "description": description, "parameters": {"type": "object", "properties": {}, "required": []}, } diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 36c29ef6..1b684094 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -654,6 +654,8 @@ class Message(BaseMessage): parse_content_parts = False if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text + elif self.content and len(self.content) == 1 and isinstance(self.content[0], ToolReturnContent): + text_content = self.content[0].content # Otherwise, check if we have TextContent and multiple other parts elif self.content and len(self.content) > 1: text = [content for content in self.content if isinstance(content, TextContent)] @@ -866,6 +868,8 @@ class Message(BaseMessage): # role: str ('user' or 'model') if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text + elif self.content and len(self.content) == 1 and isinstance(self.content[0], ToolReturnContent): + text_content = self.content[0].content else: text_content = None @@ -1000,6 +1004,8 @@ class Message(BaseMessage): # embedded function calls in multi-turn conversation become more clear if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text + if self.content and len(self.content) == 1 and isinstance(self.content[0], ToolReturnContent): + text_content = self.content[0].content else: text_content = None if self.role == "system": diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 81e97aa2..af20bd2a 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -11,7 +11,7 @@ from letta.constants import ( LETTA_VOICE_TOOL_MODULE_NAME, MCP_TOOL_TAG_NAME_PREFIX, ) -from letta.functions.ast_parsers import get_function_name_and_description +from letta.functions.ast_parsers import get_function_name_and_docstring from letta.functions.composio_helpers import generate_composio_tool_wrapper from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module from letta.functions.mcp_client.types import MCPTool @@ -84,7 +84,7 @@ class Tool(BaseTool): # TODO: We skip this for Composio bc composio json schemas are derived differently if not (COMPOSIO_TOOL_TAG_NAME in self.tags): if self.args_json_schema is not None: - name, description = get_function_name_and_description(self.source_code, self.name) + name, description = get_function_name_and_docstring(self.source_code, self.name) args_schema = generate_model_from_args_json_schema(self.args_json_schema) self.json_schema = generate_schema_from_args_schema_v2( args_schema=args_schema, diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 33085408..b930f774 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -1338,8 +1338,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_return=msg, status=msg_obj.tool_returns[0].status if msg_obj.tool_returns else "success", tool_call_id=msg_obj.tool_call_id, - stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, - stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, + stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else [], + stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else [], name=msg_obj.name, otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) @@ -1354,8 +1354,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_return=msg, status=msg_obj.tool_returns[0].status if msg_obj.tool_returns else "error", tool_call_id=msg_obj.tool_call_id, - stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, - stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, + stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else [], + stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else [], name=msg_obj.name, otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) diff --git a/letta/services/helpers/tool_execution_helper.py b/letta/services/helpers/tool_execution_helper.py index 043594cc..34ea17af 100644 --- a/letta/services/helpers/tool_execution_helper.py +++ b/letta/services/helpers/tool_execution_helper.py @@ -60,6 +60,9 @@ def run_subprocess(command: list, env: Optional[Dict[str, str]] = None, fail_msg except subprocess.CalledProcessError as e: logger.error(f"{fail_msg}\nSTDOUT:\n{e.stdout}\nSTDERR:\n{e.stderr}") raise RuntimeError(f"{fail_msg}: {e.stderr.strip()}") from e + except Exception as e: + logger.error(f"{fail_msg}: {e}") + raise RuntimeError(f"{fail_msg}: {e}") def ensure_pip_is_up_to_date(python_exec: str, env: Optional[Dict[str, str]] = None): diff --git a/letta/services/helpers/tool_parser_helper.py b/letta/services/helpers/tool_parser_helper.py new file mode 100644 index 00000000..36798a91 --- /dev/null +++ b/letta/services/helpers/tool_parser_helper.py @@ -0,0 +1,63 @@ +import ast +import base64 +import pickle +from typing import Any + +from letta.schemas.agent import AgentState +from letta.types import JsonValue + + +def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]: + """ + Decode and unpickle the result from the function execution if possible. + Returns (function_return_value, agent_state). + """ + if not text: + return None, None + if isinstance(text, str): + text = base64.b64decode(text) + result = pickle.loads(text) + agent_state = result["agent_state"] + return result["results"], agent_state + + +def parse_function_arguments(source_code: str, tool_name: str): + """Get arguments of a function from its source code""" + tree = ast.parse(source_code) + args = [] + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == tool_name: + for arg in node.args.args: + args.append(arg.arg) + return args + + +def convert_param_to_str_value(param_type: str, raw_value: JsonValue) -> str: + """ + Convert parameter to Python code representation based on JSON schema type. + TODO (cliandy): increase sanitization checks here to fail at the right place + """ + + valid_types = {"string", "integer", "boolean", "number", "array", "object"} + if param_type not in valid_types: + raise TypeError(f"Unsupported type: {param_type}, raw_value={raw_value}") + if param_type == "string": + # Safely handle python string + return repr(raw_value) + if param_type == "integer": + return str(int(raw_value)) + if param_type == "boolean": + if isinstance(raw_value, bool): + return str(raw_value) + if isinstance(raw_value, int) and raw_value in (0, 1): + return str(bool(raw_value)) + if isinstance(raw_value, str) and raw_value.strip().lower() in ("true", "false"): + return raw_value.strip().lower().capitalize() + raise ValueError(f"Invalid boolean value: {raw_value}") + if param_type == "array": + pass # need more testing here + # if isinstance(raw_value, str): + # if raw_value.strip()[0] != "[" or raw_value.strip()[-1] != "]": + # raise ValueError(f'Invalid array value: "{raw_value}"') + # return raw_value.strip() + return str(raw_value) diff --git a/letta/services/tool_executor/tool_execution_sandbox.py b/letta/services/tool_executor/tool_execution_sandbox.py index e466cd9e..4d60de8f 100644 --- a/letta/services/tool_executor/tool_execution_sandbox.py +++ b/letta/services/tool_executor/tool_execution_sandbox.py @@ -1,4 +1,3 @@ -import ast import base64 import io import os @@ -23,6 +22,7 @@ from letta.services.helpers.tool_execution_helper import ( find_python_executable, install_pip_requirements_for_sandbox, ) +from letta.services.helpers.tool_parser_helper import convert_param_to_str_value, parse_function_arguments from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager @@ -52,7 +52,6 @@ class ToolExecutionSandbox: self.tool_name = tool_name self.args = args self.user = user - # get organization self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id) self.privileged_tools = self.organization.privileged_tools @@ -476,16 +475,6 @@ class ToolExecutionSandbox: agent_state = result["agent_state"] return result["results"], agent_state - def parse_function_arguments(self, source_code: str, tool_name: str): - """Get arguments of a function from its source code""" - tree = ast.parse(source_code) - args = [] - for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef) and node.name == tool_name: - for arg in node.args.args: - args.append(arg.arg) - return args - def generate_execution_script(self, agent_state: AgentState, wrap_print_with_markers: bool = False) -> str: """ Generate code to run inside of execution sandbox. @@ -498,7 +487,7 @@ class ToolExecutionSandbox: Returns: code (str): The generated code strong """ - if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name): + if "agent_state" in parse_function_arguments(self.tool.source_code, self.tool.name): inject_agent_state = True else: inject_agent_state = False @@ -546,7 +535,7 @@ class ToolExecutionSandbox: code += ( self.LOCAL_SANDBOX_RESULT_VAR_NAME + ' = {"results": ' - + self.invoke_function_call(inject_agent_state=inject_agent_state) + + self.invoke_function_call(inject_agent_state=inject_agent_state) # this inject_agent_state is the main difference + ', "agent_state": agent_state}\n' ) code += ( @@ -562,24 +551,6 @@ class ToolExecutionSandbox: return code - def _convert_param_to_value(self, param_type: str, raw_value: str) -> str: - - if param_type == "string": - value = "pickle.loads(" + str(pickle.dumps(raw_value)) + ")" - - elif param_type == "integer" or param_type == "boolean" or param_type == "number": - value = raw_value - - elif param_type == "array": - value = raw_value - - elif param_type == "object": - value = raw_value - - else: - raise TypeError(f"Unsupported type: {param_type}, raw_value={raw_value}") - return str(value) - def initialize_param(self, name: str, raw_value: str) -> str: params = self.tool.json_schema["parameters"]["properties"] spec = params.get(name) @@ -591,8 +562,7 @@ class ToolExecutionSandbox: if param_type is None and spec.get("parameters"): param_type = spec["parameters"].get("type") - value = self._convert_param_to_value(param_type, raw_value) - + value = convert_param_to_str_value(param_type, raw_value) return name + " = " + value + "\n" def invoke_function_call(self, inject_agent_state: bool) -> str: diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index 0e1fbe01..17866aaa 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -39,6 +39,7 @@ from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal from letta.settings import tool_settings from letta.tracing import trace_method +from letta.types import JsonDict from letta.utils import get_friendly_error_msg logger = get_logger(__name__) @@ -107,11 +108,20 @@ class LettaCoreToolExecutor(ToolExecutor): # Execute the appropriate function function_args_copy = function_args.copy() # Make a copy to avoid modifying the original - function_response = function_map[function_name](agent_state, actor, **function_args_copy) - return ToolExecutionResult( - status="success", - func_return=function_response, - ) + try: + function_response = function_map[function_name](agent_state, actor, **function_args_copy) + return ToolExecutionResult( + status="success", + func_return=function_response, + agent_state=agent_state, + ) + except Exception as e: + return ToolExecutionResult( + status="error", + func_return=e, + agent_state=agent_state, + stderr=[get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))], + ) def send_message(self, agent_state: AgentState, actor: User, message: str) -> Optional[str]: """ @@ -708,7 +718,7 @@ class SandboxToolExecutor(ToolExecutor): async def execute( self, function_name: str, - function_args: dict, + function_args: JsonDict, agent_state: AgentState, tool: Tool, actor: User, @@ -749,7 +759,8 @@ class SandboxToolExecutor(ToolExecutor): except Exception as e: return self._handle_execution_error(e, function_name, traceback.format_exc()) - def _prepare_function_args(self, function_args: dict, tool: Tool, function_name: str) -> dict: + @staticmethod + def _prepare_function_args(function_args: JsonDict, tool: Tool, function_name: str) -> dict: """Prepare function arguments with proper type coercion.""" try: # Parse the source code to extract function annotations @@ -761,7 +772,8 @@ class SandboxToolExecutor(ToolExecutor): # This is defensive programming - we try to coerce but fall back if it fails return function_args - def _create_agent_state_copy(self, agent_state: AgentState): + @staticmethod + def _create_agent_state_copy(agent_state: AgentState): """Create a copy of agent state for sandbox execution.""" agent_state_copy = agent_state.__deepcopy__() # Remove tools from copy to prevent nested tool execution @@ -769,8 +781,8 @@ class SandboxToolExecutor(ToolExecutor): agent_state_copy.tool_rules = [] return agent_state_copy + @staticmethod def _handle_execution_error( - self, exception: Exception, function_name: str, stderr: str, @@ -812,6 +824,7 @@ class LettaBuiltinToolExecutor(ToolExecutor): return ToolExecutionResult( status="success", func_return=function_response, + agent_state=agent_state, ) async def run_code(self, code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str: diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index 71d76e94..e21d4743 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -1,9 +1,7 @@ -import ast -import base64 import pickle import uuid from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional from letta.functions.helpers import generate_model_from_args_json_schema from letta.schemas.agent import AgentState @@ -11,20 +9,21 @@ from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args +from letta.services.helpers.tool_parser_helper import convert_param_to_str_value, parse_function_arguments from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager +from letta.types import JsonDict, JsonValue class AsyncToolSandboxBase(ABC): NAMESPACE = uuid.NAMESPACE_DNS - LOCAL_SANDBOX_RESULT_START_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-start-marker")) - LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker")) + LOCAL_SANDBOX_RESULT_START_MARKER = uuid.uuid5(NAMESPACE, "local-sandbox-result-start-marker").bytes LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt" def __init__( self, tool_name: str, - args: dict, + args: JsonDict, user, tool_object: Optional[Tool] = None, sandbox_config: Optional[SandboxConfig] = None, @@ -48,7 +47,7 @@ class AsyncToolSandboxBase(ABC): self._sandbox_config_manager = None # See if we should inject agent_state or not based on the presence of the "agent_state" arg - if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name): + if "agent_state" in parse_function_arguments(self.tool.source_code, self.tool.name): self.inject_agent_state = True else: self.inject_agent_state = False @@ -74,83 +73,50 @@ class AsyncToolSandboxBase(ABC): def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str: """ - Generate code to run inside of execution sandbox. - Serialize the agent state and arguments, call the tool, - then base64-encode/pickle the result. + Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool, + then base64-encode/pickle the result. Runs a jinja2 template constructing the python file. """ - code = "from typing import *\n" - code += "import pickle\n" - code += "import sys\n" - code += "import base64\n" + from letta.templates.template_helper import render_template - # Additional imports to support agent state - if self.inject_agent_state: - code += "import letta\n" - code += "from letta import * \n" + TEMPLATE_NAME = "sandbox_code_file.py.j2" + + future_import = False + schema_code = None - # Add schema code if available if self.tool.args_json_schema: + # Add schema code if available schema_code = add_imports_and_pydantic_schemas_for_args(self.tool.args_json_schema) if "from __future__ import annotations" in schema_code: schema_code = schema_code.replace("from __future__ import annotations", "").lstrip() - code = "from __future__ import annotations\n\n" + code - code += schema_code + "\n" + future_import = True - # Load the agent state - if self.inject_agent_state: - agent_state_pickle = pickle.dumps(agent_state) - code += f"agent_state = pickle.loads({agent_state_pickle})\n" - else: - code += "agent_state = None\n" - - # Initialize arguments - if self.tool.args_json_schema: + # Initialize arguments args_schema = generate_model_from_args_json_schema(self.tool.args_json_schema) - code += f"args_object = {args_schema.__name__}(**{self.args})\n" + tool_args = f"args_object = {args_schema.__name__}(**{self.args})\n" for param in self.args: - code += f"{param} = args_object.{param}\n" + tool_args += f"{param} = args_object.{param}\n" else: + tool_args = "" for param in self.args: - code += self.initialize_param(param, self.args[param]) + tool_args += self.initialize_param(param, self.args[param]) - # Insert the tool's source code - code += "\n" + self.tool.source_code + "\n" + agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None - # Invoke the function and store the result in a global variable - code += ( - f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME}" + ' = {"results": ' + self.invoke_function_call() + ', "agent_state": agent_state}\n' - ) - code += ( - f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME} = base64.b64encode(" - f"pickle.dumps({self.LOCAL_SANDBOX_RESULT_VAR_NAME})" - ").decode('utf-8')\n" + return render_template( + TEMPLATE_NAME, + future_import=future_import, + inject_agent_state=self.inject_agent_state, + schema_imports=schema_code, + agent_state_pickle=agent_state_pickle, + tool_args=tool_args, + tool_source_code=self.tool.source_code, + local_sandbox_result_var_name=self.LOCAL_SANDBOX_RESULT_VAR_NAME, + invoke_function_call=self.invoke_function_call(), + wrap_print_with_markers=wrap_print_with_markers, + start_marker=self.LOCAL_SANDBOX_RESULT_START_MARKER, ) - if wrap_print_with_markers: - code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_START_MARKER}')\n" - code += f"sys.stdout.write(str({self.LOCAL_SANDBOX_RESULT_VAR_NAME}))\n" - code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_END_MARKER}')\n" - else: - code += f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME}\n" - - return code - - def _convert_param_to_value(self, param_type: str, raw_value: str) -> str: - """ - Convert parameter to Python code representation based on JSON schema type. - """ - if param_type == "string": - # Safely inject a Python string via pickle - value = "pickle.loads(" + str(pickle.dumps(raw_value)) + ")" - elif param_type in ["integer", "boolean", "number", "array", "object"]: - # This is simplistic. In real usage, ensure correct type-casting or sanitization. - value = raw_value - else: - raise TypeError(f"Unsupported type: {param_type}, raw_value={raw_value}") - - return str(value) - - def initialize_param(self, name: str, raw_value: str) -> str: + def initialize_param(self, name: str, raw_value: JsonValue) -> str: """ Produce code for initializing a single parameter in the generated script. """ @@ -164,7 +130,7 @@ class AsyncToolSandboxBase(ABC): if param_type is None and spec.get("parameters"): param_type = spec["parameters"].get("type") - value = self._convert_param_to_value(param_type, raw_value) + value = convert_param_to_str_value(param_type, raw_value) return f"{name} = {value}\n" def invoke_function_call(self) -> str: @@ -184,24 +150,5 @@ class AsyncToolSandboxBase(ABC): func_call_str = self.tool.name + "(" + params + ")" return func_call_str - def parse_best_effort(self, text: str) -> Tuple[Any, Optional[AgentState]]: - """ - Decode and unpickle the result from the function execution if possible. - Returns (function_return_value, agent_state). - """ - if not text: - return None, None - - result = pickle.loads(base64.b64decode(text)) - agent_state = result["agent_state"] - return result["results"], agent_state - - def parse_function_arguments(self, source_code: str, tool_name: str): - """Get arguments of a function from its source code""" - tree = ast.parse(source_code) - args = [] - for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef) and node.name == tool_name: - for arg in node.args.args: - args.append(arg.arg) - return args + def _update_env_vars(self): + pass # TODO diff --git a/letta/services/tool_sandbox/e2b_sandbox.py b/letta/services/tool_sandbox/e2b_sandbox.py index 07ab5727..ea07f1f1 100644 --- a/letta/services/tool_sandbox/e2b_sandbox.py +++ b/letta/services/tool_sandbox/e2b_sandbox.py @@ -1,16 +1,23 @@ -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional + +from e2b_code_interpreter import AsyncSandbox from letta.log import get_logger from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxConfig, SandboxType from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult +from letta.services.helpers.tool_parser_helper import parse_stdout_best_effort from letta.services.tool_sandbox.base import AsyncToolSandboxBase from letta.tracing import log_event, trace_method +from letta.types import JsonDict from letta.utils import get_friendly_error_msg logger = get_logger(__name__) +if TYPE_CHECKING: + from e2b_code_interpreter import Execution + class AsyncToolSandboxE2B(AsyncToolSandboxBase): METADATA_CONFIG_STATE_KEY = "config_state" @@ -18,9 +25,9 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): def __init__( self, tool_name: str, - args: dict, + args: JsonDict, user, - force_recreate=True, + force_recreate: bool = True, tool_object: Optional[Tool] = None, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, @@ -92,7 +99,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): ) execution = await e2b_sandbox.run_code(code, envs=env_vars) if execution.results: - func_return, agent_state = self.parse_best_effort(execution.results[0].text) + func_return, agent_state = parse_stdout_best_effort(execution.results[0].text) log_event( "e2b_execution_succeeded", { @@ -138,16 +145,15 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): sandbox_config_fingerprint=sbx_config.fingerprint(), ) - def parse_exception_from_e2b_execution(self, e2b_execution: "Execution") -> Exception: + @staticmethod + def parse_exception_from_e2b_execution(e2b_execution: "Execution") -> Exception: builtins_dict = __builtins__ if isinstance(__builtins__, dict) else vars(__builtins__) # Dynamically fetch the exception class from builtins, defaulting to Exception if not found exception_class = builtins_dict.get(e2b_execution.error.name, Exception) return exception_class(e2b_execution.error.value) @trace_method - async def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox": - from e2b_code_interpreter import AsyncSandbox - + async def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "AsyncSandbox": state_hash = sandbox_config.fingerprint() e2b_config = sandbox_config.get_e2b_config() @@ -194,8 +200,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): return sbx - async def list_running_e2b_sandboxes(self): - from e2b_code_interpreter import AsyncSandbox - + @staticmethod + async def list_running_e2b_sandboxes(): # List running sandboxes and access metadata. return await AsyncSandbox.list() diff --git a/letta/services/tool_sandbox/local_sandbox.py b/letta/services/tool_sandbox/local_sandbox.py index 27640951..8fe34870 100644 --- a/letta/services/tool_sandbox/local_sandbox.py +++ b/letta/services/tool_sandbox/local_sandbox.py @@ -1,8 +1,12 @@ import asyncio +import hashlib import os +import struct import sys import tempfile -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional + +from pydantic.config import JsonDict from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxConfig, SandboxType @@ -13,10 +17,11 @@ from letta.services.helpers.tool_execution_helper import ( find_python_executable, install_pip_requirements_for_sandbox, ) +from letta.services.helpers.tool_parser_helper import parse_stdout_best_effort from letta.services.tool_sandbox.base import AsyncToolSandboxBase from letta.settings import tool_settings from letta.tracing import log_event, trace_method -from letta.utils import get_friendly_error_msg +from letta.utils import get_friendly_error_msg, parse_stderr_error_msg class AsyncToolSandboxLocal(AsyncToolSandboxBase): @@ -26,7 +31,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): def __init__( self, tool_name: str, - args: dict, + args: JsonDict, user, force_recreate_venv=False, tool_object: Optional[Tool] = None, @@ -123,7 +128,15 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): # If not using venv, use whatever Python we are running on python_executable = sys.executable - exec_env["PYTHONWARNINGS"] = "ignore" + # handle unwanted terminal behavior + exec_env.update( + { + "PYTHONWARNINGS": "ignore", + "NO_COLOR": "1", + "TERM": "dumb", + "PYTHONUNBUFFERED": "1", + } + ) # Execute in subprocess return await self._execute_tool_subprocess( @@ -170,6 +183,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): Execute user code in a subprocess, always capturing stdout and stderr. We parse special markers to extract the pickled result string. """ + stdout_text = "" try: log_event(name="start subprocess") @@ -190,13 +204,20 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): raise TimeoutError(f"Executing tool {self.tool_name} timed out after 60 seconds.") - stdout = stdout_bytes.decode("utf-8") if stdout_bytes else "" stderr = stderr_bytes.decode("utf-8") if stderr_bytes else "" log_event(name="finish subprocess") # Parse markers to isolate the function result - func_result, stdout_text = self.parse_out_function_results_markers(stdout) - func_return, agent_state = self.parse_best_effort(func_result) + func_result_bytes, stdout_text = self.parse_out_function_results_markers(stdout_bytes) + func_return, agent_state = parse_stdout_best_effort(func_result_bytes) + + if process.returncode != 0 and func_return is None: + exception_name, msg = parse_stderr_error_msg(stderr) + func_return = get_friendly_error_msg( + function_name=self.tool_name, + exception_name=exception_name, + exception_message=msg, + ) return ToolExecutionResult( func_return=func_return, @@ -213,6 +234,8 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): raise e print(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}") + print(e.__class__.__name__) + print(e.__traceback__) func_return = get_friendly_error_msg( function_name=self.tool_name, exception_name=type(e).__name__, @@ -221,27 +244,32 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): return ToolExecutionResult( func_return=func_return, agent_state=None, - stdout=[], + stdout=[stdout_text], stderr=[str(e)], status="error", sandbox_config_fingerprint=sbx_config.fingerprint(), ) - def parse_out_function_results_markers(self, text: str) -> Tuple[str, str]: + def parse_out_function_results_markers(self, data: bytes) -> tuple[bytes, str]: """ Parse the function results out of the stdout using special markers. - Returns (function_result_str, stripped_stdout). + Returns (function_results_bytes, stripped_stdout_bytes). """ - if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text: - # No markers found, so nothing to parse - return "", text + pos = data.find(self.LOCAL_SANDBOX_RESULT_START_MARKER) + if pos < 0: + return b"", data.decode("utf-8") if data else "" - marker_len = len(self.LOCAL_SANDBOX_RESULT_START_MARKER) - start_index = text.index(self.LOCAL_SANDBOX_RESULT_START_MARKER) + marker_len - end_index = text.index(self.LOCAL_SANDBOX_RESULT_END_MARKER) + DATA_LENGTH_INDICATOR = 4 + CHECKSUM_LENGTH = 32 + pos_start = pos + len(self.LOCAL_SANDBOX_RESULT_START_MARKER) + checksum_start = pos_start + DATA_LENGTH_INDICATOR + message_start = checksum_start + CHECKSUM_LENGTH - # The actual pickled base64 is between start_index and end_index - results_str = text[start_index:end_index] - # The rest of stdout (minus the markers) - remainder = text[: start_index - marker_len] + text[end_index + marker_len :] - return results_str, remainder + message_len = struct.unpack(">I", data[pos_start:checksum_start])[0] + checksum = data[checksum_start:message_start] + message_data = data[message_start : message_start + message_len] + actual_checksum = hashlib.md5(message_data).hexdigest().encode("ascii") + if actual_checksum == checksum: + remainder = data[:pos] + data[message_start + message_len :] + return message_data, (remainder.decode("utf-8") if remainder else "") + raise Exception("Function ran, but output is corrupted.") diff --git a/letta/templates/__init__.py b/letta/templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/templates/sandbox_code_file.py.j2 b/letta/templates/sandbox_code_file.py.j2 new file mode 100644 index 00000000..953b8ae8 --- /dev/null +++ b/letta/templates/sandbox_code_file.py.j2 @@ -0,0 +1,47 @@ +{{ 'from __future__ import annotations' if future_import else '' }} +from typing import * +import pickle +import sys +import base64 +import struct +import hashlib + +{# Additional imports to support agent state #} +{% if inject_agent_state %} +import letta +from letta import * +{% endif %} + +{# Add schema code if available #} +{{ schema_imports or ''}} + +{# Load agent state #} +agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickle else 'None' }} + +{{ tool_args }} + +{# The tool's source code #} +{{ tool_source_code }} + +{# Invoke the function and store the result in a global variable #} +{{ local_sandbox_result_var_name }} = { + "results": {{ invoke_function_call }}, + "agent_state": agent_state +} + +{{ local_sandbox_result_var_name }}_pkl = pickle.dumps({{ local_sandbox_result_var_name }}) + +{% if wrap_print_with_markers %} +{# Combine everything to flush and write at once. #} +data_checksum = hashlib.md5({{ local_sandbox_result_var_name }}_pkl).hexdigest().encode('ascii') +{{ local_sandbox_result_var_name }}_msg = ( + {{ start_marker }} + + struct.pack('>I', len({{ local_sandbox_result_var_name }}_pkl)) + + data_checksum + + {{ local_sandbox_result_var_name }}_pkl +) +sys.stdout.buffer.write({{ local_sandbox_result_var_name }}_msg) +sys.stdout.buffer.flush() +{% else %} +base64.b64encode({{ local_sandbox_result_var_name }}_pkl).decode('utf-8') +{% endif %} diff --git a/letta/templates/template_helper.py b/letta/templates/template_helper.py new file mode 100644 index 00000000..0d2359ce --- /dev/null +++ b/letta/templates/template_helper.py @@ -0,0 +1,16 @@ +import os + +from jinja2 import Environment, FileSystemLoader, StrictUndefined + +TEMPLATE_DIR = os.path.dirname(__file__) +jinja_env = Environment( + loader=FileSystemLoader(TEMPLATE_DIR), + undefined=StrictUndefined, + trim_blocks=True, + lstrip_blocks=True, +) + + +def render_template(template_name: str, **kwargs): + template = jinja_env.get_template(template_name) + return template.render(**kwargs) diff --git a/letta/types/__init__.py b/letta/types/__init__.py index e69de29b..b0f83c65 100644 --- a/letta/types/__init__.py +++ b/letta/types/__init__.py @@ -0,0 +1,7 @@ +from typing import Any, TypeAlias + +from pydantic import JsonValue + +JsonDict: TypeAlias = dict[str, JsonValue] + +__all__ = ["JsonDict", "JsonValue"] diff --git a/letta/utils.py b/letta/utils.py index a23735b2..069d0a1a 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1034,6 +1034,20 @@ def get_friendly_error_msg(function_name: str, exception_name: str, exception_me return error_msg +def parse_stderr_error_msg(stderr_txt: str, last_n_lines: int = 3) -> tuple[str, str]: + """ + Parses out from the last `last_n_line` of `stderr_txt` the Exception type and message. + """ + index = -(last_n_lines + 1) + pattern = r"(\w+(?:Error|Exception)): (.+?)$" + for line in stderr_txt.split("\n")[:index:-1]: + if "Error" in line or "Exception" in line: + match = re.search(pattern, line) + if match: + return match.group(1), match.group(2) + return "", "" + + def run_async_task(coro: Coroutine[Any, Any, Any]) -> Any: """ Safely runs an asynchronous coroutine in a synchronous context. diff --git a/tests/test_tool_schema_parsing_files/list_of_pydantic_example.json b/tests/test_tool_schema_parsing_files/list_of_pydantic_example.json index d2aeb6bd..e272884d 100644 --- a/tests/test_tool_schema_parsing_files/list_of_pydantic_example.json +++ b/tests/test_tool_schema_parsing_files/list_of_pydantic_example.json @@ -1,6 +1,6 @@ { "name": "create_task_plan", - "description": "Creates a task plan for the current task.", + "description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": -- Name of the step.\n \"key\": -- Unique identifier for the step.\n \"description\": -- An exhaustic description of what this step is trying to achieve and accomplish.\n}", "parameters": { "type": "object", "properties": { diff --git a/tests/test_tool_schema_parsing_files/list_of_pydantic_example_so.json b/tests/test_tool_schema_parsing_files/list_of_pydantic_example_so.json index f4b8a930..69b92c1c 100644 --- a/tests/test_tool_schema_parsing_files/list_of_pydantic_example_so.json +++ b/tests/test_tool_schema_parsing_files/list_of_pydantic_example_so.json @@ -1,6 +1,6 @@ { "name": "create_task_plan", - "description": "Creates a task plan for the current task.", + "description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": -- Name of the step.\n \"key\": -- Unique identifier for the step.\n \"description\": -- An exhaustic description of what this step is trying to achieve and accomplish.\n}", "strict": true, "parameters": { "type": "object", diff --git a/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example.json b/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example.json index 087182cf..c3c55550 100644 --- a/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example.json +++ b/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example.json @@ -1,6 +1,6 @@ { "name": "create_task_plan", - "description": "Creates a task plan for the current task.", + "description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": -- Name of the step.\n \"key\": -- Unique identifier for the step.\n \"description\": -- An exhaustic description of what this step is trying to achieve and accomplish.\n}", "parameters": { "type": "object", "properties": { diff --git a/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example_so.json b/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example_so.json index 0a311389..e1543c1e 100644 --- a/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example_so.json +++ b/tests/test_tool_schema_parsing_files/nested_pydantic_as_arg_example_so.json @@ -1,6 +1,6 @@ { "name": "create_task_plan", - "description": "Creates a task plan for the current task.", + "description": "It takes in a list of steps, and updates the task with the new steps provided.\nIf there are any current steps, they will be overwritten.\nEach step in the list should have the following format:\n{\n \"name\": -- Name of the step.\n \"key\": -- Unique identifier for the step.\n \"description\": -- An exhaustic description of what this step is trying to achieve and accomplish.\n}", "strict": true, "parameters": { "type": "object", diff --git a/tests/test_tool_schema_parsing_files/simple_d20.json b/tests/test_tool_schema_parsing_files/simple_d20.json index 7d660baf..2764eead 100644 --- a/tests/test_tool_schema_parsing_files/simple_d20.json +++ b/tests/test_tool_schema_parsing_files/simple_d20.json @@ -1,6 +1,6 @@ { "name": "roll_d20", - "description": "Simulate the roll of a 20-sided die (d20).", + "description": "This function generates a random integer between 1 and 20, inclusive,\nwhich represents the outcome of a single roll of a d20.", "parameters": { "type": "object", "properties": {}, diff --git a/tests/test_tool_schema_parsing_files/simple_d20_so.json b/tests/test_tool_schema_parsing_files/simple_d20_so.json index 2f3ddeab..68b74cec 100644 --- a/tests/test_tool_schema_parsing_files/simple_d20_so.json +++ b/tests/test_tool_schema_parsing_files/simple_d20_so.json @@ -1,6 +1,6 @@ { "name": "roll_d20", - "description": "Simulate the roll of a 20-sided die (d20).", + "description": "This function generates a random integer between 1 and 20, inclusive,\nwhich represents the outcome of a single roll of a d20.", "strict": true, "parameters": { "type": "object",