fix: remove json schema generation from tool validation [LET-4509] (#4964)
* patch integration test * create default user and org * rm * patch * add testing * remove validation for schemas from pydantic object * add file * patch tests * fix more tests * fix managers * fix sdk test * patch schema tests * Comment out name in update * patch test * patch * add another test
This commit is contained in:
committed by
Caren Thomas
parent
4bdf85b883
commit
d0d36a4b07
@@ -16,12 +16,10 @@ from letta.constants import (
|
||||
# MCP Tool metadata constants for schema health status
|
||||
MCP_TOOL_METADATA_SCHEMA_STATUS = f"{MCP_TOOL_TAG_NAME_PREFIX}:SCHEMA_STATUS"
|
||||
MCP_TOOL_METADATA_SCHEMA_WARNINGS = f"{MCP_TOOL_TAG_NAME_PREFIX}:SCHEMA_WARNINGS"
|
||||
from letta.functions.ast_parsers import get_function_name_and_docstring
|
||||
from letta.functions.composio_helpers import generate_composio_tool_wrapper
|
||||
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
|
||||
from letta.functions.functions import get_json_schema_from_module
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.functions.schema_generator import (
|
||||
generate_schema_from_args_schema_v2,
|
||||
generate_tool_schema_for_composio,
|
||||
generate_tool_schema_for_mcp,
|
||||
)
|
||||
@@ -80,46 +78,19 @@ class Tool(BaseTool):
|
||||
def refresh_source_code_and_json_schema(self):
|
||||
"""
|
||||
Refresh name, description, source_code, and json_schema.
|
||||
|
||||
Note: Schema generation for custom tools is now handled at creation/update time in ToolManager.
|
||||
This method only handles built-in Letta tools.
|
||||
"""
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
|
||||
if self.tool_type == ToolType.CUSTOM and not self.json_schema:
|
||||
# attempt various fallbacks to get the JSON schema
|
||||
if not self.source_code:
|
||||
logger.error("Custom tool with id=%s is missing source_code field", self.id)
|
||||
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
|
||||
|
||||
if self.source_type == ToolSourceType.typescript:
|
||||
# TypeScript tools don't support args_json_schema, only direct schema generation
|
||||
if not self.json_schema:
|
||||
try:
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
|
||||
self.json_schema = derive_typescript_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
logger.error("Failed to derive TypeScript json schema for tool with id=%s name=%s: %s", self.id, self.name, e)
|
||||
elif (
|
||||
self.source_type == ToolSourceType.python or self.source_type is None
|
||||
): # default to python if not provided for backwards compatability
|
||||
# Python tool handling
|
||||
# Always derive json_schema for freshest possible json_schema
|
||||
if self.args_json_schema is not None:
|
||||
name, description = get_function_name_and_docstring(self.source_code, self.name)
|
||||
args_schema = generate_model_from_args_json_schema(self.args_json_schema)
|
||||
self.json_schema = generate_schema_from_args_schema_v2(
|
||||
args_schema=args_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
append_heartbeat=False,
|
||||
)
|
||||
else: # elif not self.json_schema: # TODO: JSON schema is not being derived correctly the first time?
|
||||
# If there's not a json_schema provided, then we need to re-derive
|
||||
try:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
logger.error("Failed to derive json schema for tool with id=%s name=%s: %s", self.id, self.name, e)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool source type: {self.source_type}")
|
||||
if self.tool_type == ToolType.CUSTOM:
|
||||
# Custom tools should already have their schema set during creation/update
|
||||
# No schema generation happens here anymore
|
||||
if not self.json_schema:
|
||||
logger.warning(
|
||||
"Custom tool with id=%s name=%s is missing json_schema. Schema should be set during creation/update.",
|
||||
self.id,
|
||||
self.name,
|
||||
)
|
||||
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE, ToolType.LETTA_SLEEPTIME_CORE}:
|
||||
# If it's letta core tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
|
||||
@@ -139,23 +110,6 @@ class Tool(BaseTool):
|
||||
# Composio schemas handled separately
|
||||
pass
|
||||
|
||||
# At this point, we need to validate that at least json_schema is populated
|
||||
if not self.json_schema:
|
||||
logger.error("Tool with id=%s name=%s tool_type=%s is missing a json_schema", self.id, self.name, self.tool_type)
|
||||
raise ValueError(f"Tool with id={self.id} name={self.name} tool_type={self.tool_type} is missing a json_schema.")
|
||||
|
||||
# Derive name from the JSON schema if not provided
|
||||
if not self.name:
|
||||
# TODO: This in theory could error, but name should always be on json_schema
|
||||
# TODO: Make JSON schema a typed pydantic object
|
||||
self.name = self.json_schema.get("name")
|
||||
|
||||
# Derive description from the JSON schema if not provided
|
||||
if not self.description:
|
||||
# TODO: This in theory could error, but description should always be on json_schema
|
||||
# TODO: Make JSON schema a typed pydantic object
|
||||
self.description = self.json_schema.get("description")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@@ -253,6 +207,7 @@ class ToolUpdate(LettaBase):
|
||||
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional metadata for the tool.")
|
||||
default_requires_approval: Optional[bool] = Field(None, description="Whether or not to require approval before executing this tool.")
|
||||
# name: Optional[str] = Field(None, description="The name of the tool (must match the JSON schema name and source code function name).")
|
||||
|
||||
model_config = ConfigDict(extra="ignore") # Allows extra fields without validation errors
|
||||
# TODO: Remove this, and clean usage of ToolUpdate everywhere else
|
||||
|
||||
@@ -20,6 +20,8 @@ from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import (
|
||||
BedrockPermissionError,
|
||||
LettaAgentNotFoundError,
|
||||
LettaToolCreateError,
|
||||
LettaToolNameConflictError,
|
||||
LettaUserNotFoundError,
|
||||
LLMAuthenticationError,
|
||||
LLMError,
|
||||
@@ -240,6 +242,8 @@ def create_application() -> "FastAPI":
|
||||
app.add_exception_handler(LettaUserNotFoundError, _error_handler_404_user)
|
||||
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
|
||||
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
|
||||
app.add_exception_handler(LettaToolCreateError, _error_handler_400)
|
||||
app.add_exception_handler(LettaToolNameConflictError, _error_handler_400)
|
||||
|
||||
@app.exception_handler(IncompatibleAgentType)
|
||||
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
||||
|
||||
@@ -287,19 +287,9 @@ async def create_tool(
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
tool = Tool(**request.model_dump(exclude_unset=True))
|
||||
return await server.tool_manager.create_tool_async(pydantic_tool=tool, actor=actor)
|
||||
except UniqueConstraintViolationError as e:
|
||||
clean_error_message = "Tool with this name already exists."
|
||||
raise HTTPException(status_code=409, detail=clean_error_message)
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors and raise an internal server error
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
tool = Tool(**request.model_dump(exclude_unset=True))
|
||||
return await server.tool_manager.create_or_update_tool_async(pydantic_tool=tool, actor=actor)
|
||||
|
||||
|
||||
@router.put("/", response_model=Tool, operation_id="upsert_tool")
|
||||
@@ -311,21 +301,9 @@ async def upsert_tool(
|
||||
"""
|
||||
Create or update a tool
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
tool = await server.tool_manager.create_or_update_tool_async(
|
||||
pydantic_tool=Tool(**request.model_dump(exclude_unset=True)), actor=actor
|
||||
)
|
||||
return tool
|
||||
except UniqueConstraintViolationError as e:
|
||||
# Log the error and raise a conflict exception
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors and raise an internal server error
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
tool = await server.tool_manager.create_or_update_tool_async(pydantic_tool=Tool(**request.model_dump(exclude_unset=True)), actor=actor)
|
||||
return tool
|
||||
|
||||
|
||||
@router.patch("/{tool_id}", response_model=Tool, operation_id="modify_tool")
|
||||
@@ -338,19 +316,9 @@ async def modify_tool(
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
tool = await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
return tool
|
||||
except LettaToolNameConflictError as e:
|
||||
# HTTP 409 == Conflict
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors and raise an internal server error
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
tool = await server.tool_manager.update_tool_by_id_async(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
return tool
|
||||
|
||||
|
||||
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
|
||||
@@ -376,25 +344,17 @@ async def run_tool_from_source(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
try:
|
||||
return await server.run_tool_from_source(
|
||||
tool_source=request.source_code,
|
||||
tool_source_type=request.source_type,
|
||||
tool_args=request.args,
|
||||
tool_env_vars=request.env_vars,
|
||||
tool_name=request.name,
|
||||
tool_args_json_schema=request.args_json_schema,
|
||||
tool_json_schema=request.json_schema,
|
||||
pip_requirements=request.pip_requirements,
|
||||
actor=actor,
|
||||
)
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors and raise an internal server error
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
return await server.run_tool_from_source(
|
||||
tool_source=request.source_code,
|
||||
tool_source_type=request.source_type,
|
||||
tool_args=request.args,
|
||||
tool_env_vars=request.env_vars,
|
||||
tool_name=request.name,
|
||||
tool_args_json_schema=request.args_json_schema,
|
||||
tool_json_schema=request.json_schema,
|
||||
pip_requirements=request.pip_requirements,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
# Specific routes for Composio
|
||||
|
||||
@@ -1181,6 +1181,8 @@ class SyncServer(object):
|
||||
) -> ToolReturnMessage:
|
||||
"""Run a tool from source code"""
|
||||
|
||||
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
|
||||
|
||||
if tool_source_type not in (None, ToolSourceType.python, ToolSourceType.typescript):
|
||||
raise ValueError("Tool source type is not supported at this time. Found {tool_source_type}")
|
||||
|
||||
@@ -1203,6 +1205,11 @@ class SyncServer(object):
|
||||
source_type=tool_source_type,
|
||||
)
|
||||
|
||||
# try to get the schema
|
||||
if not tool.name:
|
||||
if not tool.json_schema:
|
||||
tool.json_schema = generate_schema_for_tool_creation(tool)
|
||||
tool.name = tool.json_schema.get("name")
|
||||
assert tool.name is not None, "Failed to create tool object"
|
||||
|
||||
# TODO eventually allow using agent state in tools
|
||||
|
||||
@@ -32,6 +32,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.agent_manager_helper import calculate_multi_agent_tools
|
||||
from letta.services.mcp.types import SSEServerConfig, StdioServerConfig
|
||||
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
@@ -47,8 +48,29 @@ class ToolManager:
|
||||
self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False
|
||||
) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool_id:
|
||||
if pydantic_tool.tool_type == ToolType.CUSTOM and not pydantic_tool.json_schema:
|
||||
generated_schema = generate_schema_for_tool_creation(pydantic_tool)
|
||||
if generated_schema:
|
||||
pydantic_tool.json_schema = generated_schema
|
||||
else:
|
||||
raise ValueError("Failed to generate schema for tool", pydantic_tool.source_code)
|
||||
|
||||
print("SCHEMA", pydantic_tool.json_schema)
|
||||
|
||||
# make sure the name matches the json_schema
|
||||
if not pydantic_tool.name:
|
||||
pydantic_tool.name = pydantic_tool.json_schema.get("name")
|
||||
else:
|
||||
if pydantic_tool.name != pydantic_tool.json_schema.get("name"):
|
||||
raise LettaToolNameSchemaMismatchError(
|
||||
tool_name=pydantic_tool.name,
|
||||
json_schema_name=pydantic_tool.json_schema.get("name"),
|
||||
source_code=pydantic_tool.source_code,
|
||||
)
|
||||
|
||||
# check if the tool name already exists
|
||||
current_tool = await self.get_tool_by_name_async(tool_name=pydantic_tool.name, actor=actor)
|
||||
if current_tool:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data["organization_id"] = actor.organization_id
|
||||
@@ -61,17 +83,17 @@ class ToolManager:
|
||||
if "tool_type" in update_data:
|
||||
updated_tool_type = update_data.get("tool_type")
|
||||
tool = await self.update_tool_by_id_async(
|
||||
tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
|
||||
current_tool.id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type
|
||||
)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
||||
)
|
||||
tool = await self.get_tool_by_id_async(tool_id, actor=actor)
|
||||
else:
|
||||
tool = await self.create_tool_async(pydantic_tool, actor=actor)
|
||||
tool = await self.get_tool_by_id_async(current_tool.id, actor=actor)
|
||||
return tool
|
||||
|
||||
return await self.create_tool_async(pydantic_tool, actor=actor)
|
||||
|
||||
return tool
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server(
|
||||
@@ -115,9 +137,11 @@ class ToolManager:
|
||||
@trace_method
|
||||
async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Generate schema only if not provided (only for custom tools)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None:
|
||||
if pydantic_tool.description is None and pydantic_tool.json_schema:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
tool_data = pydantic_tool.model_dump(to_orm=True)
|
||||
# Set the organization id at the ORM layer
|
||||
@@ -166,6 +190,11 @@ class ToolManager:
|
||||
if not pydantic_tools:
|
||||
return []
|
||||
|
||||
# get schemas if not provided
|
||||
for tool in pydantic_tools:
|
||||
if tool.json_schema is None:
|
||||
tool.json_schema = generate_schema_for_tool_creation(tool)
|
||||
|
||||
# auto-generate descriptions if not provided
|
||||
for tool in pydantic_tools:
|
||||
if tool.description is None:
|
||||
@@ -494,34 +523,56 @@ class ToolManager:
|
||||
bypass_name_check: bool = False,
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
# First, check if source code update would cause a name conflict
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
new_name = None
|
||||
new_schema = None
|
||||
|
||||
# Fetch current tool early to allow conditional logic based on tool type
|
||||
current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
|
||||
|
||||
# Do NOT derive schema from Python source. Trust provided JSON schema.
|
||||
# Prefer provided json_schema; fall back to current
|
||||
if "json_schema" in update_data:
|
||||
new_schema = update_data["json_schema"].copy()
|
||||
# Handle schema updates for custom tools
|
||||
new_schema = None
|
||||
if current_tool.tool_type == ToolType.CUSTOM:
|
||||
if tool_update.json_schema is not None:
|
||||
new_schema = tool_update.json_schema
|
||||
elif tool_update.args_json_schema is not None:
|
||||
# Generate full schema from args_json_schema
|
||||
generated_schema = generate_schema_for_tool_update(
|
||||
current_tool=current_tool,
|
||||
json_schema=None,
|
||||
args_json_schema=tool_update.args_json_schema,
|
||||
source_code=tool_update.source_code,
|
||||
source_type=tool_update.source_type,
|
||||
)
|
||||
if generated_schema:
|
||||
tool_update.json_schema = generated_schema
|
||||
new_schema = generated_schema
|
||||
|
||||
# Now model_dump with the potentially updated schema
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
|
||||
# Determine the final schema and name
|
||||
if new_schema:
|
||||
new_name = new_schema.get("name", current_tool.name)
|
||||
elif "json_schema" in update_data:
|
||||
new_schema = update_data["json_schema"]
|
||||
new_name = new_schema.get("name", current_tool.name)
|
||||
else:
|
||||
# Keep existing schema
|
||||
new_schema = current_tool.json_schema
|
||||
new_name = current_tool.name
|
||||
|
||||
# original tool may no have a JSON schema at all for legacy reasons
|
||||
# in this case, fallback to dangerous schema generation
|
||||
if new_schema is None:
|
||||
# Get source_type from update_data if present, otherwise use current tool's source_type
|
||||
source_type = update_data.get("source_type", current_tool.source_type)
|
||||
if source_type == "typescript":
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
|
||||
new_schema = derive_typescript_json_schema(source_code=update_data["source_code"])
|
||||
else:
|
||||
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
# Handle explicit name updates
|
||||
if "name" in update_data and update_data["name"] != current_tool.name:
|
||||
# Name is being explicitly changed
|
||||
new_name = update_data["name"]
|
||||
# Update the json_schema name to match if there's a schema
|
||||
if new_schema:
|
||||
new_schema = new_schema.copy()
|
||||
new_schema["name"] = new_name
|
||||
update_data["json_schema"] = new_schema
|
||||
elif new_schema and new_name != current_tool.name:
|
||||
# Schema provides a different name but name wasn't explicitly changed
|
||||
update_data["name"] = new_name
|
||||
#raise ValueError(
|
||||
# f"JSON schema name '{new_name}' conflicts with current tool name '{current_tool.name}'. Update the name field explicitly if you want to rename the tool."
|
||||
#)
|
||||
|
||||
# If name changes, enforce uniqueness
|
||||
if new_name != current_tool.name:
|
||||
@@ -532,7 +583,9 @@ class ToolManager:
|
||||
# NOTE: EXTREMELEY HACKY, we need to stop making assumptions about the source_code
|
||||
if "source_code" in update_data and f"def {new_name}" not in update_data.get("source_code", ""):
|
||||
raise LettaToolNameSchemaMismatchError(
|
||||
tool_name=new_name, json_schema_name=new_schema.get("name"), source_code=update_data.get("source_code")
|
||||
tool_name=new_name,
|
||||
json_schema_name=new_schema.get("name") if new_schema else None,
|
||||
source_code=update_data.get("source_code"),
|
||||
)
|
||||
|
||||
# Now perform the update within the session
|
||||
|
||||
123
letta/services/tool_schema_generator.py
Normal file
123
letta/services/tool_schema_generator.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Schema generation utilities for tool creation and updates."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from letta.functions.ast_parsers import get_function_name_and_docstring
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ToolSourceType, ToolType
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def generate_schema_for_tool_creation(
|
||||
tool: PydanticTool,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Generate JSON schema for tool creation based on the provided parameters.
|
||||
|
||||
Args:
|
||||
tool: The tool being created
|
||||
|
||||
Returns:
|
||||
Generated JSON schema or None if not applicable
|
||||
"""
|
||||
# Only generate schema for custom tools
|
||||
if tool.tool_type != ToolType.CUSTOM:
|
||||
return None
|
||||
|
||||
# If json_schema is already provided, use it
|
||||
if tool.json_schema:
|
||||
return tool.json_schema
|
||||
|
||||
# Must have source code for custom tools
|
||||
if not tool.source_code:
|
||||
logger.error("Custom tool is missing source_code field")
|
||||
raise ValueError("Custom tool is missing source_code field.")
|
||||
|
||||
# TypeScript tools
|
||||
if tool.source_type == ToolSourceType.typescript:
|
||||
try:
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
|
||||
return derive_typescript_json_schema(source_code=tool.source_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to derive TypeScript json schema: {e}")
|
||||
raise ValueError(f"Failed to derive TypeScript json schema: {e}")
|
||||
|
||||
# Python tools (default if not specified for backwards compatibility)
|
||||
elif tool.source_type == ToolSourceType.python or tool.source_type is None:
|
||||
# If args_json_schema is provided, use it to generate full schema
|
||||
if tool.args_json_schema:
|
||||
name, description = get_function_name_and_docstring(tool.source_code, tool.name)
|
||||
args_schema = generate_model_from_args_json_schema(tool.args_json_schema)
|
||||
return generate_schema_from_args_schema_v2(
|
||||
args_schema=args_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
append_heartbeat=False,
|
||||
)
|
||||
# Otherwise, attempt to parse from docstring with best effort
|
||||
else:
|
||||
try:
|
||||
return derive_openai_json_schema(source_code=tool.source_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to derive json schema: {e}")
|
||||
raise ValueError(f"Failed to derive json schema: {e}")
|
||||
else:
|
||||
# TODO: convert to explicit error
|
||||
raise ValueError(f"Unknown tool source type: {tool.source_type}")
|
||||
|
||||
|
||||
def generate_schema_for_tool_update(
|
||||
current_tool: PydanticTool,
|
||||
json_schema: Optional[dict] = None,
|
||||
args_json_schema: Optional[dict] = None,
|
||||
source_code: Optional[str] = None,
|
||||
source_type: Optional[ToolSourceType] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Generate JSON schema for tool update based on the provided parameters.
|
||||
|
||||
Args:
|
||||
current_tool: The current tool being updated
|
||||
json_schema: Directly provided JSON schema (takes precedence)
|
||||
args_json_schema: Schema for just the arguments
|
||||
source_code: Updated source code (only used if explicitly updating source)
|
||||
source_type: Source type for the tool
|
||||
|
||||
Returns:
|
||||
Updated JSON schema or None if no update needed
|
||||
"""
|
||||
# Only handle custom tools
|
||||
if current_tool.tool_type != ToolType.CUSTOM:
|
||||
return None
|
||||
|
||||
# If json_schema is directly provided, use it
|
||||
if json_schema is not None:
|
||||
# If args_json_schema is also provided, that's an error
|
||||
if args_json_schema is not None:
|
||||
raise ValueError("Cannot provide both json_schema and args_json_schema in update")
|
||||
return json_schema
|
||||
|
||||
# If args_json_schema is provided, generate full schema from it
|
||||
if args_json_schema is not None:
|
||||
# Use updated source_code if provided, otherwise use current
|
||||
code_to_parse = source_code if source_code is not None else current_tool.source_code
|
||||
if not code_to_parse:
|
||||
raise ValueError("Source code required when updating with args_json_schema")
|
||||
|
||||
name, description = get_function_name_and_docstring(code_to_parse, current_tool.name)
|
||||
args_schema = generate_model_from_args_json_schema(args_json_schema)
|
||||
return generate_schema_from_args_schema_v2(
|
||||
args_schema=args_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
append_heartbeat=False,
|
||||
)
|
||||
|
||||
# Otherwise, no schema updates (don't parse docstring)
|
||||
return None
|
||||
@@ -96,6 +96,7 @@ from letta.server.server import SyncServer
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools, validate_agent_exists_async
|
||||
from letta.services.step_manager import FeedbackType
|
||||
from letta.services.tool_schema_generator import generate_schema_for_tool_creation
|
||||
from letta.settings import settings, tool_settings
|
||||
from letta.utils import calculate_file_defaults_based_on_context_window
|
||||
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
|
||||
@@ -689,23 +690,25 @@ async def test_list_tools_with_tool_types(server: SyncServer, default_user):
|
||||
|
||||
# create custom tools
|
||||
custom_tool1 = PydanticTool(
|
||||
name="calculator",
|
||||
name="calculator_tool",
|
||||
description="Math tool",
|
||||
source_code=parse_source_code(calculator_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
custom_tool1.json_schema = derive_openai_json_schema(source_code=custom_tool1.source_code, name=custom_tool1.name)
|
||||
# Use generate_schema_for_tool_creation to generate schema
|
||||
custom_tool1.json_schema = generate_schema_for_tool_creation(custom_tool1)
|
||||
custom_tool1 = await server.tool_manager.create_or_update_tool_async(custom_tool1, actor=default_user)
|
||||
|
||||
custom_tool2 = PydanticTool(
|
||||
name="weather",
|
||||
# name="weather_tool",
|
||||
description="Weather tool",
|
||||
source_code=parse_source_code(weather_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
custom_tool2.json_schema = derive_openai_json_schema(source_code=custom_tool2.source_code, name=custom_tool2.name)
|
||||
# Use generate_schema_for_tool_creation to generate schema
|
||||
custom_tool2.json_schema = generate_schema_for_tool_creation(custom_tool2)
|
||||
custom_tool2 = await server.tool_manager.create_or_update_tool_async(custom_tool2, actor=default_user)
|
||||
|
||||
# test filtering by single tool type
|
||||
@@ -744,13 +747,13 @@ async def test_list_tools_with_exclude_tool_types(server: SyncServer, default_us
|
||||
return msg
|
||||
|
||||
special = PydanticTool(
|
||||
name="special",
|
||||
name="special_tool",
|
||||
description="Special tool",
|
||||
source_code=parse_source_code(special_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
special.json_schema = derive_openai_json_schema(source_code=special.source_code, name=special.name)
|
||||
special.json_schema = generate_schema_for_tool_creation(special)
|
||||
special = await server.tool_manager.create_or_update_tool_async(special, actor=default_user)
|
||||
|
||||
# test excluding EXTERNAL_MCP (should get all tools since none are MCP)
|
||||
@@ -796,15 +799,15 @@ async def test_list_tools_with_names(server: SyncServer, default_user):
|
||||
return "gamma"
|
||||
|
||||
alpha = PydanticTool(name="alpha_tool", description="Alpha", source_code=parse_source_code(alpha_tool), source_type="python")
|
||||
alpha.json_schema = derive_openai_json_schema(source_code=alpha.source_code, name=alpha.name)
|
||||
alpha.json_schema = generate_schema_for_tool_creation(alpha)
|
||||
alpha = await server.tool_manager.create_or_update_tool_async(alpha, actor=default_user)
|
||||
|
||||
beta = PydanticTool(name="beta_tool", description="Beta", source_code=parse_source_code(beta_tool), source_type="python")
|
||||
beta.json_schema = derive_openai_json_schema(source_code=beta.source_code, name=beta.name)
|
||||
beta.json_schema = generate_schema_for_tool_creation(beta)
|
||||
beta = await server.tool_manager.create_or_update_tool_async(beta, actor=default_user)
|
||||
|
||||
gamma = PydanticTool(name="gamma_tool", description="Gamma", source_code=parse_source_code(gamma_tool), source_type="python")
|
||||
gamma.json_schema = derive_openai_json_schema(source_code=gamma.source_code, name=gamma.name)
|
||||
gamma.json_schema = generate_schema_for_tool_creation(gamma)
|
||||
gamma = await server.tool_manager.create_or_update_tool_async(gamma, actor=default_user)
|
||||
|
||||
# test filtering by single name
|
||||
@@ -852,15 +855,15 @@ async def test_list_tools_with_tool_ids(server: SyncServer, default_user):
|
||||
return "3"
|
||||
|
||||
t1 = PydanticTool(name="tool1", description="First", source_code=parse_source_code(tool1), source_type="python")
|
||||
t1.json_schema = derive_openai_json_schema(source_code=t1.source_code, name=t1.name)
|
||||
t1.json_schema = generate_schema_for_tool_creation(t1)
|
||||
t1 = await server.tool_manager.create_or_update_tool_async(t1, actor=default_user)
|
||||
|
||||
t2 = PydanticTool(name="tool2", description="Second", source_code=parse_source_code(tool2), source_type="python")
|
||||
t2.json_schema = derive_openai_json_schema(source_code=t2.source_code, name=t2.name)
|
||||
t2.json_schema = generate_schema_for_tool_creation(t2)
|
||||
t2 = await server.tool_manager.create_or_update_tool_async(t2, actor=default_user)
|
||||
|
||||
t3 = PydanticTool(name="tool3", description="Third", source_code=parse_source_code(tool3), source_type="python")
|
||||
t3.json_schema = derive_openai_json_schema(source_code=t3.source_code, name=t3.name)
|
||||
t3.json_schema = generate_schema_for_tool_creation(t3)
|
||||
t3 = await server.tool_manager.create_or_update_tool_async(t3, actor=default_user)
|
||||
|
||||
# test filtering by single id
|
||||
@@ -910,19 +913,19 @@ async def test_list_tools_with_search(server: SyncServer, default_user):
|
||||
calc_add = PydanticTool(
|
||||
name="calculator_add", description="Add numbers", source_code=parse_source_code(calculator_add), source_type="python"
|
||||
)
|
||||
calc_add.json_schema = derive_openai_json_schema(source_code=calc_add.source_code, name=calc_add.name)
|
||||
calc_add.json_schema = generate_schema_for_tool_creation(calc_add)
|
||||
calc_add = await server.tool_manager.create_or_update_tool_async(calc_add, actor=default_user)
|
||||
|
||||
calc_sub = PydanticTool(
|
||||
name="calculator_subtract", description="Subtract numbers", source_code=parse_source_code(calculator_subtract), source_type="python"
|
||||
)
|
||||
calc_sub.json_schema = derive_openai_json_schema(source_code=calc_sub.source_code, name=calc_sub.name)
|
||||
calc_sub.json_schema = generate_schema_for_tool_creation(calc_sub)
|
||||
calc_sub = await server.tool_manager.create_or_update_tool_async(calc_sub, actor=default_user)
|
||||
|
||||
weather = PydanticTool(
|
||||
name="weather_forecast", description="Weather", source_code=parse_source_code(weather_forecast), source_type="python"
|
||||
)
|
||||
weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name)
|
||||
weather.json_schema = generate_schema_for_tool_creation(weather)
|
||||
weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user)
|
||||
|
||||
# test searching for "calculator" (should find both calculator tools)
|
||||
@@ -965,7 +968,7 @@ async def test_list_tools_return_only_letta_tools(server: SyncServer, default_us
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
custom.json_schema = derive_openai_json_schema(source_code=custom.source_code, name=custom.name)
|
||||
custom.json_schema = generate_schema_for_tool_creation(custom)
|
||||
custom = await server.tool_manager.create_or_update_tool_async(custom, actor=default_user)
|
||||
|
||||
# test without filter (should get custom tool + all letta tools)
|
||||
@@ -1013,48 +1016,45 @@ async def test_list_tools_combined_filters(server: SyncServer, default_user):
|
||||
return "weather"
|
||||
|
||||
calc1 = PydanticTool(
|
||||
name="calculator_add", description="Add", source_code=parse_source_code(calc_add), source_type="python", tool_type=ToolType.CUSTOM
|
||||
name="calc_add", description="Add", source_code=parse_source_code(calc_add), source_type="python", tool_type=ToolType.CUSTOM
|
||||
)
|
||||
calc1.json_schema = derive_openai_json_schema(source_code=calc1.source_code, name=calc1.name)
|
||||
calc1.json_schema = generate_schema_for_tool_creation(calc1)
|
||||
calc1 = await server.tool_manager.create_or_update_tool_async(calc1, actor=default_user)
|
||||
|
||||
calc2 = PydanticTool(
|
||||
name="calculator_multiply",
|
||||
description="Multiply",
|
||||
source_code=parse_source_code(calc_multiply),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
calc2.json_schema = derive_openai_json_schema(source_code=calc2.source_code, name=calc2.name)
|
||||
calc2.json_schema = generate_schema_for_tool_creation(calc2)
|
||||
calc2 = await server.tool_manager.create_or_update_tool_async(calc2, actor=default_user)
|
||||
|
||||
weather = PydanticTool(
|
||||
name="weather_current",
|
||||
name="weather_tool",
|
||||
description="Weather",
|
||||
source_code=parse_source_code(weather_tool),
|
||||
source_type="python",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
)
|
||||
weather.json_schema = derive_openai_json_schema(source_code=weather.source_code, name=weather.name)
|
||||
weather.json_schema = generate_schema_for_tool_creation(weather)
|
||||
weather = await server.tool_manager.create_or_update_tool_async(weather, actor=default_user)
|
||||
|
||||
# combine search with tool_types
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, search="calculator", tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False
|
||||
actor=default_user, search="calc", tool_types=[ToolType.CUSTOM.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 2
|
||||
assert all("calculator" in t.name and t.tool_type == ToolType.CUSTOM for t in tools)
|
||||
assert all("calc" in t.name and t.tool_type == ToolType.CUSTOM for t in tools)
|
||||
|
||||
# combine names with tool_ids
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, names=["calculator_add"], tool_ids=[calc1.id], upsert_base_tools=False
|
||||
)
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["calc_add"], tool_ids=[calc1.id], upsert_base_tools=False)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].id == calc1.id
|
||||
|
||||
# combine search with exclude_tool_types
|
||||
tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, search="calculator", exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
|
||||
actor=default_user, search="cal", exclude_tool_types=[ToolType.EXTERNAL_MCP.value], upsert_base_tools=False
|
||||
)
|
||||
assert len(tools) == 2
|
||||
|
||||
@@ -1091,13 +1091,13 @@ async def test_count_tools_async(server: SyncServer, default_user):
|
||||
ta = PydanticTool(
|
||||
name="tool_a", description="A", source_code=parse_source_code(tool_a), source_type="python", tool_type=ToolType.CUSTOM
|
||||
)
|
||||
ta.json_schema = derive_openai_json_schema(source_code=ta.source_code, name=ta.name)
|
||||
ta.json_schema = generate_schema_for_tool_creation(ta)
|
||||
ta = await server.tool_manager.create_or_update_tool_async(ta, actor=default_user)
|
||||
|
||||
tb = PydanticTool(
|
||||
name="tool_b", description="B", source_code=parse_source_code(tool_b), source_type="python", tool_type=ToolType.CUSTOM
|
||||
)
|
||||
tb.json_schema = derive_openai_json_schema(source_code=tb.source_code, name=tb.name)
|
||||
tb.json_schema = generate_schema_for_tool_creation(tb)
|
||||
tb = await server.tool_manager.create_or_update_tool_async(tb, actor=default_user)
|
||||
|
||||
# upsert base tools to ensure we have Letta tools for counting
|
||||
@@ -1167,8 +1167,8 @@ async def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
assert updated_tool.tool_type == ToolType.EXTERNAL_MCP
|
||||
|
||||
|
||||
#@pytest.mark.asyncio
|
||||
#async def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, print_tool, default_user):
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, print_tool, default_user):
|
||||
# def counter_tool(counter: int):
|
||||
# """
|
||||
# Args:
|
||||
@@ -1205,8 +1205,8 @@ async def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
# assert updated_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
|
||||
#@pytest.mark.asyncio
|
||||
#async def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print_tool, default_user):
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print_tool, default_user):
|
||||
# def counter_tool(counter: int):
|
||||
# """
|
||||
# Args:
|
||||
@@ -1701,7 +1701,7 @@ async def test_create_tool_with_pip_requirements(server: SyncServer, default_use
|
||||
tool = PydanticTool(
|
||||
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
|
||||
)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
derived_json_schema = generate_schema_for_tool_creation(tool)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
@@ -1767,7 +1767,7 @@ async def test_update_tool_clear_pip_requirements(server: SyncServer, default_us
|
||||
tool = PydanticTool(
|
||||
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
|
||||
)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
derived_json_schema = generate_schema_for_tool_creation(tool)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
@@ -1817,7 +1817,7 @@ async def test_pip_requirements_roundtrip(server: SyncServer, default_user, defa
|
||||
tool = PydanticTool(
|
||||
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
|
||||
)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
derived_json_schema = generate_schema_for_tool_creation(tool)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
@@ -1859,3 +1859,360 @@ async def test_update_default_requires_approval(server: SyncServer, bash_tool, d
|
||||
|
||||
# Assertions
|
||||
assert updated_tool.default_requires_approval == True
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# ToolManager Schema tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
async def test_create_tool_with_json_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test that json_schema is used when provided at creation."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
source_code = """
|
||||
def test_function(arg1: str) -> str:
|
||||
return arg1
|
||||
"""
|
||||
|
||||
json_schema = {
|
||||
"name": "test_function",
|
||||
"description": "A test function",
|
||||
"parameters": {"type": "object", "properties": {"arg1": {"type": "string"}}, "required": ["arg1"]},
|
||||
}
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_function",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_tool_async(tool, default_user)
|
||||
|
||||
assert created_tool.json_schema == json_schema
|
||||
assert created_tool.name == "test_function"
|
||||
assert created_tool.description == "A test function"
|
||||
|
||||
|
||||
async def test_create_tool_with_args_json_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test that schema is generated from args_json_schema at creation."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
source_code = """
|
||||
def test_function(arg1: str, arg2: int) -> str:
|
||||
'''This is a test function'''
|
||||
return f"{arg1} {arg2}"
|
||||
"""
|
||||
|
||||
args_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"type": "string", "description": "First argument"},
|
||||
"arg2": {"type": "integer", "description": "Second argument"},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_function",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
args_json_schema=args_json_schema,
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
|
||||
assert created_tool.json_schema is not None
|
||||
assert created_tool.json_schema["name"] == "test_function"
|
||||
assert created_tool.json_schema["description"] == "This is a test function"
|
||||
assert "parameters" in created_tool.json_schema
|
||||
assert created_tool.json_schema["parameters"]["properties"]["arg1"]["type"] == "string"
|
||||
assert created_tool.json_schema["parameters"]["properties"]["arg2"]["type"] == "integer"
|
||||
|
||||
|
||||
async def test_create_tool_with_docstring_no_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test that schema is generated from docstring when no schema provided."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
source_code = """
|
||||
def test_function(arg1: str, arg2: int = 5) -> str:
|
||||
'''
|
||||
This is a test function
|
||||
|
||||
Args:
|
||||
arg1: First argument
|
||||
arg2: Second argument
|
||||
|
||||
Returns:
|
||||
A string result
|
||||
'''
|
||||
return f"{arg1} {arg2}"
|
||||
"""
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_function",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
|
||||
assert created_tool.json_schema is not None
|
||||
assert created_tool.json_schema["name"] == "test_function"
|
||||
assert "This is a test function" in created_tool.json_schema["description"]
|
||||
assert "parameters" in created_tool.json_schema
|
||||
|
||||
|
||||
async def test_create_tool_with_docstring_and_args_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test that args_json_schema takes precedence over docstring."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
source_code = """
|
||||
def test_function(old_arg: str) -> str:
|
||||
'''Old docstring that should be overridden'''
|
||||
return old_arg
|
||||
"""
|
||||
|
||||
args_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {"new_arg": {"type": "string", "description": "New argument from schema"}},
|
||||
"required": ["new_arg"],
|
||||
}
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_function",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
args_json_schema=args_json_schema,
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
|
||||
assert created_tool.json_schema is not None
|
||||
assert created_tool.json_schema["name"] == "test_function"
|
||||
# The description should come from the docstring
|
||||
assert created_tool.json_schema["description"] == "Old docstring that should be overridden"
|
||||
# But parameters should come from args_json_schema
|
||||
assert "new_arg" in created_tool.json_schema["parameters"]["properties"]
|
||||
assert "old_arg" not in created_tool.json_schema["parameters"]["properties"]
|
||||
|
||||
|
||||
async def test_error_no_docstring_or_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test error when no docstring or schema provided (minimal function)."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
# Minimal function with no docstring - should still derive basic schema
|
||||
source_code = """
|
||||
def test_function():
|
||||
pass
|
||||
"""
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_function",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
created_tool = await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
|
||||
|
||||
async def test_error_on_create_tool_with_name_conflict(server: SyncServer, default_user, default_organization):
|
||||
"""Test error when json_schema name conflicts with function name."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
source_code = """
|
||||
def test_function(arg1: str) -> str:
|
||||
return arg1
|
||||
"""
|
||||
|
||||
# JSON schema with conflicting name
|
||||
json_schema = {
|
||||
"name": "different_name",
|
||||
"description": "A test function",
|
||||
"parameters": {"type": "object", "properties": {"arg1": {"type": "string"}}, "required": ["arg1"]},
|
||||
}
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_function",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
# This should succeed at creation - the tool name takes precedence
|
||||
created_tool = await tool_manager.create_tool_async(tool, default_user)
|
||||
assert created_tool.name == "test_function"
|
||||
|
||||
|
||||
async def test_update_tool_with_json_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test update with a new json_schema."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
# Create initial tool
|
||||
source_code = """
|
||||
def test_function() -> str:
|
||||
return "hello"
|
||||
"""
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_update_json_schema",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
json_schema={"name": "test_update_json_schema", "description": "Original"},
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_tool_async(tool, default_user)
|
||||
|
||||
# Update with new json_schema
|
||||
new_schema = {
|
||||
"name": "test_update_json_schema",
|
||||
"description": "Updated description",
|
||||
"parameters": {"type": "object", "properties": {"new_arg": {"type": "string"}}, "required": ["new_arg"]},
|
||||
}
|
||||
|
||||
update = ToolUpdate(json_schema=new_schema)
|
||||
updated_tool = await tool_manager.update_tool_by_id_async(created_tool.id, update, default_user)
|
||||
|
||||
assert updated_tool.json_schema == new_schema
|
||||
assert updated_tool.json_schema["description"] == "Updated description"
|
||||
|
||||
|
||||
async def test_update_tool_with_args_json_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test update with args_json_schema."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
# Create initial tool
|
||||
source_code = """
|
||||
def test_function() -> str:
|
||||
'''Original function'''
|
||||
return "hello"
|
||||
"""
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_function",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code=source_code,
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
|
||||
# Update with args_json_schema
|
||||
new_source_code = """
|
||||
def test_function(new_arg: str) -> str:
|
||||
'''Updated function'''
|
||||
return new_arg
|
||||
"""
|
||||
|
||||
args_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {"new_arg": {"type": "string", "description": "New argument"}},
|
||||
"required": ["new_arg"],
|
||||
}
|
||||
|
||||
update = ToolUpdate(source_code=new_source_code, args_json_schema=args_json_schema)
|
||||
updated_tool = await tool_manager.update_tool_by_id_async(created_tool.id, update, default_user)
|
||||
|
||||
assert updated_tool.json_schema is not None
|
||||
assert updated_tool.json_schema["description"] == "Updated function"
|
||||
assert "new_arg" in updated_tool.json_schema["parameters"]["properties"]
|
||||
|
||||
|
||||
async def test_update_tool_with_no_schema(server: SyncServer, default_user, default_organization):
|
||||
"""Test update with no schema changes."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
# Create initial tool
|
||||
original_schema = {
|
||||
"name": "test_no_schema_update",
|
||||
"description": "Original description",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
}
|
||||
|
||||
tool = PydanticTool(
|
||||
name="test_no_schema_update",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code="def test_function(): pass",
|
||||
json_schema=original_schema,
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_tool_async(tool, default_user)
|
||||
|
||||
# Update with only description (no schema change)
|
||||
update = ToolUpdate(description="New description")
|
||||
updated_tool = await tool_manager.update_tool_by_id_async(created_tool.id, update, default_user)
|
||||
|
||||
# Schema should remain unchanged
|
||||
assert updated_tool.json_schema == original_schema
|
||||
assert updated_tool.description == "New description"
|
||||
|
||||
|
||||
async def test_update_tool_name(server: SyncServer, default_user, default_organization):
|
||||
"""Test various name update scenarios."""
|
||||
tool_manager = server.tool_manager
|
||||
|
||||
# Create initial tool
|
||||
original_schema = {"name": "original_name", "description": "Test", "parameters": {"type": "object", "properties": {}}}
|
||||
|
||||
tool = PydanticTool(
|
||||
name="original_name",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code="def original_name(): pass",
|
||||
json_schema=original_schema,
|
||||
)
|
||||
|
||||
created_tool = await tool_manager.create_or_update_tool_async(tool, default_user)
|
||||
assert created_tool.name == "original_name"
|
||||
assert created_tool.json_schema["name"] == "original_name"
|
||||
|
||||
matching_schema = {"name": "matched_name", "description": "Test", "parameters": {"type": "object", "properties": {}}}
|
||||
update = ToolUpdate(json_schema=matching_schema)
|
||||
updated_tool3 = await tool_manager.update_tool_by_id_async(created_tool.id, update, default_user)
|
||||
assert updated_tool3.name == "matched_name"
|
||||
assert updated_tool3.json_schema["name"] == "matched_name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_with_corrupted_tool(server: SyncServer, default_user, print_tool):
|
||||
"""Test that list_tools still works even if there's a corrupted tool (missing json_schema) in the database."""
|
||||
|
||||
# First, verify we have a normal tool
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False)
|
||||
initial_tool_count = len(tools)
|
||||
assert any(t.id == print_tool.id for t in tools)
|
||||
|
||||
# Now insert a corrupted tool directly into the database (bypassing normal validation)
|
||||
# This simulates a tool that somehow got corrupted in the database
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Create a tool with no json_schema (corrupted state)
|
||||
corrupted_tool = ToolModel(
|
||||
id=f"tool-corrupted-{uuid.uuid4()}",
|
||||
name="corrupted_tool",
|
||||
description="This tool has no json_schema",
|
||||
tool_type=ToolType.CUSTOM,
|
||||
source_code="def corrupted_tool(): pass",
|
||||
json_schema=None, # Explicitly set to None to simulate corruption
|
||||
organization_id=default_user.organization_id,
|
||||
created_by_id=default_user.id,
|
||||
last_updated_by_id=default_user.id,
|
||||
tags=["corrupted"],
|
||||
)
|
||||
|
||||
session.add(corrupted_tool)
|
||||
await session.commit()
|
||||
corrupted_tool_id = corrupted_tool.id
|
||||
|
||||
# Now try to list tools - it should still work and not include the corrupted tool
|
||||
# The corrupted tool should be automatically excluded from results
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, upsert_base_tools=False)
|
||||
|
||||
# Verify listing still works
|
||||
assert len(tools) == initial_tool_count # Corrupted tool should not be in the results
|
||||
assert any(t.id == print_tool.id for t in tools) # Normal tool should still be there
|
||||
assert not any(t.id == corrupted_tool_id for t in tools) # Corrupted tool should not be there
|
||||
|
||||
# Verify the corrupted tool's name is not in the results
|
||||
assert not any(t.name == "corrupted_tool" for t in tools)
|
||||
|
||||
@@ -50,7 +50,8 @@ def run_server():
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
# @pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def client() -> LettaSDKClient:
|
||||
# Get URL from environment or start server
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
|
||||
@@ -65,8 +66,8 @@ def client() -> LettaSDKClient:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
@pytest.fixture(scope="function")
|
||||
async def server():
|
||||
"""
|
||||
Creates a SyncServer instance for testing.
|
||||
|
||||
@@ -74,7 +75,9 @@ def server():
|
||||
"""
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
return SyncServer()
|
||||
server = SyncServer()
|
||||
await server.init_async()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
@@ -277,42 +277,6 @@ def _run_composio_test(action_name, openai_model, structured_output):
|
||||
return (action_name, False, str(e)) # Failure with error message
|
||||
|
||||
|
||||
@pytest.mark.parametrize("openai_model", ["gpt-4o-mini"])
|
||||
@pytest.mark.parametrize("structured_output", [True])
|
||||
def test_composio_tool_schema_generation(openai_model: str, structured_output: bool):
|
||||
"""Test that we can generate the schemas for some Composio tools."""
|
||||
|
||||
if not os.getenv("COMPOSIO_API_KEY"):
|
||||
pytest.skip("COMPOSIO_API_KEY not set")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
action_names = [
|
||||
"GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER", # Simple
|
||||
"CAL_GET_AVAILABLE_SLOTS_INFO", # has an array arg, needs to be converted properly
|
||||
"SALESFORCE_RETRIEVE_LEAD_BY_ID", # has an array arg, needs to be converted properly
|
||||
"FIRECRAWL_SEARCH", # has an optional array arg, needs to be converted properly
|
||||
]
|
||||
|
||||
# Create a pool of processes
|
||||
pool = mp.Pool(processes=min(mp.cpu_count(), len(action_names)))
|
||||
|
||||
# Map the work to the pool
|
||||
func = partial(_run_composio_test, openai_model=openai_model, structured_output=structured_output)
|
||||
results = pool.map(func, action_names)
|
||||
|
||||
# Check results
|
||||
for action_name, success, error_message in results:
|
||||
print(f"Test for {action_name}: {'SUCCESS' if success else 'FAILED - ' + error_message}")
|
||||
assert success, f"Test for {action_name} failed: {error_message}"
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
end_time = time.time()
|
||||
print(f"Total execution time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
|
||||
# Helper function for pydantic args schema test
|
||||
def _run_pydantic_args_test(filename, openai_model, structured_output):
|
||||
"""Run a single pydantic args schema test case"""
|
||||
@@ -342,6 +306,9 @@ def _run_pydantic_args_test(filename, openai_model, structured_output):
|
||||
source_code=last_function_source,
|
||||
args_json_schema=args_schema,
|
||||
)
|
||||
from letta.services.tool_schema_generator import generate_schema_for_tool_creation
|
||||
|
||||
tool.json_schema = generate_schema_for_tool_creation(tool)
|
||||
schema = tool.json_schema
|
||||
|
||||
# We expect this to fail for all_python_complex with structured_output=True
|
||||
|
||||
Reference in New Issue
Block a user