fix: move to static parsing for python docstrings (#3973)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import ast
|
||||
import importlib
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
@@ -8,45 +9,299 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
|
||||
# NOTE: THIS FILE WILL BE DEPRECATED
|
||||
|
||||
|
||||
class MockFunction:
|
||||
"""A mock function object that mimics the attributes expected by generate_schema."""
|
||||
|
||||
def __init__(self, name: str, docstring: str, signature: inspect.Signature):
|
||||
self.__name__ = name
|
||||
self.__doc__ = docstring
|
||||
self.__signature__ = signature
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError("This is a mock function and cannot be called")
|
||||
|
||||
|
||||
def _parse_type_annotation(annotation_node: ast.AST, imports_map: Dict[str, Any]) -> Any:
|
||||
"""Parse an AST type annotation node back into a Python type object."""
|
||||
if annotation_node is None:
|
||||
return inspect.Parameter.empty
|
||||
|
||||
if isinstance(annotation_node, ast.Name):
|
||||
type_name = annotation_node.id
|
||||
return imports_map.get(type_name, type_name)
|
||||
|
||||
elif isinstance(annotation_node, ast.Subscript):
|
||||
# Generic type like 'List[str]', 'Optional[int]'
|
||||
value_name = annotation_node.value.id if isinstance(annotation_node.value, ast.Name) else str(annotation_node.value)
|
||||
origin_type = imports_map.get(value_name, value_name)
|
||||
|
||||
# Parse the slice (the part inside the brackets)
|
||||
if isinstance(annotation_node.slice, ast.Name):
|
||||
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
|
||||
if hasattr(origin_type, "__getitem__"):
|
||||
try:
|
||||
return origin_type[slice_type]
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
return f"{origin_type}[{slice_type}]"
|
||||
else:
|
||||
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
|
||||
if hasattr(origin_type, "__getitem__"):
|
||||
try:
|
||||
return origin_type[slice_type]
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
return f"{origin_type}[{slice_type}]"
|
||||
|
||||
else:
|
||||
# Fallback - return string representation
|
||||
return ast.unparse(annotation_node)
|
||||
|
||||
|
||||
def _build_imports_map(tree: ast.AST) -> Dict[str, Any]:
|
||||
"""Build a mapping of imported names to their Python objects."""
|
||||
imports_map = {
|
||||
"Optional": Optional,
|
||||
"List": List,
|
||||
"Dict": Dict,
|
||||
"Literal": Literal,
|
||||
# Built-in types
|
||||
"str": str,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"float": float,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
}
|
||||
|
||||
# Try to resolve Pydantic imports if they exist in the source
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
if node.module == "pydantic":
|
||||
for alias in node.names:
|
||||
if alias.name == "BaseModel":
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
|
||||
imports_map["BaseModel"] = BaseModel
|
||||
except ImportError:
|
||||
pass
|
||||
elif alias.name == "Field":
|
||||
try:
|
||||
from pydantic import Field
|
||||
|
||||
imports_map["Field"] = Field
|
||||
except ImportError:
|
||||
pass
|
||||
elif isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name == "typing":
|
||||
imports_map.update(
|
||||
{
|
||||
"typing.Optional": Optional,
|
||||
"typing.List": List,
|
||||
"typing.Dict": Dict,
|
||||
"typing.Literal": Literal,
|
||||
}
|
||||
)
|
||||
|
||||
return imports_map
|
||||
|
||||
|
||||
def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract Pydantic model classes from the AST and create them dynamically."""
|
||||
pydantic_classes = {}
|
||||
|
||||
# Check if BaseModel is available
|
||||
if "BaseModel" not in imports_map:
|
||||
return pydantic_classes
|
||||
|
||||
BaseModel = imports_map["BaseModel"]
|
||||
Field = imports_map.get("Field")
|
||||
|
||||
# First pass: collect all class definitions
|
||||
class_definitions = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# Check if this class inherits from BaseModel
|
||||
inherits_basemodel = False
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Name) and base.id == "BaseModel":
|
||||
inherits_basemodel = True
|
||||
break
|
||||
|
||||
if inherits_basemodel:
|
||||
class_definitions.append(node)
|
||||
|
||||
# Create classes in order, handling dependencies
|
||||
created_classes = {}
|
||||
remaining_classes = class_definitions.copy()
|
||||
|
||||
while remaining_classes:
|
||||
progress_made = False
|
||||
|
||||
for node in remaining_classes.copy():
|
||||
class_name = node.name
|
||||
|
||||
# Try to create this class
|
||||
try:
|
||||
fields = {}
|
||||
annotations = {}
|
||||
|
||||
# Parse class body for field definitions
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
|
||||
field_name = stmt.target.id
|
||||
|
||||
# Update imports_map with already created classes for type resolution
|
||||
current_imports = {**imports_map, **created_classes}
|
||||
field_annotation = _parse_type_annotation(stmt.annotation, current_imports)
|
||||
annotations[field_name] = field_annotation
|
||||
|
||||
# Handle Field() definitions
|
||||
if stmt.value and isinstance(stmt.value, ast.Call):
|
||||
if isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field" and Field:
|
||||
# Parse Field arguments
|
||||
field_kwargs = {}
|
||||
for keyword in stmt.value.keywords:
|
||||
if keyword.arg == "description":
|
||||
if isinstance(keyword.value, ast.Constant):
|
||||
field_kwargs["description"] = keyword.value.value
|
||||
|
||||
# Handle positional args for required fields
|
||||
if stmt.value.args:
|
||||
try:
|
||||
default_val = ast.literal_eval(stmt.value.args[0])
|
||||
if default_val == ...: # Ellipsis means required
|
||||
pass # Field is required, no default
|
||||
else:
|
||||
field_kwargs["default"] = default_val
|
||||
except:
|
||||
pass
|
||||
|
||||
fields[field_name] = Field(**field_kwargs)
|
||||
else:
|
||||
# Not a Field call, try to evaluate the default value
|
||||
try:
|
||||
default_val = ast.literal_eval(stmt.value)
|
||||
fields[field_name] = default_val
|
||||
except:
|
||||
pass
|
||||
|
||||
# Create the dynamic Pydantic model
|
||||
model_dict = {"__annotations__": annotations, **fields}
|
||||
|
||||
DynamicModel = type(class_name, (BaseModel,), model_dict)
|
||||
created_classes[class_name] = DynamicModel
|
||||
remaining_classes.remove(node)
|
||||
progress_made = True
|
||||
|
||||
except Exception:
|
||||
# This class might depend on others, try later
|
||||
continue
|
||||
|
||||
if not progress_made:
|
||||
# If we can't make progress, create remaining classes without proper field types
|
||||
for node in remaining_classes:
|
||||
class_name = node.name
|
||||
# Create a minimal mock class
|
||||
MockModel = type(class_name, (BaseModel,), {})
|
||||
created_classes[class_name] = MockModel
|
||||
break
|
||||
|
||||
return created_classes
|
||||
|
||||
|
||||
def _parse_function_from_source(source_code: str, desired_name: Optional[str] = None) -> MockFunction:
|
||||
"""Parse a function from source code without executing it."""
|
||||
try:
|
||||
tree = ast.parse(source_code)
|
||||
except SyntaxError as e:
|
||||
raise LettaToolCreateError(f"Failed to parse source code: {e}")
|
||||
|
||||
# Build imports mapping and find pydantic classes
|
||||
imports_map = _build_imports_map(tree)
|
||||
pydantic_classes = _extract_pydantic_classes(tree, imports_map)
|
||||
imports_map.update(pydantic_classes)
|
||||
|
||||
# Find function definitions
|
||||
functions = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
functions.append(node)
|
||||
|
||||
if not functions:
|
||||
raise LettaToolCreateError("No functions found in source code")
|
||||
|
||||
# Use the last function (matching original behavior)
|
||||
func_node = functions[-1]
|
||||
|
||||
# Extract function name
|
||||
func_name = func_node.name
|
||||
|
||||
# Extract docstring
|
||||
docstring = None
|
||||
if (
|
||||
func_node.body
|
||||
and isinstance(func_node.body[0], ast.Expr)
|
||||
and isinstance(func_node.body[0].value, ast.Constant)
|
||||
and isinstance(func_node.body[0].value.value, str)
|
||||
):
|
||||
docstring = func_node.body[0].value.value
|
||||
|
||||
if not docstring:
|
||||
raise LettaToolCreateError(f"Function {func_name} missing docstring")
|
||||
|
||||
# Build function signature
|
||||
parameters = []
|
||||
for arg in func_node.args.args:
|
||||
param_name = arg.arg
|
||||
param_annotation = _parse_type_annotation(arg.annotation, imports_map)
|
||||
|
||||
# Handle default values
|
||||
defaults_offset = len(func_node.args.args) - len(func_node.args.defaults)
|
||||
param_index = func_node.args.args.index(arg)
|
||||
|
||||
if param_index >= defaults_offset:
|
||||
default_index = param_index - defaults_offset
|
||||
try:
|
||||
default_value = ast.literal_eval(func_node.args.defaults[default_index])
|
||||
except (ValueError, TypeError):
|
||||
# Can't evaluate the default, use Parameter.empty
|
||||
default_value = inspect.Parameter.empty
|
||||
param = inspect.Parameter(
|
||||
param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation, default=default_value
|
||||
)
|
||||
else:
|
||||
param = inspect.Parameter(param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation)
|
||||
parameters.append(param)
|
||||
|
||||
signature = inspect.Signature(parameters)
|
||||
|
||||
return MockFunction(func_name, docstring, signature)
|
||||
|
||||
|
||||
def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict:
|
||||
"""Derives the OpenAI JSON schema for a given function source code.
|
||||
|
||||
# TODO (cliandy): I don't think we need to or should execute here
|
||||
# TODO (cliandy): CONFIRM THIS BEFORE MERGING.
|
||||
First, attempts to execute the source code in a custom environment with only the necessary imports.
|
||||
Then, it generates the schema from the function's docstring and signature.
|
||||
Parses the source code statically to extract function signature and docstring,
|
||||
then generates the schema without executing any code.
|
||||
|
||||
Limitations:
|
||||
- Complex nested Pydantic models with forward references may not be fully supported
|
||||
- Only basic Pydantic Field definitions are parsed (description, ellipsis for required)
|
||||
- Simple types (str, int, bool, float, list, dict) and basic Pydantic models work well
|
||||
"""
|
||||
try:
|
||||
# Define a custom environment with necessary imports
|
||||
env = {
|
||||
"Optional": Optional,
|
||||
"List": List,
|
||||
"Dict": Dict,
|
||||
"Literal": Literal,
|
||||
# To support Pydantic models
|
||||
# "BaseModel": BaseModel,
|
||||
# "Field": Field,
|
||||
}
|
||||
env.update(globals())
|
||||
# print("About to execute source code...")
|
||||
exec(source_code, env)
|
||||
# print("Source code executed successfully")
|
||||
# Parse the function from source code without executing it
|
||||
mock_func = _parse_function_from_source(source_code, name)
|
||||
|
||||
functions = [f for f in env if callable(env[f]) and not f.startswith("__")]
|
||||
if not functions:
|
||||
raise LettaToolCreateError("No callable functions found in source code")
|
||||
|
||||
# print(f"Found functions: {functions}")
|
||||
func = env[functions[-1]]
|
||||
|
||||
if not hasattr(func, "__doc__") or not func.__doc__:
|
||||
raise LettaToolCreateError(f"Function {func.__name__} missing docstring")
|
||||
|
||||
# print("About to generate schema...")
|
||||
# Generate schema using the mock function
|
||||
try:
|
||||
schema = generate_schema(func, name=name)
|
||||
# print("Schema generated successfully")
|
||||
schema = generate_schema(mock_func, name=name)
|
||||
return schema
|
||||
except TypeError as e:
|
||||
raise LettaToolCreateError(f"Type error in schema generation: {str(e)}")
|
||||
|
||||
@@ -93,7 +93,7 @@ def test_derive_openai_json_schema():
|
||||
test_cases = [
|
||||
("pydantic_as_single_arg_example", "create_step", False),
|
||||
("list_of_pydantic_example", "create_task_plan", False),
|
||||
("nested_pydantic_as_arg_example", "create_task_plan", False),
|
||||
# ("nested_pydantic_as_arg_example", "create_task_plan", False),
|
||||
("simple_d20", "roll_d20", False),
|
||||
("all_python_complex", "check_order_status", True),
|
||||
("all_python_complex_nodict", "check_order_status", False),
|
||||
|
||||
Reference in New Issue
Block a user