fix: safer type coersion for tools (#8990)

* mvp

* perfrom type coercion in sandbox

* fix: safely resolve typing annotations on host

Use an AST whitelist for generic annotations to avoid eval while keeping list/dict coercion working.

👾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

---------

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
Kian Jones
2026-01-20 19:12:15 -08:00
committed by Caren Thomas
parent 2e826577d9
commit 1ab21af725
5 changed files with 143 additions and 6 deletions

View File

@@ -7,8 +7,52 @@ from typing import Dict, Optional, Tuple
from letta.errors import LettaToolCreateError from letta.errors import LettaToolCreateError
from letta.types import JsonDict from letta.types import JsonDict
_ALLOWED_TYPING_NAMES = {name: obj for name, obj in vars(typing).items() if not name.startswith("_")}
_ALLOWED_BUILTIN_TYPES = {name: obj for name, obj in vars(builtins).items() if isinstance(obj, type)}
_ALLOWED_TYPE_NAMES = {**_ALLOWED_TYPING_NAMES, **_ALLOWED_BUILTIN_TYPES, "typing": typing}
def resolve_type(annotation: str):
def _resolve_annotation_node(node: ast.AST):
if isinstance(node, ast.Name):
if node.id == "None":
return type(None)
if node.id in _ALLOWED_TYPE_NAMES:
return _ALLOWED_TYPE_NAMES[node.id]
raise ValueError(f"Unsupported annotation name: {node.id}")
if isinstance(node, ast.Attribute):
if isinstance(node.value, ast.Name) and node.value.id == "typing" and node.attr in _ALLOWED_TYPING_NAMES:
return _ALLOWED_TYPING_NAMES[node.attr]
raise ValueError("Unsupported annotation attribute")
if isinstance(node, ast.Subscript):
origin = _resolve_annotation_node(node.value)
args = _resolve_subscript_slice(node.slice)
return origin[args]
if isinstance(node, ast.Tuple):
return tuple(_resolve_annotation_node(elt) for elt in node.elts)
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
left = _resolve_annotation_node(node.left)
right = _resolve_annotation_node(node.right)
return left | right
if isinstance(node, ast.Constant) and node.value is None:
return type(None)
raise ValueError("Unsupported annotation expression")
def _resolve_subscript_slice(slice_node: ast.AST):
if isinstance(slice_node, ast.Index):
slice_node = slice_node.value
if isinstance(slice_node, ast.Tuple):
return tuple(_resolve_annotation_node(elt) for elt in slice_node.elts)
return _resolve_annotation_node(slice_node)
def resolve_type(annotation: str, *, allow_unsafe_eval: bool = False, extra_globals: Optional[Dict[str, object]] = None):
""" """
Resolve a type annotation string into a Python type. Resolve a type annotation string into a Python type.
Previously, primitive support for int, float, str, dict, list, set, tuple, bool. Previously, primitive support for int, float, str, dict, list, set, tuple, bool.
@@ -23,15 +67,23 @@ def resolve_type(annotation: str):
ValueError: If the annotation is unsupported or invalid. ValueError: If the annotation is unsupported or invalid.
""" """
python_types = {**vars(typing), **vars(builtins)} python_types = {**vars(typing), **vars(builtins)}
if extra_globals:
python_types.update(extra_globals)
if annotation in python_types: if annotation in python_types:
return python_types[annotation] return python_types[annotation]
try: try:
# Allow use of typing and builtins in a safe eval context parsed = ast.parse(annotation, mode="eval")
return eval(annotation, python_types) return _resolve_annotation_node(parsed.body)
except Exception: except Exception:
raise ValueError(f"Unsupported annotation: {annotation}") if allow_unsafe_eval:
try:
return eval(annotation, python_types)
except Exception as exc:
raise ValueError(f"Unsupported annotation: {annotation}") from exc
raise ValueError(f"Unsupported annotation: {annotation}")
# TODO :: THIS MUST BE EDITED TO HANDLE THINGS # TODO :: THIS MUST BE EDITED TO HANDLE THINGS
@@ -62,14 +114,34 @@ def get_function_annotations_from_source(source_code: str, function_name: str) -
# NOW json_loads -> ast.literal_eval -> typing.get_origin # NOW json_loads -> ast.literal_eval -> typing.get_origin
def coerce_dict_args_by_annotations(function_args: JsonDict, annotations: Dict[str, str]) -> dict: def coerce_dict_args_by_annotations(
function_args: JsonDict,
annotations: Dict[str, object],
*,
allow_unsafe_eval: bool = False,
extra_globals: Optional[Dict[str, object]] = None,
) -> dict:
coerced_args = dict(function_args) # Shallow copy coerced_args = dict(function_args) # Shallow copy
for arg_name, value in coerced_args.items(): for arg_name, value in coerced_args.items():
if arg_name in annotations: if arg_name in annotations:
annotation_str = annotations[arg_name] annotation_str = annotations[arg_name]
try: try:
arg_type = resolve_type(annotation_str) annotation_value = annotations[arg_name]
if isinstance(annotation_value, str):
arg_type = resolve_type(
annotation_value,
allow_unsafe_eval=allow_unsafe_eval,
extra_globals=extra_globals,
)
elif isinstance(annotation_value, typing.ForwardRef):
arg_type = resolve_type(
annotation_value.__forward_arg__,
allow_unsafe_eval=allow_unsafe_eval,
extra_globals=extra_globals,
)
else:
arg_type = annotation_value
# Always parse strings using literal_eval or json if possible # Always parse strings using literal_eval or json if possible
if isinstance(value, str): if isinstance(value, str):

View File

@@ -533,6 +533,24 @@ class ToolExecutionSandbox:
code += "\n" + self.tool.source_code + "\n" code += "\n" + self.tool.source_code + "\n"
if self.args:
raw_args = ", ".join([f"{name!r}: {name}" for name in self.args])
code += f"__letta_raw_args = {{{raw_args}}}\n"
code += "try:\n"
code += " from letta.functions.ast_parsers import coerce_dict_args_by_annotations\n"
code += f" __letta_func = {self.tool.name}\n"
code += " __letta_annotations = getattr(__letta_func, '__annotations__', {})\n"
code += " __letta_coerced_args = coerce_dict_args_by_annotations(\n"
code += " __letta_raw_args,\n"
code += " __letta_annotations,\n"
code += " allow_unsafe_eval=True,\n"
code += " extra_globals=__letta_func.__globals__,\n"
code += " )\n"
for name in self.args:
code += f" {name} = __letta_coerced_args.get({name!r}, {name})\n"
code += "except Exception:\n"
code += " pass\n"
# TODO: handle wrapped print # TODO: handle wrapped print
code += ( code += (

View File

@@ -167,6 +167,19 @@ def modal_tool_wrapper(tool: PydanticTool, actor: PydanticUser, sandbox_env_vars
if "agent_state" in tool_func.__code__.co_varnames: if "agent_state" in tool_func.__code__.co_varnames:
kwargs["agent_state"] = reconstructed_agent_state kwargs["agent_state"] = reconstructed_agent_state
try:
from letta.functions.ast_parsers import coerce_dict_args_by_annotations
annotations = getattr(tool_func, "__annotations__", {})
kwargs = coerce_dict_args_by_annotations(
kwargs,
annotations,
allow_unsafe_eval=True,
extra_globals=tool_func.__globals__,
)
except Exception:
pass
# Execute the tool function (async or sync) # Execute the tool function (async or sync)
if is_async: if is_async:
result = asyncio.run(tool_func(**kwargs)) result = asyncio.run(tool_func(**kwargs))

View File

@@ -259,6 +259,27 @@ class AsyncToolSandboxBase(ABC):
if tool_source_code: if tool_source_code:
lines.append(tool_source_code.rstrip()) lines.append(tool_source_code.rstrip())
if self.args:
raw_args = ", ".join([f"{name!r}: {name}" for name in self.args])
lines.extend(
[
f"__letta_raw_args = {{{raw_args}}}",
"try:",
" from letta.functions.ast_parsers import coerce_dict_args_by_annotations",
f" __letta_func = {self.tool.name}",
" __letta_annotations = getattr(__letta_func, '__annotations__', {})",
" __letta_coerced_args = coerce_dict_args_by_annotations(",
" __letta_raw_args,",
" __letta_annotations,",
" allow_unsafe_eval=True,",
" extra_globals=__letta_func.__globals__,",
" )",
]
)
for name in self.args:
lines.append(f" {name} = __letta_coerced_args.get({name!r}, {name})")
lines.extend(["except Exception:", " pass"])
if not self.is_async_function: if not self.is_async_function:
# sync variant # sync variant
lines.append(f"_function_result = {invoke_function_call}") lines.append(f"_function_result = {invoke_function_call}")

View File

@@ -122,6 +122,19 @@ class ModalFunctionExecutor:
if inject_agent_state: if inject_agent_state:
kwargs["agent_state"] = agent_state kwargs["agent_state"] = agent_state
try:
from letta.functions.ast_parsers import coerce_dict_args_by_annotations
annotations = getattr(func, "__annotations__", {})
kwargs = coerce_dict_args_by_annotations(
kwargs,
annotations,
allow_unsafe_eval=True,
extra_globals=func.__globals__,
)
except Exception:
pass
if is_async: if is_async:
result = asyncio.run(func(**kwargs)) result = asyncio.run(func(**kwargs))
else: else: