fix: safer type coersion for tools (#8990)
* mvp * perfrom type coercion in sandbox * fix: safely resolve typing annotations on host Use an AST whitelist for generic annotations to avoid eval while keeping list/dict coercion working. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -7,8 +7,52 @@ from typing import Dict, Optional, Tuple
|
|||||||
from letta.errors import LettaToolCreateError
|
from letta.errors import LettaToolCreateError
|
||||||
from letta.types import JsonDict
|
from letta.types import JsonDict
|
||||||
|
|
||||||
|
_ALLOWED_TYPING_NAMES = {name: obj for name, obj in vars(typing).items() if not name.startswith("_")}
|
||||||
|
_ALLOWED_BUILTIN_TYPES = {name: obj for name, obj in vars(builtins).items() if isinstance(obj, type)}
|
||||||
|
_ALLOWED_TYPE_NAMES = {**_ALLOWED_TYPING_NAMES, **_ALLOWED_BUILTIN_TYPES, "typing": typing}
|
||||||
|
|
||||||
def resolve_type(annotation: str):
|
|
||||||
|
def _resolve_annotation_node(node: ast.AST):
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
if node.id == "None":
|
||||||
|
return type(None)
|
||||||
|
if node.id in _ALLOWED_TYPE_NAMES:
|
||||||
|
return _ALLOWED_TYPE_NAMES[node.id]
|
||||||
|
raise ValueError(f"Unsupported annotation name: {node.id}")
|
||||||
|
|
||||||
|
if isinstance(node, ast.Attribute):
|
||||||
|
if isinstance(node.value, ast.Name) and node.value.id == "typing" and node.attr in _ALLOWED_TYPING_NAMES:
|
||||||
|
return _ALLOWED_TYPING_NAMES[node.attr]
|
||||||
|
raise ValueError("Unsupported annotation attribute")
|
||||||
|
|
||||||
|
if isinstance(node, ast.Subscript):
|
||||||
|
origin = _resolve_annotation_node(node.value)
|
||||||
|
args = _resolve_subscript_slice(node.slice)
|
||||||
|
return origin[args]
|
||||||
|
|
||||||
|
if isinstance(node, ast.Tuple):
|
||||||
|
return tuple(_resolve_annotation_node(elt) for elt in node.elts)
|
||||||
|
|
||||||
|
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||||
|
left = _resolve_annotation_node(node.left)
|
||||||
|
right = _resolve_annotation_node(node.right)
|
||||||
|
return left | right
|
||||||
|
|
||||||
|
if isinstance(node, ast.Constant) and node.value is None:
|
||||||
|
return type(None)
|
||||||
|
|
||||||
|
raise ValueError("Unsupported annotation expression")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_subscript_slice(slice_node: ast.AST):
|
||||||
|
if isinstance(slice_node, ast.Index):
|
||||||
|
slice_node = slice_node.value
|
||||||
|
if isinstance(slice_node, ast.Tuple):
|
||||||
|
return tuple(_resolve_annotation_node(elt) for elt in slice_node.elts)
|
||||||
|
return _resolve_annotation_node(slice_node)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_type(annotation: str, *, allow_unsafe_eval: bool = False, extra_globals: Optional[Dict[str, object]] = None):
|
||||||
"""
|
"""
|
||||||
Resolve a type annotation string into a Python type.
|
Resolve a type annotation string into a Python type.
|
||||||
Previously, primitive support for int, float, str, dict, list, set, tuple, bool.
|
Previously, primitive support for int, float, str, dict, list, set, tuple, bool.
|
||||||
@@ -23,15 +67,23 @@ def resolve_type(annotation: str):
|
|||||||
ValueError: If the annotation is unsupported or invalid.
|
ValueError: If the annotation is unsupported or invalid.
|
||||||
"""
|
"""
|
||||||
python_types = {**vars(typing), **vars(builtins)}
|
python_types = {**vars(typing), **vars(builtins)}
|
||||||
|
if extra_globals:
|
||||||
|
python_types.update(extra_globals)
|
||||||
|
|
||||||
if annotation in python_types:
|
if annotation in python_types:
|
||||||
return python_types[annotation]
|
return python_types[annotation]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Allow use of typing and builtins in a safe eval context
|
parsed = ast.parse(annotation, mode="eval")
|
||||||
return eval(annotation, python_types)
|
return _resolve_annotation_node(parsed.body)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(f"Unsupported annotation: {annotation}")
|
if allow_unsafe_eval:
|
||||||
|
try:
|
||||||
|
return eval(annotation, python_types)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Unsupported annotation: {annotation}") from exc
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported annotation: {annotation}")
|
||||||
|
|
||||||
|
|
||||||
# TODO :: THIS MUST BE EDITED TO HANDLE THINGS
|
# TODO :: THIS MUST BE EDITED TO HANDLE THINGS
|
||||||
@@ -62,14 +114,34 @@ def get_function_annotations_from_source(source_code: str, function_name: str) -
|
|||||||
|
|
||||||
|
|
||||||
# NOW json_loads -> ast.literal_eval -> typing.get_origin
|
# NOW json_loads -> ast.literal_eval -> typing.get_origin
|
||||||
def coerce_dict_args_by_annotations(function_args: JsonDict, annotations: Dict[str, str]) -> dict:
|
def coerce_dict_args_by_annotations(
|
||||||
|
function_args: JsonDict,
|
||||||
|
annotations: Dict[str, object],
|
||||||
|
*,
|
||||||
|
allow_unsafe_eval: bool = False,
|
||||||
|
extra_globals: Optional[Dict[str, object]] = None,
|
||||||
|
) -> dict:
|
||||||
coerced_args = dict(function_args) # Shallow copy
|
coerced_args = dict(function_args) # Shallow copy
|
||||||
|
|
||||||
for arg_name, value in coerced_args.items():
|
for arg_name, value in coerced_args.items():
|
||||||
if arg_name in annotations:
|
if arg_name in annotations:
|
||||||
annotation_str = annotations[arg_name]
|
annotation_str = annotations[arg_name]
|
||||||
try:
|
try:
|
||||||
arg_type = resolve_type(annotation_str)
|
annotation_value = annotations[arg_name]
|
||||||
|
if isinstance(annotation_value, str):
|
||||||
|
arg_type = resolve_type(
|
||||||
|
annotation_value,
|
||||||
|
allow_unsafe_eval=allow_unsafe_eval,
|
||||||
|
extra_globals=extra_globals,
|
||||||
|
)
|
||||||
|
elif isinstance(annotation_value, typing.ForwardRef):
|
||||||
|
arg_type = resolve_type(
|
||||||
|
annotation_value.__forward_arg__,
|
||||||
|
allow_unsafe_eval=allow_unsafe_eval,
|
||||||
|
extra_globals=extra_globals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arg_type = annotation_value
|
||||||
|
|
||||||
# Always parse strings using literal_eval or json if possible
|
# Always parse strings using literal_eval or json if possible
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
|||||||
@@ -533,6 +533,24 @@ class ToolExecutionSandbox:
|
|||||||
|
|
||||||
code += "\n" + self.tool.source_code + "\n"
|
code += "\n" + self.tool.source_code + "\n"
|
||||||
|
|
||||||
|
if self.args:
|
||||||
|
raw_args = ", ".join([f"{name!r}: {name}" for name in self.args])
|
||||||
|
code += f"__letta_raw_args = {{{raw_args}}}\n"
|
||||||
|
code += "try:\n"
|
||||||
|
code += " from letta.functions.ast_parsers import coerce_dict_args_by_annotations\n"
|
||||||
|
code += f" __letta_func = {self.tool.name}\n"
|
||||||
|
code += " __letta_annotations = getattr(__letta_func, '__annotations__', {})\n"
|
||||||
|
code += " __letta_coerced_args = coerce_dict_args_by_annotations(\n"
|
||||||
|
code += " __letta_raw_args,\n"
|
||||||
|
code += " __letta_annotations,\n"
|
||||||
|
code += " allow_unsafe_eval=True,\n"
|
||||||
|
code += " extra_globals=__letta_func.__globals__,\n"
|
||||||
|
code += " )\n"
|
||||||
|
for name in self.args:
|
||||||
|
code += f" {name} = __letta_coerced_args.get({name!r}, {name})\n"
|
||||||
|
code += "except Exception:\n"
|
||||||
|
code += " pass\n"
|
||||||
|
|
||||||
# TODO: handle wrapped print
|
# TODO: handle wrapped print
|
||||||
|
|
||||||
code += (
|
code += (
|
||||||
|
|||||||
@@ -167,6 +167,19 @@ def modal_tool_wrapper(tool: PydanticTool, actor: PydanticUser, sandbox_env_vars
|
|||||||
if "agent_state" in tool_func.__code__.co_varnames:
|
if "agent_state" in tool_func.__code__.co_varnames:
|
||||||
kwargs["agent_state"] = reconstructed_agent_state
|
kwargs["agent_state"] = reconstructed_agent_state
|
||||||
|
|
||||||
|
try:
|
||||||
|
from letta.functions.ast_parsers import coerce_dict_args_by_annotations
|
||||||
|
|
||||||
|
annotations = getattr(tool_func, "__annotations__", {})
|
||||||
|
kwargs = coerce_dict_args_by_annotations(
|
||||||
|
kwargs,
|
||||||
|
annotations,
|
||||||
|
allow_unsafe_eval=True,
|
||||||
|
extra_globals=tool_func.__globals__,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Execute the tool function (async or sync)
|
# Execute the tool function (async or sync)
|
||||||
if is_async:
|
if is_async:
|
||||||
result = asyncio.run(tool_func(**kwargs))
|
result = asyncio.run(tool_func(**kwargs))
|
||||||
|
|||||||
@@ -259,6 +259,27 @@ class AsyncToolSandboxBase(ABC):
|
|||||||
if tool_source_code:
|
if tool_source_code:
|
||||||
lines.append(tool_source_code.rstrip())
|
lines.append(tool_source_code.rstrip())
|
||||||
|
|
||||||
|
if self.args:
|
||||||
|
raw_args = ", ".join([f"{name!r}: {name}" for name in self.args])
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
f"__letta_raw_args = {{{raw_args}}}",
|
||||||
|
"try:",
|
||||||
|
" from letta.functions.ast_parsers import coerce_dict_args_by_annotations",
|
||||||
|
f" __letta_func = {self.tool.name}",
|
||||||
|
" __letta_annotations = getattr(__letta_func, '__annotations__', {})",
|
||||||
|
" __letta_coerced_args = coerce_dict_args_by_annotations(",
|
||||||
|
" __letta_raw_args,",
|
||||||
|
" __letta_annotations,",
|
||||||
|
" allow_unsafe_eval=True,",
|
||||||
|
" extra_globals=__letta_func.__globals__,",
|
||||||
|
" )",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for name in self.args:
|
||||||
|
lines.append(f" {name} = __letta_coerced_args.get({name!r}, {name})")
|
||||||
|
lines.extend(["except Exception:", " pass"])
|
||||||
|
|
||||||
if not self.is_async_function:
|
if not self.is_async_function:
|
||||||
# sync variant
|
# sync variant
|
||||||
lines.append(f"_function_result = {invoke_function_call}")
|
lines.append(f"_function_result = {invoke_function_call}")
|
||||||
|
|||||||
@@ -122,6 +122,19 @@ class ModalFunctionExecutor:
|
|||||||
if inject_agent_state:
|
if inject_agent_state:
|
||||||
kwargs["agent_state"] = agent_state
|
kwargs["agent_state"] = agent_state
|
||||||
|
|
||||||
|
try:
|
||||||
|
from letta.functions.ast_parsers import coerce_dict_args_by_annotations
|
||||||
|
|
||||||
|
annotations = getattr(func, "__annotations__", {})
|
||||||
|
kwargs = coerce_dict_args_by_annotations(
|
||||||
|
kwargs,
|
||||||
|
annotations,
|
||||||
|
allow_unsafe_eval=True,
|
||||||
|
extra_globals=func.__globals__,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if is_async:
|
if is_async:
|
||||||
result = asyncio.run(func(**kwargs))
|
result = asyncio.run(func(**kwargs))
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user