fix: move to static parsing for python docstrings (#3973)

This commit is contained in:
Sarah Wooders
2025-08-17 15:16:13 -07:00
committed by GitHub
parent aca371b87e
commit 136aa89047
2 changed files with 287 additions and 32 deletions

View File

@@ -1,3 +1,4 @@
import ast
import importlib
import inspect
from collections.abc import Callable
@@ -8,45 +9,299 @@ from typing import Any, Dict, List, Literal, Optional
from letta.errors import LettaToolCreateError
from letta.functions.schema_generator import generate_schema
# NOTE: THIS FILE WILL BE DEPRECATED
class MockFunction:
"""A mock function object that mimics the attributes expected by generate_schema."""
def __init__(self, name: str, docstring: str, signature: inspect.Signature):
self.__name__ = name
self.__doc__ = docstring
self.__signature__ = signature
def __call__(self, *args, **kwargs):
raise NotImplementedError("This is a mock function and cannot be called")
def _parse_type_annotation(annotation_node: ast.AST, imports_map: Dict[str, Any]) -> Any:
"""Parse an AST type annotation node back into a Python type object."""
if annotation_node is None:
return inspect.Parameter.empty
if isinstance(annotation_node, ast.Name):
type_name = annotation_node.id
return imports_map.get(type_name, type_name)
elif isinstance(annotation_node, ast.Subscript):
# Generic type like 'List[str]', 'Optional[int]'
value_name = annotation_node.value.id if isinstance(annotation_node.value, ast.Name) else str(annotation_node.value)
origin_type = imports_map.get(value_name, value_name)
# Parse the slice (the part inside the brackets)
if isinstance(annotation_node.slice, ast.Name):
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
if hasattr(origin_type, "__getitem__"):
try:
return origin_type[slice_type]
except (TypeError, AttributeError):
pass
return f"{origin_type}[{slice_type}]"
else:
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
if hasattr(origin_type, "__getitem__"):
try:
return origin_type[slice_type]
except (TypeError, AttributeError):
pass
return f"{origin_type}[{slice_type}]"
else:
# Fallback - return string representation
return ast.unparse(annotation_node)
def _build_imports_map(tree: ast.AST) -> Dict[str, Any]:
"""Build a mapping of imported names to their Python objects."""
imports_map = {
"Optional": Optional,
"List": List,
"Dict": Dict,
"Literal": Literal,
# Built-in types
"str": str,
"int": int,
"bool": bool,
"float": float,
"list": list,
"dict": dict,
}
# Try to resolve Pydantic imports if they exist in the source
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
if node.module == "pydantic":
for alias in node.names:
if alias.name == "BaseModel":
try:
from pydantic import BaseModel
imports_map["BaseModel"] = BaseModel
except ImportError:
pass
elif alias.name == "Field":
try:
from pydantic import Field
imports_map["Field"] = Field
except ImportError:
pass
elif isinstance(node, ast.Import):
for alias in node.names:
if alias.name == "typing":
imports_map.update(
{
"typing.Optional": Optional,
"typing.List": List,
"typing.Dict": Dict,
"typing.Literal": Literal,
}
)
return imports_map
def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dict[str, Any]:
"""Extract Pydantic model classes from the AST and create them dynamically."""
pydantic_classes = {}
# Check if BaseModel is available
if "BaseModel" not in imports_map:
return pydantic_classes
BaseModel = imports_map["BaseModel"]
Field = imports_map.get("Field")
# First pass: collect all class definitions
class_definitions = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# Check if this class inherits from BaseModel
inherits_basemodel = False
for base in node.bases:
if isinstance(base, ast.Name) and base.id == "BaseModel":
inherits_basemodel = True
break
if inherits_basemodel:
class_definitions.append(node)
# Create classes in order, handling dependencies
created_classes = {}
remaining_classes = class_definitions.copy()
while remaining_classes:
progress_made = False
for node in remaining_classes.copy():
class_name = node.name
# Try to create this class
try:
fields = {}
annotations = {}
# Parse class body for field definitions
for stmt in node.body:
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
field_name = stmt.target.id
# Update imports_map with already created classes for type resolution
current_imports = {**imports_map, **created_classes}
field_annotation = _parse_type_annotation(stmt.annotation, current_imports)
annotations[field_name] = field_annotation
# Handle Field() definitions
if stmt.value and isinstance(stmt.value, ast.Call):
if isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field" and Field:
# Parse Field arguments
field_kwargs = {}
for keyword in stmt.value.keywords:
if keyword.arg == "description":
if isinstance(keyword.value, ast.Constant):
field_kwargs["description"] = keyword.value.value
# Handle positional args for required fields
if stmt.value.args:
try:
default_val = ast.literal_eval(stmt.value.args[0])
if default_val == ...: # Ellipsis means required
pass # Field is required, no default
else:
field_kwargs["default"] = default_val
except:
pass
fields[field_name] = Field(**field_kwargs)
else:
# Not a Field call, try to evaluate the default value
try:
default_val = ast.literal_eval(stmt.value)
fields[field_name] = default_val
except:
pass
# Create the dynamic Pydantic model
model_dict = {"__annotations__": annotations, **fields}
DynamicModel = type(class_name, (BaseModel,), model_dict)
created_classes[class_name] = DynamicModel
remaining_classes.remove(node)
progress_made = True
except Exception:
# This class might depend on others, try later
continue
if not progress_made:
# If we can't make progress, create remaining classes without proper field types
for node in remaining_classes:
class_name = node.name
# Create a minimal mock class
MockModel = type(class_name, (BaseModel,), {})
created_classes[class_name] = MockModel
break
return created_classes
def _parse_function_from_source(source_code: str, desired_name: Optional[str] = None) -> MockFunction:
"""Parse a function from source code without executing it."""
try:
tree = ast.parse(source_code)
except SyntaxError as e:
raise LettaToolCreateError(f"Failed to parse source code: {e}")
# Build imports mapping and find pydantic classes
imports_map = _build_imports_map(tree)
pydantic_classes = _extract_pydantic_classes(tree, imports_map)
imports_map.update(pydantic_classes)
# Find function definitions
functions = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
functions.append(node)
if not functions:
raise LettaToolCreateError("No functions found in source code")
# Use the last function (matching original behavior)
func_node = functions[-1]
# Extract function name
func_name = func_node.name
# Extract docstring
docstring = None
if (
func_node.body
and isinstance(func_node.body[0], ast.Expr)
and isinstance(func_node.body[0].value, ast.Constant)
and isinstance(func_node.body[0].value.value, str)
):
docstring = func_node.body[0].value.value
if not docstring:
raise LettaToolCreateError(f"Function {func_name} missing docstring")
# Build function signature
parameters = []
for arg in func_node.args.args:
param_name = arg.arg
param_annotation = _parse_type_annotation(arg.annotation, imports_map)
# Handle default values
defaults_offset = len(func_node.args.args) - len(func_node.args.defaults)
param_index = func_node.args.args.index(arg)
if param_index >= defaults_offset:
default_index = param_index - defaults_offset
try:
default_value = ast.literal_eval(func_node.args.defaults[default_index])
except (ValueError, TypeError):
# Can't evaluate the default, use Parameter.empty
default_value = inspect.Parameter.empty
param = inspect.Parameter(
param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation, default=default_value
)
else:
param = inspect.Parameter(param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation)
parameters.append(param)
signature = inspect.Signature(parameters)
return MockFunction(func_name, docstring, signature)
def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict:
"""Derives the OpenAI JSON schema for a given function source code.
# TODO (cliandy): I don't think we need to or should execute here
# TODO (cliandy): CONFIRM THIS BEFORE MERGING.
First, attempts to execute the source code in a custom environment with only the necessary imports.
Then, it generates the schema from the function's docstring and signature.
Parses the source code statically to extract function signature and docstring,
then generates the schema without executing any code.
Limitations:
- Complex nested Pydantic models with forward references may not be fully supported
- Only basic Pydantic Field definitions are parsed (description, ellipsis for required)
- Simple types (str, int, bool, float, list, dict) and basic Pydantic models work well
"""
try:
# Define a custom environment with necessary imports
env = {
"Optional": Optional,
"List": List,
"Dict": Dict,
"Literal": Literal,
# To support Pydantic models
# "BaseModel": BaseModel,
# "Field": Field,
}
env.update(globals())
# print("About to execute source code...")
exec(source_code, env)
# print("Source code executed successfully")
# Parse the function from source code without executing it
mock_func = _parse_function_from_source(source_code, name)
functions = [f for f in env if callable(env[f]) and not f.startswith("__")]
if not functions:
raise LettaToolCreateError("No callable functions found in source code")
# print(f"Found functions: {functions}")
func = env[functions[-1]]
if not hasattr(func, "__doc__") or not func.__doc__:
raise LettaToolCreateError(f"Function {func.__name__} missing docstring")
# print("About to generate schema...")
# Generate schema using the mock function
try:
schema = generate_schema(func, name=name)
# print("Schema generated successfully")
schema = generate_schema(mock_func, name=name)
return schema
except TypeError as e:
raise LettaToolCreateError(f"Type error in schema generation: {str(e)}")