diff --git a/letta/agent.py b/letta/agent.py index c2335ddd..29575ad5 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1,4 +1,3 @@ -import inspect import json import time import traceback @@ -20,6 +19,7 @@ from letta.constants import ( REQ_HEARTBEAT_MESSAGE, ) from letta.errors import ContextWindowExceededError +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.functions import get_function_from_module from letta.helpers import ToolRulesSolver from letta.interface import AgentInterface @@ -223,15 +223,10 @@ class Agent(BaseAgent): function_response = callable_func(**function_args) self.update_memory_if_changed(agent_state_copy.memory) else: - # TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args. - env = {} - env.update(globals()) - exec(target_letta_tool.source_code, env) - callable_func = env[target_letta_tool.json_schema["name"]] - spec = inspect.getfullargspec(callable_func).annotations - for name, arg in function_args.items(): - if isinstance(function_args[name], dict): - function_args[name] = spec[name](**function_args[name]) + # Parse the source code to extract function annotations + annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) + # Coerce the function arguments to the correct types based on the annotations + function_args = coerce_dict_args_by_annotations(function_args, annotations) # execute tool in a sandbox # TODO: allow agent_state to specify which sandbox to execute tools in diff --git a/letta/functions/ast_parsers.py b/letta/functions/ast_parsers.py new file mode 100644 index 00000000..fbe8a06b --- /dev/null +++ b/letta/functions/ast_parsers.py @@ -0,0 +1,105 @@ +import ast +import json +from typing import Dict + +# Registry of known types for annotation resolution +BUILTIN_TYPES = { + "int": int, + "float": float, + "str": str, + "dict": dict, + "list": list, + "set": set, + "tuple": tuple, + "bool": bool, +} + + +def resolve_type(annotation: str): + """ + Resolve a type annotation string into a Python type. + + Args: + annotation (str): The annotation string (e.g., 'int', 'list', etc.). + + Returns: + type: The corresponding Python type. + + Raises: + ValueError: If the annotation is unsupported or invalid. + """ + if annotation in BUILTIN_TYPES: + return BUILTIN_TYPES[annotation] + + try: + parsed = ast.literal_eval(annotation) + if isinstance(parsed, type): + return parsed + raise ValueError(f"Annotation '{annotation}' is not a recognized type.") + except (ValueError, SyntaxError): + raise ValueError(f"Unsupported annotation: {annotation}") + + +def get_function_annotations_from_source(source_code: str, function_name: str) -> Dict[str, str]: + """ + Parse the source code to extract annotations for a given function name. + + Args: + source_code (str): The Python source code containing the function. + function_name (str): The name of the function to extract annotations for. + + Returns: + Dict[str, str]: A dictionary of argument names to their annotation strings. + + Raises: + ValueError: If the function is not found in the source code. + """ + tree = ast.parse(source_code) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef) and node.name == function_name: + annotations = {} + for arg in node.args.args: + if arg.annotation is not None: + annotation_str = ast.unparse(arg.annotation) + annotations[arg.arg] = annotation_str + return annotations + raise ValueError(f"Function '{function_name}' not found in the provided source code.") + + +def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict: + """ + Coerce arguments in a dictionary to their annotated types. + + Args: + function_args (dict): The original function arguments. + annotations (Dict[str, str]): Argument annotations as strings. + + Returns: + dict: The updated dictionary with coerced argument types. + + Raises: + ValueError: If type coercion fails for an argument. + """ + coerced_args = dict(function_args) # Shallow copy for mutation safety + + for arg_name, value in coerced_args.items(): + if arg_name in annotations: + annotation_str = annotations[arg_name] + try: + # Resolve the type from the annotation + arg_type = resolve_type(annotation_str) + + # Handle JSON-like inputs for dict and list types + if arg_type in {dict, list} and isinstance(value, str): + try: + # First, try JSON parsing + value = json.loads(value) + except json.JSONDecodeError: + # Fall back to literal_eval for Python-specific literals + value = ast.literal_eval(value) + + # Coerce the value to the resolved type + coerced_args[arg_name] = arg_type(value) + except (TypeError, ValueError, json.JSONDecodeError, SyntaxError) as e: + raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}") + return coerced_args diff --git a/tests/test_ast_parsing.py b/tests/test_ast_parsing.py new file mode 100644 index 00000000..938cef20 --- /dev/null +++ b/tests/test_ast_parsing.py @@ -0,0 +1,216 @@ +import pytest + +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source + +# ----------------------------------------------------------------------- +# Example source code for testing multiple scenarios, including: +# 1) A class-based custom type (which we won't handle properly). +# 2) Functions with multiple argument types. +# 3) A function with default arguments. +# 4) A function with no arguments. +# 5) A function that shares the same name as another symbol. +# ----------------------------------------------------------------------- +example_source_code = r""" +class CustomClass: + def __init__(self, x): + self.x = x + +def unrelated_symbol(): + pass + +def no_args_func(): + pass + +def default_args_func(x: int = 5, y: str = "hello"): + return x, y + +def my_function(a: int, b: float, c: str, d: list, e: dict, f: CustomClass = None): + pass + +def my_function_duplicate(): + # This function shares the name "my_function" partially, but isn't an exact match + pass +""" + + +# --------------------- get_function_annotations_from_source TESTS --------------------- # + + +def test_get_function_annotations_found(): + """ + Test that we correctly parse annotations for a function + that includes multiple argument types and a custom class. + """ + annotations = get_function_annotations_from_source(example_source_code, "my_function") + assert annotations == { + "a": "int", + "b": "float", + "c": "str", + "d": "list", + "e": "dict", + "f": "CustomClass", + } + + +def test_get_function_annotations_not_found(): + """ + If the requested function name doesn't exist exactly, + we should raise a ValueError. + """ + with pytest.raises(ValueError, match="Function 'missing_function' not found"): + get_function_annotations_from_source(example_source_code, "missing_function") + + +def test_get_function_annotations_no_args(): + """ + Check that a function without arguments returns an empty annotations dict. + """ + annotations = get_function_annotations_from_source(example_source_code, "no_args_func") + assert annotations == {} + + +def test_get_function_annotations_with_default_values(): + """ + Ensure that a function with default arguments still captures the annotations. + """ + annotations = get_function_annotations_from_source(example_source_code, "default_args_func") + assert annotations == {"x": "int", "y": "str"} + + +def test_get_function_annotations_partial_name_collision(): + """ + Ensure we only match the exact function name, not partial collisions. + """ + # This will match 'my_function' exactly, ignoring 'my_function_duplicate' + annotations = get_function_annotations_from_source(example_source_code, "my_function") + assert "a" in annotations # Means it matched the correct function + # No error expected here, just making sure we didn't accidentally parse "my_function_duplicate". + + +# --------------------- coerce_dict_args_by_annotations TESTS --------------------- # + + +def test_coerce_dict_args_success(): + """ + Basic success scenario with standard types: + int, float, str, list, dict. + """ + annotations = {"a": "int", "b": "float", "c": "str", "d": "list", "e": "dict"} + function_args = {"a": "42", "b": "3.14", "c": 123, "d": "[1, 2, 3]", "e": '{"key": "value"}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == 3.14 + assert coerced_args["c"] == "123" + assert coerced_args["d"] == [1, 2, 3] + assert coerced_args["e"] == {"key": "value"} + + +def test_coerce_dict_args_invalid_type(): + """ + If the value cannot be coerced into the annotation, + a ValueError should be raised. + """ + annotations = {"a": "int"} + function_args = {"a": "invalid_int"} + + with pytest.raises(ValueError, match="Failed to coerce argument 'a' to int"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_no_annotations(): + """ + If there are no annotations, we do no coercion. + """ + annotations = {} + function_args = {"a": 42, "b": "hello"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args == function_args # Exactly the same dict back + + +def test_coerce_dict_args_partial_annotations(): + """ + Only coerce annotated arguments; leave unannotated ones unchanged. + """ + annotations = {"a": "int"} + function_args = {"a": "42", "b": "no_annotation"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == "no_annotation" + + +def test_coerce_dict_args_with_missing_args(): + """ + If function_args lacks some keys listed in annotations, + those are simply not coerced. (We do not add them.) + """ + annotations = {"a": "int", "b": "float"} + function_args = {"a": "42"} # Missing 'b' + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert "b" not in coerced_args + + +def test_coerce_dict_args_unexpected_keys(): + """ + If function_args has extra keys not in annotations, + we leave them alone. + """ + annotations = {"a": "int"} + function_args = {"a": "42", "z": 999} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["z"] == 999 # unchanged + + +def test_coerce_dict_args_unsupported_custom_class(): + """ + If someone tries to pass an annotation that isn't supported (like a custom class), + we should raise a ValueError (or similarly handle the error) rather than silently + accept it. + """ + annotations = {"f": "CustomClass"} # We can't resolve this + function_args = {"f": {"x": 1}} + with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass: Unsupported annotation: CustomClass"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_with_complex_types(): + """ + Confirm the ability to parse built-in complex data (lists, dicts, etc.) + when given as strings. + """ + annotations = {"big_list": "list", "nested_dict": "dict"} + function_args = {"big_list": "[1, 2, [3, 4], {'five': 5}]", "nested_dict": '{"alpha": [10, 20], "beta": {"x": 1, "y": 2}}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["big_list"] == [1, 2, [3, 4], {"five": 5}] + assert coerced_args["nested_dict"] == { + "alpha": [10, 20], + "beta": {"x": 1, "y": 2}, + } + + +def test_coerce_dict_args_non_string_keys(): + """ + Validate behavior if `function_args` includes non-string keys. + (We should simply skip annotation checks for them.) + """ + annotations = {"a": "int"} + function_args = {123: "42", "a": "42"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + # 'a' is coerced to int + assert coerced_args["a"] == 42 + # 123 remains untouched + assert coerced_args[123] == "42" + + +def test_coerce_dict_args_non_parseable_list_or_dict(): + """ + Test passing incorrectly formatted JSON for a 'list' or 'dict' annotation. + """ + annotations = {"bad_list": "list", "bad_dict": "dict"} + function_args = {"bad_list": "[1, 2, 3", "bad_dict": '{"key": "value"'} # missing brackets + + with pytest.raises(ValueError, match="Failed to coerce argument 'bad_list' to list"): + coerce_dict_args_by_annotations(function_args, annotations)