fix: safer type coersion for tools (#8990)
* mvp * perfrom type coercion in sandbox * fix: safely resolve typing annotations on host Use an AST whitelist for generic annotations to avoid eval while keeping list/dict coercion working. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -7,8 +7,52 @@ from typing import Dict, Optional, Tuple
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.types import JsonDict
|
||||
|
||||
_ALLOWED_TYPING_NAMES = {name: obj for name, obj in vars(typing).items() if not name.startswith("_")}
|
||||
_ALLOWED_BUILTIN_TYPES = {name: obj for name, obj in vars(builtins).items() if isinstance(obj, type)}
|
||||
_ALLOWED_TYPE_NAMES = {**_ALLOWED_TYPING_NAMES, **_ALLOWED_BUILTIN_TYPES, "typing": typing}
|
||||
|
||||
def resolve_type(annotation: str):
|
||||
|
||||
def _resolve_annotation_node(node: ast.AST):
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id == "None":
|
||||
return type(None)
|
||||
if node.id in _ALLOWED_TYPE_NAMES:
|
||||
return _ALLOWED_TYPE_NAMES[node.id]
|
||||
raise ValueError(f"Unsupported annotation name: {node.id}")
|
||||
|
||||
if isinstance(node, ast.Attribute):
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "typing" and node.attr in _ALLOWED_TYPING_NAMES:
|
||||
return _ALLOWED_TYPING_NAMES[node.attr]
|
||||
raise ValueError("Unsupported annotation attribute")
|
||||
|
||||
if isinstance(node, ast.Subscript):
|
||||
origin = _resolve_annotation_node(node.value)
|
||||
args = _resolve_subscript_slice(node.slice)
|
||||
return origin[args]
|
||||
|
||||
if isinstance(node, ast.Tuple):
|
||||
return tuple(_resolve_annotation_node(elt) for elt in node.elts)
|
||||
|
||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
left = _resolve_annotation_node(node.left)
|
||||
right = _resolve_annotation_node(node.right)
|
||||
return left | right
|
||||
|
||||
if isinstance(node, ast.Constant) and node.value is None:
|
||||
return type(None)
|
||||
|
||||
raise ValueError("Unsupported annotation expression")
|
||||
|
||||
|
||||
def _resolve_subscript_slice(slice_node: ast.AST):
|
||||
if isinstance(slice_node, ast.Index):
|
||||
slice_node = slice_node.value
|
||||
if isinstance(slice_node, ast.Tuple):
|
||||
return tuple(_resolve_annotation_node(elt) for elt in slice_node.elts)
|
||||
return _resolve_annotation_node(slice_node)
|
||||
|
||||
|
||||
def resolve_type(annotation: str, *, allow_unsafe_eval: bool = False, extra_globals: Optional[Dict[str, object]] = None):
|
||||
"""
|
||||
Resolve a type annotation string into a Python type.
|
||||
Previously, primitive support for int, float, str, dict, list, set, tuple, bool.
|
||||
@@ -23,15 +67,23 @@ def resolve_type(annotation: str):
|
||||
ValueError: If the annotation is unsupported or invalid.
|
||||
"""
|
||||
python_types = {**vars(typing), **vars(builtins)}
|
||||
if extra_globals:
|
||||
python_types.update(extra_globals)
|
||||
|
||||
if annotation in python_types:
|
||||
return python_types[annotation]
|
||||
|
||||
try:
|
||||
# Allow use of typing and builtins in a safe eval context
|
||||
return eval(annotation, python_types)
|
||||
parsed = ast.parse(annotation, mode="eval")
|
||||
return _resolve_annotation_node(parsed.body)
|
||||
except Exception:
|
||||
raise ValueError(f"Unsupported annotation: {annotation}")
|
||||
if allow_unsafe_eval:
|
||||
try:
|
||||
return eval(annotation, python_types)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Unsupported annotation: {annotation}") from exc
|
||||
|
||||
raise ValueError(f"Unsupported annotation: {annotation}")
|
||||
|
||||
|
||||
# TODO :: THIS MUST BE EDITED TO HANDLE THINGS
|
||||
@@ -62,14 +114,34 @@ def get_function_annotations_from_source(source_code: str, function_name: str) -
|
||||
|
||||
|
||||
# NOW json_loads -> ast.literal_eval -> typing.get_origin
|
||||
def coerce_dict_args_by_annotations(function_args: JsonDict, annotations: Dict[str, str]) -> dict:
|
||||
def coerce_dict_args_by_annotations(
|
||||
function_args: JsonDict,
|
||||
annotations: Dict[str, object],
|
||||
*,
|
||||
allow_unsafe_eval: bool = False,
|
||||
extra_globals: Optional[Dict[str, object]] = None,
|
||||
) -> dict:
|
||||
coerced_args = dict(function_args) # Shallow copy
|
||||
|
||||
for arg_name, value in coerced_args.items():
|
||||
if arg_name in annotations:
|
||||
annotation_str = annotations[arg_name]
|
||||
try:
|
||||
arg_type = resolve_type(annotation_str)
|
||||
annotation_value = annotations[arg_name]
|
||||
if isinstance(annotation_value, str):
|
||||
arg_type = resolve_type(
|
||||
annotation_value,
|
||||
allow_unsafe_eval=allow_unsafe_eval,
|
||||
extra_globals=extra_globals,
|
||||
)
|
||||
elif isinstance(annotation_value, typing.ForwardRef):
|
||||
arg_type = resolve_type(
|
||||
annotation_value.__forward_arg__,
|
||||
allow_unsafe_eval=allow_unsafe_eval,
|
||||
extra_globals=extra_globals,
|
||||
)
|
||||
else:
|
||||
arg_type = annotation_value
|
||||
|
||||
# Always parse strings using literal_eval or json if possible
|
||||
if isinstance(value, str):
|
||||
|
||||
@@ -533,6 +533,24 @@ class ToolExecutionSandbox:
|
||||
|
||||
code += "\n" + self.tool.source_code + "\n"
|
||||
|
||||
if self.args:
|
||||
raw_args = ", ".join([f"{name!r}: {name}" for name in self.args])
|
||||
code += f"__letta_raw_args = {{{raw_args}}}\n"
|
||||
code += "try:\n"
|
||||
code += " from letta.functions.ast_parsers import coerce_dict_args_by_annotations\n"
|
||||
code += f" __letta_func = {self.tool.name}\n"
|
||||
code += " __letta_annotations = getattr(__letta_func, '__annotations__', {})\n"
|
||||
code += " __letta_coerced_args = coerce_dict_args_by_annotations(\n"
|
||||
code += " __letta_raw_args,\n"
|
||||
code += " __letta_annotations,\n"
|
||||
code += " allow_unsafe_eval=True,\n"
|
||||
code += " extra_globals=__letta_func.__globals__,\n"
|
||||
code += " )\n"
|
||||
for name in self.args:
|
||||
code += f" {name} = __letta_coerced_args.get({name!r}, {name})\n"
|
||||
code += "except Exception:\n"
|
||||
code += " pass\n"
|
||||
|
||||
# TODO: handle wrapped print
|
||||
|
||||
code += (
|
||||
|
||||
@@ -167,6 +167,19 @@ def modal_tool_wrapper(tool: PydanticTool, actor: PydanticUser, sandbox_env_vars
|
||||
if "agent_state" in tool_func.__code__.co_varnames:
|
||||
kwargs["agent_state"] = reconstructed_agent_state
|
||||
|
||||
try:
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations
|
||||
|
||||
annotations = getattr(tool_func, "__annotations__", {})
|
||||
kwargs = coerce_dict_args_by_annotations(
|
||||
kwargs,
|
||||
annotations,
|
||||
allow_unsafe_eval=True,
|
||||
extra_globals=tool_func.__globals__,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Execute the tool function (async or sync)
|
||||
if is_async:
|
||||
result = asyncio.run(tool_func(**kwargs))
|
||||
|
||||
@@ -259,6 +259,27 @@ class AsyncToolSandboxBase(ABC):
|
||||
if tool_source_code:
|
||||
lines.append(tool_source_code.rstrip())
|
||||
|
||||
if self.args:
|
||||
raw_args = ", ".join([f"{name!r}: {name}" for name in self.args])
|
||||
lines.extend(
|
||||
[
|
||||
f"__letta_raw_args = {{{raw_args}}}",
|
||||
"try:",
|
||||
" from letta.functions.ast_parsers import coerce_dict_args_by_annotations",
|
||||
f" __letta_func = {self.tool.name}",
|
||||
" __letta_annotations = getattr(__letta_func, '__annotations__', {})",
|
||||
" __letta_coerced_args = coerce_dict_args_by_annotations(",
|
||||
" __letta_raw_args,",
|
||||
" __letta_annotations,",
|
||||
" allow_unsafe_eval=True,",
|
||||
" extra_globals=__letta_func.__globals__,",
|
||||
" )",
|
||||
]
|
||||
)
|
||||
for name in self.args:
|
||||
lines.append(f" {name} = __letta_coerced_args.get({name!r}, {name})")
|
||||
lines.extend(["except Exception:", " pass"])
|
||||
|
||||
if not self.is_async_function:
|
||||
# sync variant
|
||||
lines.append(f"_function_result = {invoke_function_call}")
|
||||
|
||||
@@ -122,6 +122,19 @@ class ModalFunctionExecutor:
|
||||
if inject_agent_state:
|
||||
kwargs["agent_state"] = agent_state
|
||||
|
||||
try:
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations
|
||||
|
||||
annotations = getattr(func, "__annotations__", {})
|
||||
kwargs = coerce_dict_args_by_annotations(
|
||||
kwargs,
|
||||
annotations,
|
||||
allow_unsafe_eval=True,
|
||||
extra_globals=func.__globals__,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_async:
|
||||
result = asyncio.run(func(**kwargs))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user