feat: add args schema for tool object (#1160)

This commit is contained in:
cthomas
2025-03-01 14:53:10 -08:00
committed by GitHub
parent 2bfbbeb9b8
commit 6a6e50a4f4
20 changed files with 1011 additions and 489 deletions

View File

@@ -0,0 +1,31 @@
"""add args schema to tools
Revision ID: 54f2311edb62
Revises: b183663c6769
Create Date: 2025-02-27 16:45:50.835081
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "54f2311edb62"
down_revision: Union[str, None] = "b183663c6769"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tools", sa.Column("args_json_schema", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tools", "args_json_schema")
# ### end Alembic commands ###

View File

@@ -1,6 +1,8 @@
import ast
import json
from typing import Dict
from typing import Dict, Optional, Tuple
from letta.errors import LettaToolCreateError
# Registry of known types for annotation resolution
BUILTIN_TYPES = {
@@ -103,3 +105,50 @@ def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str,
except (TypeError, ValueError, json.JSONDecodeError, SyntaxError) as e:
raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}")
return coerced_args
def get_function_name_and_description(source_code: str, name: Optional[str] = None) -> Tuple[str, str]:
"""Gets the name and description for a given function source code by parsing the AST.
Args:
source_code: The source code to parse
name: Optional override for the function name
Returns:
Tuple of (function_name, docstring)
"""
try:
# Parse the source code into an AST
tree = ast.parse(source_code)
# Find the last function definition
function_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
function_def = node
if not function_def:
raise LettaToolCreateError("No function definition found in source code")
# Get the function name
function_name = name if name is not None else function_def.name
# Get the docstring if it exists
docstring = ast.get_docstring(function_def)
if not function_name:
raise LettaToolCreateError("Could not determine function name")
if not docstring:
raise LettaToolCreateError("Docstring is missing")
return function_name, docstring
except Exception as e:
raise LettaToolCreateError(f"Failed to parse function name and docstring: {str(e)}")
except Exception as e:
import traceback
traceback.print_exc()
raise LettaToolCreateError(f"Name and docstring generation failed: {str(e)}")

View File

@@ -1,11 +1,11 @@
import asyncio
import threading
from random import uniform
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union
import humps
from composio.constants import DEFAULT_ENTITY_ID
from pydantic import BaseModel
from pydantic import BaseModel, Field, create_model
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.functions.interface import MultiAgentMessagingInterface
@@ -561,3 +561,86 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async finish", message=message, tags=tags)
return final
def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseModel]:
"""Creates a Pydantic model from a JSON schema.
Args:
schema: The JSON schema dictionary
Returns:
A Pydantic model class
"""
# First create any nested models from $defs
nested_models = {}
if "$defs" in schema:
for name, model_schema in schema["$defs"].items():
# Create field definitions for the nested model
fields = {}
for field_name, field_schema in model_schema["properties"].items():
field_type = _get_field_type(field_schema)
required = field_name in model_schema.get("required", [])
description = field_schema.get("description", "") # Get description or empty string
fields[field_name] = (field_type, Field(..., description=description) if required else Field(None, description=description))
# Create the nested model
nested_models[name] = create_model(name, **fields)
# Create the main model fields
fields = {}
for field_name, field_schema in schema["properties"].items():
field_type = _get_field_type(field_schema, nested_models)
required = field_name in schema.get("required", [])
description = field_schema.get("description", "") # Get description or empty string
fields[field_name] = (field_type, Field(..., description=description) if required else Field(None, description=description))
# Create and return the main model
return create_model(schema.get("title", "DynamicModel"), **fields)
def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] = None) -> Any:
"""Helper to convert JSON schema types to Python types."""
if field_schema.get("type") == "string":
return str
elif field_schema.get("type") == "integer":
return int
elif field_schema.get("type") == "number":
return float
elif field_schema.get("type") == "boolean":
return bool
elif field_schema.get("type") == "array":
item_type = field_schema["items"].get("$ref", "").split("/")[-1]
if item_type and nested_models and item_type in nested_models:
return List[nested_models[item_type]]
return List[_get_field_type(field_schema["items"])]
elif field_schema.get("type") == "object":
if "$ref" in field_schema:
ref_type = field_schema["$ref"].split("/")[-1]
if nested_models and ref_type in nested_models:
return nested_models[ref_type]
elif "additionalProperties" in field_schema:
value_type = _get_field_type(field_schema["additionalProperties"], nested_models)
return Dict[str, value_type]
return dict
elif field_schema.get("$ref") is not None:
ref_type = field_schema["$ref"].split("/")[-1]
if nested_models and ref_type in nested_models:
return nested_models[ref_type]
else:
raise ValueError(f"Reference {ref_type} not found in nested models")
elif field_schema.get("anyOf") is not None:
types = []
has_null = False
for type_option in field_schema["anyOf"]:
if type_option.get("type") == "null":
has_null = True
else:
types.append(_get_field_type(type_option, nested_models))
# If we have exactly one type and null, make it Optional
if has_null and len(types) == 1:
return Optional[types[0]]
# Otherwise make it a Union of all types
else:
return Union[tuple(types)]
raise ValueError(f"Unable to convert pydantic field schema to type: {field_schema}")

View File

@@ -43,6 +43,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
args_json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The JSON schema of the function arguments.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")

View File

@@ -8,8 +8,9 @@ from letta.constants import (
LETTA_CORE_TOOL_MODULE_NAME,
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
)
from letta.functions.ast_parsers import get_function_name_and_description
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper, generate_model_from_args_json_schema
from letta.functions.schema_generator import generate_schema_from_args_schema_v2, generate_tool_schema_for_composio
from letta.log import get_logger
from letta.orm.enums import ToolType
@@ -46,6 +47,7 @@ class Tool(BaseTool):
# code
source_code: Optional[str] = Field(None, description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
# tool configuration
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
@@ -70,7 +72,16 @@ class Tool(BaseTool):
# TODO: Instead of checking the tag, we should having `COMPOSIO` as a specific ToolType
# TODO: We skip this for Composio bc composio json schemas are derived differently
if not (COMPOSIO_TOOL_TAG_NAME in self.tags):
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
if self.args_json_schema is not None:
name, description = get_function_name_and_description(self.source_code, self.name)
args_schema = generate_model_from_args_json_schema(self.args_json_schema)
self.json_schema = generate_schema_from_args_schema_v2(
args_schema=args_schema,
name=name,
description=description,
)
else:
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}:
# If it's letta core tool, we generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
@@ -107,6 +118,7 @@ class ToolCreate(LettaBase):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
@classmethod
@@ -189,6 +201,7 @@ class ToolUpdate(LettaBase):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
return_char_limit: Optional[int] = Field(None, description="The maximum number of characters in the response.")
class Config:
@@ -202,3 +215,4 @@ class ToolRunFromSource(LettaBase):
env_vars: Dict[str, str] = Field(None, description="The environment variables to pass to the tool.")
name: Optional[str] = Field(None, description="The name of the tool to run.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")

View File

@@ -190,6 +190,7 @@ def run_tool_from_source(
tool_args=request.args,
tool_env_vars=request.env_vars,
tool_name=request.name,
tool_args_json_schema=request.args_json_schema,
actor=actor,
)
except LettaToolCreateError as e:

View File

@@ -6,7 +6,7 @@ import traceback
import warnings
from abc import abstractmethod
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from composio.client import Composio
from composio.client.collections import ActionModel, AppModel
@@ -1104,6 +1104,7 @@ class SyncServer(Server):
tool_env_vars: Optional[Dict[str, str]] = None,
tool_source_type: Optional[str] = None,
tool_name: Optional[str] = None,
tool_args_json_schema: Optional[Dict[str, Any]] = None,
) -> ToolReturnMessage:
"""Run a tool from source code"""
if tool_source_type is not None and tool_source_type != "python":
@@ -1113,6 +1114,7 @@ class SyncServer(Server):
tool = Tool(
name=tool_name,
source_code=tool_source,
args_json_schema=tool_args_json_schema,
)
assert tool.name is not None, "Failed to create tool object"

View File

@@ -4,6 +4,10 @@ import subprocess
import venv
from typing import Dict, Optional
from datamodel_code_generator import DataModelType, PythonVersion
from datamodel_code_generator.model import get_data_model_types
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
from letta.log import get_logger
from letta.schemas.sandbox_config import LocalSandboxConfig
@@ -153,3 +157,17 @@ def create_venv_for_local_sandbox(sandbox_dir_path: str, venv_path: str, env: Di
except subprocess.CalledProcessError as e:
logger.error(f"Error while setting up the virtual environment: {e}")
raise RuntimeError(f"Failed to set up the virtual environment: {e}")
def add_imports_and_pydantic_schemas_for_args(args_json_schema: dict) -> str:
data_model_types = get_data_model_types(DataModelType.PydanticV2BaseModel, target_python_version=PythonVersion.PY_311)
parser = JsonSchemaParser(
str(args_json_schema),
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
data_type_manager_type=data_model_types.data_type_manager,
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
)
result = parser.parse()
return result

View File

@@ -11,12 +11,14 @@ import traceback
import uuid
from typing import Any, Dict, Optional
from letta.functions.helpers import generate_model_from_args_json_schema
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.helpers.tool_execution_helper import (
add_imports_and_pydantic_schemas_for_args,
create_venv_for_local_sandbox,
find_python_executable,
install_pip_requirements_for_sandbox,
@@ -408,20 +410,35 @@ class ToolExecutionSandbox:
code += "import sys\n"
code += "import base64\n"
# Load the agent state data into the program
# imports to support agent state
if agent_state:
code += "import letta\n"
code += "from letta import * \n"
import pickle
if self.tool.args_json_schema:
schema_code = add_imports_and_pydantic_schemas_for_args(self.tool.args_json_schema)
if "from __future__ import annotations" in schema_code:
schema_code = schema_code.replace("from __future__ import annotations", "").lstrip()
code = "from __future__ import annotations\n\n" + code
code += schema_code + "\n"
# load the agent state
if agent_state:
agent_state_pickle = pickle.dumps(agent_state)
code += f"agent_state = pickle.loads({agent_state_pickle})\n"
else:
# agent state is None
code += "agent_state = None\n"
for param in self.args:
code += self.initialize_param(param, self.args[param])
if self.tool.args_json_schema:
args_schema = generate_model_from_args_json_schema(self.tool.args_json_schema)
code += f"args_object = {args_schema.__name__}(**{self.args})\n"
for param in self.args:
code += f"{param} = args_object.{param}\n"
else:
for param in self.args:
code += self.initialize_param(param, self.args[param])
if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name):
inject_agent_state = True

View File

@@ -42,7 +42,7 @@ class ToolManager:
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
if tool:
# Put to dict and remove fields that should not be reset
update_data = pydantic_tool.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
# If there's anything to update
if update_data:

973
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -87,6 +87,7 @@ faker = "^36.1.0"
colorama = "^0.4.6"
marshmallow-sqlalchemy = "^1.4.1"
boto3 = {version = "^1.36.24", optional = true}
datamodel-code-generator = {extras = ["http"], version = "^0.25.0"}
[tool.poetry.extras]

View File

@@ -1,11 +1,14 @@
import importlib.util
import inspect
import json
import os
import pytest
from pydantic import BaseModel
from letta.functions.functions import derive_openai_json_schema
from letta.llm_api.helpers import convert_to_structured_output, make_post_request
from letta.schemas.tool import ToolCreate
from letta.schemas.tool import Tool, ToolCreate
def _clean_diff(d1, d2):
@@ -233,3 +236,53 @@ def test_langchain_tool_schema_generation(openai_model: str, structured_output:
print(f"Failed to call OpenAI using schema {schema} generated from {langchain_tool.name}\n\n")
raise
@pytest.mark.parametrize("openai_model", ["gpt-4", "gpt-4o"])
@pytest.mark.parametrize("structured_output", [True, False])
def test_valid_schemas_with_pydantic_args_schema(openai_model: str, structured_output: bool):
"""Test that we can send the schemas to OpenAI and get a tool call back."""
for filename in [
"pydantic_as_single_arg_example",
"list_of_pydantic_example",
"nested_pydantic_as_arg_example",
"simple_d20",
"all_python_complex",
"all_python_complex_nodict",
]:
# Import the module dynamically
file_path = os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{filename}.py")
spec = importlib.util.spec_from_file_location(filename, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find the function definition and args schema if defined
last_function_name, last_function_source, last_model_class = None, None, None
for name, obj in inspect.getmembers(module):
if inspect.isfunction(obj) and obj.__module__ == module.__name__:
last_function_name = name
last_function_source = inspect.getsource(obj) # only import the function, not the whole file
if inspect.isclass(obj) and obj.__module__ == module.__name__ and issubclass(obj, BaseModel):
last_model_class = obj
# Get the ArgsSchema if it exists
args_schema = None
if last_model_class:
args_schema = last_model_class.model_json_schema()
tool = Tool(
name=last_function_name,
source_code=last_function_source,
args_json_schema=args_schema,
)
schema = tool.json_schema
print(f"==== TESTING OPENAI PAYLOAD FOR {openai_model} + {filename} ====")
# We should expect the all_python_complex one to fail when structured_output=True
if filename == "all_python_complex" and structured_output:
with pytest.raises(ValueError):
_openai_payload(openai_model, schema, structured_output)
else:
_openai_payload(openai_model, schema, structured_output)

View File

@@ -1,5 +1,34 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class ArgsSchema(BaseModel):
order_number: int = Field(
...,
description="The order number to check on.",
)
customer_name: str = Field(
...,
description="The customer name to check on.",
)
related_tickets: List[str] = Field(
...,
description="A list of related ticket numbers.",
)
related_ticket_reasons: dict = Field(
...,
description="A dictionary of reasons for each related ticket.",
)
severity: float = Field(
...,
description="The severity of the order.",
)
metadata: Optional[dict] = Field(
None,
description="Optional metadata about the order.",
)
def check_order_status(
order_number: int,

View File

@@ -1,5 +1,30 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class ArgsSchema(BaseModel):
order_number: int = Field(
...,
description="The order number to check on.",
)
customer_name: str = Field(
...,
description="The customer name to check on.",
)
related_tickets: List[str] = Field(
...,
description="A list of related ticket numbers.",
)
severity: float = Field(
...,
description="The severity of the order.",
)
metadata: Optional[str] = Field(
None,
description="Optional metadata about the order.",
)
def check_order_status(
order_number: int,

View File

@@ -16,6 +16,13 @@ class Step(BaseModel):
)
class ArgsSchema(BaseModel):
steps: list[Step] = Field(
...,
description="List of steps to add to the task plan.",
)
def create_task_plan(steps: list[Step]) -> str:
"""
Creates a task plan for the current task.

View File

@@ -1,39 +1,43 @@
{
"name": "create_task_plan",
"description": "Creates a task plan for the current task.",
"parameters": {
"type": "object",
"properties": {
"steps": {
"type": "object",
"description": "List of steps to add to the task plan.",
"properties": {
"steps": {
"type": "array",
"description": "A list of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
"name": "create_task_plan",
"description": "Creates a task plan for the current task.",
"parameters": {
"type": "object",
"properties": {
"steps": {
"type": "object",
"description": "List of steps to add to the task plan.",
"properties": {
"steps": {
"type": "array",
"description": "A list of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"required": ["name", "key", "description"]
}
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
},
"required": ["name", "key", "description"]
}
},
"required": ["steps"]
}
}
},
"required": ["steps"]
},
"required": ["steps"]
}
"completed": {
"type": "integer",
"description": "The number of steps to add as completed to the task plan."
}
},
"required": ["steps", "completed"]
}
}

View File

@@ -25,7 +25,18 @@ class Steps(BaseModel):
)
def create_task_plan(steps: Steps) -> str:
class ArgsSchema(BaseModel):
steps: Steps = Field(
...,
description="A list of steps to add to the task plan.",
)
completed: int = Field(
...,
description="The number of steps to add as completed to the task plan.",
)
def create_task_plan(steps: Steps, completed: int) -> str:
"""
Creates a task plan for the current task.
It takes in a list of steps, and updates the task with the new steps provided.
@@ -39,6 +50,7 @@ def create_task_plan(steps: Steps) -> str:
Args:
steps: List of steps to add to the task plan.
completed: The number of steps to add as completed to the task plan.
Returns:
str: A summary of the updated task plan after deletion

View File

@@ -1,43 +1,47 @@
{
"name": "create_task_plan",
"description": "Creates a task plan for the current task.",
"strict": true,
"parameters": {
"type": "object",
"properties": {
"steps": {
"type": "object",
"description": "List of steps to add to the task plan.",
"properties": {
"steps": {
"type": "array",
"description": "A list of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
"name": "create_task_plan",
"description": "Creates a task plan for the current task.",
"strict": true,
"parameters": {
"type": "object",
"properties": {
"steps": {
"type": "object",
"description": "List of steps to add to the task plan.",
"properties": {
"steps": {
"type": "array",
"description": "A list of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the step."
},
"additionalProperties": false,
"required": ["name", "key", "description"]
}
"key": {
"type": "string",
"description": "Unique identifier for the step."
},
"description": {
"type": "string",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
},
"additionalProperties": false,
"required": ["name", "key", "description"]
}
},
"additionalProperties": false,
"required": ["steps"]
}
}
},
"additionalProperties": false,
"required": ["steps"]
},
"additionalProperties": false,
"required": ["steps"]
}
"completed": {
"type": "integer",
"description": "The number of steps to add as completed to the task plan."
}
},
"additionalProperties": false,
"required": ["steps", "completed"]
}
}

View File

@@ -16,6 +16,13 @@ class Step(BaseModel):
)
class ArgsSchema(BaseModel):
step: Step = Field(
...,
description="A step to add to the task plan.",
)
def create_step(step: Step) -> str:
"""
Creates a step for the current task.