Files
letta-server/letta/functions/functions.py
Kian Jones f5c4ab50f4 chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import

* remove all ignores, add FastAPI rules and Ruff rules

* add ty and precommit

* ruff stuff

* ty check fixes

* ty check fixes pt 2

* error on invalid
2026-02-24 10:55:11 -08:00

413 lines
16 KiB
Python

import ast
import importlib
import inspect
from collections.abc import Callable
from textwrap import dedent # remove indentation
from types import ModuleType
from typing import Any, Dict, List, Literal, Optional
from letta.errors import LettaToolCreateError
from letta.functions.schema_generator import generate_schema
# NOTE: THIS FILE WILL BE DEPRECATED
class MockFunction:
"""A mock function object that mimics the attributes expected by generate_schema."""
def __init__(self, name: str, docstring: str, signature: inspect.Signature):
self.__name__ = name
self.__doc__ = docstring
self.__signature__ = signature
def __call__(self, *args, **kwargs):
raise NotImplementedError("This is a mock function and cannot be called")
def _parse_type_annotation(annotation_node: ast.AST, imports_map: Dict[str, Any]) -> Any:
"""Parse an AST type annotation node back into a Python type object."""
if annotation_node is None:
return inspect.Parameter.empty
if isinstance(annotation_node, ast.Name):
type_name = annotation_node.id
return imports_map.get(type_name, type_name)
elif isinstance(annotation_node, ast.Subscript):
# Generic type like 'List[str]', 'Optional[int]'
value_name = annotation_node.value.id if isinstance(annotation_node.value, ast.Name) else str(annotation_node.value)
origin_type = imports_map.get(value_name, value_name)
# Parse the slice (the part inside the brackets)
if isinstance(annotation_node.slice, ast.Name):
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
if hasattr(origin_type, "__getitem__"):
try:
return origin_type[slice_type]
except (TypeError, AttributeError):
pass
return f"{origin_type}[{slice_type}]"
else:
slice_type = _parse_type_annotation(annotation_node.slice, imports_map)
if hasattr(origin_type, "__getitem__"):
try:
return origin_type[slice_type]
except (TypeError, AttributeError):
pass
return f"{origin_type}[{slice_type}]"
else:
# Fallback - return string representation
return ast.unparse(annotation_node)
def _build_imports_map(tree: ast.AST) -> Dict[str, Any]:
"""Build a mapping of imported names to their Python objects."""
imports_map = {
"Optional": Optional,
"List": List,
"Dict": Dict,
"Literal": Literal,
# Built-in types
"str": str,
"int": int,
"bool": bool,
"float": float,
"list": list,
"dict": dict,
}
# Try to resolve Pydantic imports if they exist in the source
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
if node.module == "pydantic":
for alias in node.names:
if alias.name == "BaseModel":
try:
from pydantic import BaseModel
imports_map["BaseModel"] = BaseModel
except ImportError:
pass
elif alias.name == "Field":
try:
from pydantic import Field
imports_map["Field"] = Field
except ImportError:
pass
elif isinstance(node, ast.Import):
for alias in node.names:
if alias.name == "typing":
imports_map.update(
{
"typing.Optional": Optional,
"typing.List": List,
"typing.Dict": Dict,
"typing.Literal": Literal,
}
)
return imports_map
def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dict[str, Any]:
"""Extract Pydantic model classes from the AST and create them dynamically."""
pydantic_classes = {}
# Check if BaseModel is available
if "BaseModel" not in imports_map:
return pydantic_classes
BaseModel = imports_map["BaseModel"]
Field = imports_map.get("Field")
# First pass: collect all class definitions
class_definitions = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# Check if this class inherits from BaseModel
inherits_basemodel = False
for base in node.bases:
if isinstance(base, ast.Name) and base.id == "BaseModel":
inherits_basemodel = True
break
if inherits_basemodel:
class_definitions.append(node)
# Create classes in order, handling dependencies
created_classes = {}
remaining_classes = class_definitions.copy()
while remaining_classes:
progress_made = False
for node in remaining_classes.copy():
class_name = node.name
# Try to create this class
try:
fields = {}
annotations = {}
# Parse class body for field definitions
for stmt in node.body:
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
field_name = stmt.target.id
# Update imports_map with already created classes for type resolution
current_imports = {**imports_map, **created_classes}
field_annotation = _parse_type_annotation(stmt.annotation, current_imports)
annotations[field_name] = field_annotation
# Handle Field() definitions
if stmt.value and isinstance(stmt.value, ast.Call):
if isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field" and Field:
# Parse Field arguments
field_kwargs = {}
for keyword in stmt.value.keywords:
if keyword.arg == "description":
if isinstance(keyword.value, ast.Constant):
field_kwargs["description"] = keyword.value.value
# Handle positional args for required fields
if stmt.value.args:
try:
default_val = ast.literal_eval(stmt.value.args[0])
if default_val == ...: # Ellipsis means required
pass # Field is required, no default
else:
field_kwargs["default"] = default_val
except Exception:
pass
fields[field_name] = Field(**field_kwargs)
else:
# Not a Field call, try to evaluate the default value
try:
default_val = ast.literal_eval(stmt.value)
fields[field_name] = default_val
except Exception:
pass
# Create the dynamic Pydantic model
model_dict = {"__annotations__": annotations, **fields}
DynamicModel = type(class_name, (BaseModel,), model_dict)
created_classes[class_name] = DynamicModel
remaining_classes.remove(node)
progress_made = True
except Exception:
# This class might depend on others, try later
continue
if not progress_made:
# If we can't make progress, create remaining classes without proper field types
for node in remaining_classes:
class_name = node.name
# Create a minimal mock class
MockModel = type(class_name, (BaseModel,), {})
created_classes[class_name] = MockModel
break
return created_classes
def _parse_function_from_source(source_code: str, desired_name: Optional[str] = None) -> MockFunction:
"""Parse a function from source code without executing it."""
try:
tree = ast.parse(source_code)
except SyntaxError as e:
raise LettaToolCreateError(f"Failed to parse source code: {e}")
# Build imports mapping and find pydantic classes
imports_map = _build_imports_map(tree)
pydantic_classes = _extract_pydantic_classes(tree, imports_map)
imports_map.update(pydantic_classes)
# Find function definitions
functions = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
functions.append(node)
if not functions:
raise LettaToolCreateError("No functions found in source code")
# Use the last function (matching original behavior)
func_node = functions[-1]
# Extract function name
func_name = func_node.name
# Extract docstring
docstring = None
if (
func_node.body
and isinstance(func_node.body[0], ast.Expr)
and isinstance(func_node.body[0].value, ast.Constant)
and isinstance(func_node.body[0].value.value, str)
):
docstring = func_node.body[0].value.value
if not docstring:
raise LettaToolCreateError(f"Function {func_name} missing docstring")
# Build function signature
parameters = []
for arg in func_node.args.args:
param_name = arg.arg
param_annotation = _parse_type_annotation(arg.annotation, imports_map)
# Handle default values
defaults_offset = len(func_node.args.args) - len(func_node.args.defaults)
param_index = func_node.args.args.index(arg)
if param_index >= defaults_offset:
default_index = param_index - defaults_offset
try:
default_value = ast.literal_eval(func_node.args.defaults[default_index])
except (ValueError, TypeError):
# Can't evaluate the default, use Parameter.empty
default_value = inspect.Parameter.empty
param = inspect.Parameter(
param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation, default=default_value
)
else:
param = inspect.Parameter(param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_annotation)
parameters.append(param)
signature = inspect.Signature(parameters)
return MockFunction(func_name, docstring, signature)
def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict:
"""Derives the OpenAI JSON schema for a given function source code.
Parses the source code statically to extract function signature and docstring,
then generates the schema without executing any code.
Limitations:
- Complex nested Pydantic models with forward references may not be fully supported
- Only basic Pydantic Field definitions are parsed (description, ellipsis for required)
- Simple types (str, int, bool, float, list, dict) and basic Pydantic models work well
"""
try:
# Parse the function from source code without executing it
mock_func = _parse_function_from_source(source_code, name)
# Generate schema using the mock function
try:
schema = generate_schema(mock_func, name=name)
return schema
except TypeError as e:
raise LettaToolCreateError(f"Type error in schema generation: {str(e)}")
except ValueError as e:
raise LettaToolCreateError(f"Value error in schema generation: {str(e)}")
except Exception as e:
raise LettaToolCreateError(f"Unexpected error in schema generation: {str(e)}")
except Exception as e:
import traceback
traceback.print_exc()
raise LettaToolCreateError(f"Schema generation failed: {str(e)}") from e
def parse_source_code(func) -> str:
"""Parse the source code of a function and remove indendation"""
source_code = dedent(inspect.getsource(func))
return source_code
# TODO (cliandy) refactor below two funcs
def get_function_from_module(module_name: str, function_name: str) -> Callable[..., Any]:
"""
Dynamically imports a function from a specified module.
Args:
module_name (str): The name of the module to import (e.g., 'base').
function_name (str): The name of the function to retrieve.
Returns:
Callable: The imported function.
Raises:
ModuleNotFoundError: If the specified module cannot be found.
AttributeError: If the function is not found in the module.
"""
try:
# Dynamically import the module
module = importlib.import_module(module_name)
# Retrieve the function
return getattr(module, function_name)
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
except AttributeError:
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")
def get_json_schema_from_module(module_name: str, function_name: str) -> dict:
"""
Dynamically loads a specific function from a module and generates its JSON schema.
Args:
module_name (str): The name of the module to import (e.g., 'base').
function_name (str): The name of the function to retrieve.
Returns:
dict: The JSON schema for the specified function.
Raises:
ModuleNotFoundError: If the specified module cannot be found.
AttributeError: If the function is not found in the module.
ValueError: If the attribute is not a user-defined function.
"""
try:
# Dynamically import the module
module = importlib.import_module(module_name)
# Retrieve the function
attr = getattr(module, function_name, None)
# Check if it's a user-defined function
if not (inspect.isfunction(attr) and attr.__module__ == module.__name__):
raise ValueError(f"'{function_name}' is not a user-defined function in module '{module_name}'")
# Generate schema (assuming a `generate_schema` function exists)
generated_schema = generate_schema(attr)
return generated_schema
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
except AttributeError:
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")
def load_function_set(module: ModuleType) -> dict:
"""Load the functions and generate schema for them, given a module object"""
function_dict = {}
for attr_name in dir(module):
# Get the attribute
attr = getattr(module, attr_name)
# Check if it's a callable function and not a built-in or special method
if inspect.isfunction(attr) and attr.__module__ == module.__name__:
if attr_name in function_dict:
raise ValueError(f"Found a duplicate of function name '{attr_name}'")
generated_schema = generate_schema(attr)
function_dict[attr_name] = {
"module": inspect.getsource(module),
"python_function": attr,
"json_schema": generated_schema,
}
if len(function_dict) == 0:
raise ValueError(f"No functions found in module {module}")
return function_dict