diff --git a/letta/functions/ast_parsers.py b/letta/functions/ast_parsers.py index 627b7fdb..14eed2fa 100644 --- a/letta/functions/ast_parsers.py +++ b/letta/functions/ast_parsers.py @@ -7,8 +7,52 @@ from typing import Dict, Optional, Tuple from letta.errors import LettaToolCreateError from letta.types import JsonDict +_ALLOWED_TYPING_NAMES = {name: obj for name, obj in vars(typing).items() if not name.startswith("_")} +_ALLOWED_BUILTIN_TYPES = {name: obj for name, obj in vars(builtins).items() if isinstance(obj, type)} +_ALLOWED_TYPE_NAMES = {**_ALLOWED_TYPING_NAMES, **_ALLOWED_BUILTIN_TYPES, "typing": typing} -def resolve_type(annotation: str): + +def _resolve_annotation_node(node: ast.AST): + if isinstance(node, ast.Name): + if node.id == "None": + return type(None) + if node.id in _ALLOWED_TYPE_NAMES: + return _ALLOWED_TYPE_NAMES[node.id] + raise ValueError(f"Unsupported annotation name: {node.id}") + + if isinstance(node, ast.Attribute): + if isinstance(node.value, ast.Name) and node.value.id == "typing" and node.attr in _ALLOWED_TYPING_NAMES: + return _ALLOWED_TYPING_NAMES[node.attr] + raise ValueError("Unsupported annotation attribute") + + if isinstance(node, ast.Subscript): + origin = _resolve_annotation_node(node.value) + args = _resolve_subscript_slice(node.slice) + return origin[args] + + if isinstance(node, ast.Tuple): + return tuple(_resolve_annotation_node(elt) for elt in node.elts) + + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + left = _resolve_annotation_node(node.left) + right = _resolve_annotation_node(node.right) + return left | right + + if isinstance(node, ast.Constant) and node.value is None: + return type(None) + + raise ValueError("Unsupported annotation expression") + + +def _resolve_subscript_slice(slice_node: ast.AST): + if isinstance(slice_node, ast.Index): + slice_node = slice_node.value + if isinstance(slice_node, ast.Tuple): + return tuple(_resolve_annotation_node(elt) for elt in slice_node.elts) + return _resolve_annotation_node(slice_node) + + +def resolve_type(annotation: str, *, allow_unsafe_eval: bool = False, extra_globals: Optional[Dict[str, object]] = None): """ Resolve a type annotation string into a Python type. Previously, primitive support for int, float, str, dict, list, set, tuple, bool. @@ -23,15 +67,23 @@ def resolve_type(annotation: str): ValueError: If the annotation is unsupported or invalid. """ python_types = {**vars(typing), **vars(builtins)} + if extra_globals: + python_types.update(extra_globals) if annotation in python_types: return python_types[annotation] try: - # Allow use of typing and builtins in a safe eval context - return eval(annotation, python_types) + parsed = ast.parse(annotation, mode="eval") + return _resolve_annotation_node(parsed.body) except Exception: - raise ValueError(f"Unsupported annotation: {annotation}") + if allow_unsafe_eval: + try: + return eval(annotation, python_types) + except Exception as exc: + raise ValueError(f"Unsupported annotation: {annotation}") from exc + + raise ValueError(f"Unsupported annotation: {annotation}") # TODO :: THIS MUST BE EDITED TO HANDLE THINGS @@ -62,14 +114,34 @@ def get_function_annotations_from_source(source_code: str, function_name: str) - # NOW json_loads -> ast.literal_eval -> typing.get_origin -def coerce_dict_args_by_annotations(function_args: JsonDict, annotations: Dict[str, str]) -> dict: +def coerce_dict_args_by_annotations( + function_args: JsonDict, + annotations: Dict[str, object], + *, + allow_unsafe_eval: bool = False, + extra_globals: Optional[Dict[str, object]] = None, +) -> dict: coerced_args = dict(function_args) # Shallow copy for arg_name, value in coerced_args.items(): if arg_name in annotations: annotation_str = annotations[arg_name] try: - arg_type = resolve_type(annotation_str) + annotation_value = annotations[arg_name] + if isinstance(annotation_value, str): + arg_type = resolve_type( + annotation_value, + allow_unsafe_eval=allow_unsafe_eval, + extra_globals=extra_globals, + ) + elif isinstance(annotation_value, typing.ForwardRef): + arg_type = resolve_type( + annotation_value.__forward_arg__, + allow_unsafe_eval=allow_unsafe_eval, + extra_globals=extra_globals, + ) + else: + arg_type = annotation_value # Always parse strings using literal_eval or json if possible if isinstance(value, str): diff --git a/letta/services/tool_executor/tool_execution_sandbox.py b/letta/services/tool_executor/tool_execution_sandbox.py index bc35618b..48b52fe8 100644 --- a/letta/services/tool_executor/tool_execution_sandbox.py +++ b/letta/services/tool_executor/tool_execution_sandbox.py @@ -533,6 +533,24 @@ class ToolExecutionSandbox: code += "\n" + self.tool.source_code + "\n" + if self.args: + raw_args = ", ".join([f"{name!r}: {name}" for name in self.args]) + code += f"__letta_raw_args = {{{raw_args}}}\n" + code += "try:\n" + code += " from letta.functions.ast_parsers import coerce_dict_args_by_annotations\n" + code += f" __letta_func = {self.tool.name}\n" + code += " __letta_annotations = getattr(__letta_func, '__annotations__', {})\n" + code += " __letta_coerced_args = coerce_dict_args_by_annotations(\n" + code += " __letta_raw_args,\n" + code += " __letta_annotations,\n" + code += " allow_unsafe_eval=True,\n" + code += " extra_globals=__letta_func.__globals__,\n" + code += " )\n" + for name in self.args: + code += f" {name} = __letta_coerced_args.get({name!r}, {name})\n" + code += "except Exception:\n" + code += " pass\n" + # TODO: handle wrapped print code += ( diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 9a1e10a3..f243580b 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -167,6 +167,19 @@ def modal_tool_wrapper(tool: PydanticTool, actor: PydanticUser, sandbox_env_vars if "agent_state" in tool_func.__code__.co_varnames: kwargs["agent_state"] = reconstructed_agent_state + try: + from letta.functions.ast_parsers import coerce_dict_args_by_annotations + + annotations = getattr(tool_func, "__annotations__", {}) + kwargs = coerce_dict_args_by_annotations( + kwargs, + annotations, + allow_unsafe_eval=True, + extra_globals=tool_func.__globals__, + ) + except Exception: + pass + # Execute the tool function (async or sync) if is_async: result = asyncio.run(tool_func(**kwargs)) diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index 14ded2cb..9b290f8c 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -259,6 +259,27 @@ class AsyncToolSandboxBase(ABC): if tool_source_code: lines.append(tool_source_code.rstrip()) + if self.args: + raw_args = ", ".join([f"{name!r}: {name}" for name in self.args]) + lines.extend( + [ + f"__letta_raw_args = {{{raw_args}}}", + "try:", + " from letta.functions.ast_parsers import coerce_dict_args_by_annotations", + f" __letta_func = {self.tool.name}", + " __letta_annotations = getattr(__letta_func, '__annotations__', {})", + " __letta_coerced_args = coerce_dict_args_by_annotations(", + " __letta_raw_args,", + " __letta_annotations,", + " allow_unsafe_eval=True,", + " extra_globals=__letta_func.__globals__,", + " )", + ] + ) + for name in self.args: + lines.append(f" {name} = __letta_coerced_args.get({name!r}, {name})") + lines.extend(["except Exception:", " pass"]) + if not self.is_async_function: # sync variant lines.append(f"_function_result = {invoke_function_call}") diff --git a/sandbox/modal_executor.py b/sandbox/modal_executor.py index 8ee22d09..2b759967 100644 --- a/sandbox/modal_executor.py +++ b/sandbox/modal_executor.py @@ -122,6 +122,19 @@ class ModalFunctionExecutor: if inject_agent_state: kwargs["agent_state"] = agent_state + try: + from letta.functions.ast_parsers import coerce_dict_args_by_annotations + + annotations = getattr(func, "__annotations__", {}) + kwargs = coerce_dict_args_by_annotations( + kwargs, + annotations, + allow_unsafe_eval=True, + extra_globals=func.__globals__, + ) + except Exception: + pass + if is_async: result = asyncio.run(func(**kwargs)) else: