fix: Fix tool renaming if json schema is passed in (#3745)

This commit is contained in:
Matthew Zhou
2025-08-05 11:36:23 -07:00
committed by GitHub
parent 4a9c51d2f8
commit 47d8650f09
3 changed files with 145 additions and 23 deletions

View File

@@ -1479,7 +1479,7 @@ class AgentManager:
):
pydantic_tool = existing_pydantic_tool
else:
pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor)
pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor, bypass_name_check=True)
pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor)

View File

@@ -46,7 +46,7 @@ class ToolManager:
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
@enforce_types
@trace_method
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor)
if tool_id:
@@ -60,7 +60,9 @@ class ToolManager:
updated_tool_type = None
if "tool_type" in update_data:
updated_tool_type = update_data.get("tool_type")
tool = self.update_tool_by_id(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
tool = self.update_tool_by_id(
tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
)
else:
printd(
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
@@ -73,7 +75,9 @@ class ToolManager:
@enforce_types
@trace_method
async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
async def create_or_update_tool_async(
self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False
) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
if tool_id:
@@ -88,7 +92,9 @@ class ToolManager:
updated_tool_type = None
if "tool_type" in update_data:
updated_tool_type = update_data.get("tool_type")
tool = await self.update_tool_by_id_async(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
tool = await self.update_tool_by_id_async(
tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
)
else:
printd(
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
@@ -387,7 +393,12 @@ class ToolManager:
@enforce_types
@trace_method
def update_tool_by_id(
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
self,
tool_id: str,
tool_update: ToolUpdate,
actor: PydanticUser,
updated_tool_type: Optional[ToolType] = None,
bypass_name_check: bool = False,
) -> PydanticTool:
"""Update a tool by its ID with the given ToolUpdate object."""
# First, check if source code update would cause a name conflict
@@ -395,17 +406,29 @@ class ToolManager:
new_name = None
new_schema = None
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
# Derive the new schema and name from the source code
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
new_name = new_schema["name"]
# TODO: Consider this behavior...is this what we want?
# TODO: I feel like it's bad if json_schema strays from source code so
# if source code is provided, always derive the name from it
if "source_code" in update_data.keys() and not bypass_name_check:
# derive the schema from source code to get the function name
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
new_name = derived_schema["name"]
# Get current tool to check if name is changing
# if json_schema wasn't provided, use the derived schema
if "json_schema" not in update_data.keys():
new_schema = derived_schema
else:
# if json_schema was provided, update only its name to match the source code
new_schema = update_data["json_schema"].copy()
new_schema["name"] = new_name
# update the json_schema in update_data so it gets applied in the loop
update_data["json_schema"] = new_schema
# get current tool to check if name is changing
current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor)
# Check if the name is changing and if so, verify it doesn't conflict
# check if the name is changing and if so, verify it doesn't conflict
if new_name != current_tool.name:
# Check if a tool with the new name already exists
# check if a tool with the new name already exists
existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor)
if existing_tool:
raise LettaToolNameConflictError(tool_name=new_name)
@@ -433,7 +456,12 @@ class ToolManager:
@enforce_types
@trace_method
async def update_tool_by_id_async(
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
self,
tool_id: str,
tool_update: ToolUpdate,
actor: PydanticUser,
updated_tool_type: Optional[ToolType] = None,
bypass_name_check: bool = False,
) -> PydanticTool:
"""Update a tool by its ID with the given ToolUpdate object."""
# First, check if source code update would cause a name conflict
@@ -441,17 +469,29 @@ class ToolManager:
new_name = None
new_schema = None
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
# Derive the new schema and name from the source code
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
new_name = new_schema["name"]
# TODO: Consider this behavior...is this what we want?
# TODO: I feel like it's bad if json_schema strays from source code so
# if source code is provided, always derive the name from it
if "source_code" in update_data.keys() and not bypass_name_check:
# derive the schema from source code to get the function name
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
new_name = derived_schema["name"]
# Get current tool to check if name is changing
# if json_schema wasn't provided, use the derived schema
if "json_schema" not in update_data.keys():
new_schema = derived_schema
else:
# if json_schema was provided, update only its name to match the source code
new_schema = update_data["json_schema"].copy()
new_schema["name"] = new_name
# update the json_schema in update_data so it gets applied in the loop
update_data["json_schema"] = new_schema
# get current tool to check if name is changing
current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
# Check if the name is changing and if so, verify it doesn't conflict
# check if the name is changing and if so, verify it doesn't conflict
if new_name != current_tool.name:
# Check if a tool with the new name already exists
# check if a tool with the new name already exists
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
if name_exists:
raise LettaToolNameConflictError(tool_name=new_name)

View File

@@ -1440,3 +1440,85 @@ def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
finally:
# Clean up
client.tools.delete(tool_id=tool.id)
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
"""Test that passing both new JSON schema AND source code still renames the tool based on source code"""
import textwrap
# Create initial tool
def initial_tool(x: int) -> int:
"""
Multiply a number by 2
Args:
x: The input number
Returns:
The input multiplied by 2
"""
return x * 2
# Create the tool
tool = client.tools.upsert_from_function(func=initial_tool)
assert tool.name == "initial_tool"
try:
# Define new function source code with different name
new_source_code = textwrap.dedent(
"""
def renamed_function(value: float, multiplier: float = 2.0) -> float:
'''
Multiply a value by a multiplier
Args:
value: The input value
multiplier: The multiplier to use (default 2.0)
Returns:
The value multiplied by the multiplier
'''
return value * multiplier
"""
).strip()
# Create a custom JSON schema that has a different name
custom_json_schema = {
"name": "custom_schema_name",
"description": "Custom description from JSON schema",
"parameters": {
"type": "object",
"properties": {
"value": {"type": "number", "description": "Input value from JSON schema"},
"multiplier": {"type": "number", "description": "Multiplier from JSON schema", "default": 2.0},
},
"required": ["value"],
},
}
# Modify the tool with both new source code AND JSON schema
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
# Verify the name comes from the source code function name, not the JSON schema
assert modified_tool.name == "renamed_function"
assert modified_tool.source_code == new_source_code
# Verify the JSON schema was updated to match the function name from source code
assert modified_tool.json_schema is not None
assert modified_tool.json_schema["name"] == "renamed_function"
# The description should come from the source code docstring, not the JSON schema
assert modified_tool.json_schema["description"] == "Multiply a value by a multiplier"
# Verify parameters are from the source code, not the custom JSON schema
params = modified_tool.json_schema.get("parameters", {})
properties = params.get("properties", {})
assert "value" in properties
assert "multiplier" in properties
assert properties["value"]["type"] == "number"
assert properties["multiplier"]["type"] == "number"
assert params["required"] == ["value"]
finally:
# Clean up
client.tools.delete(tool_id=tool.id)