fix: Fix tool renaming if json schema is passed in (#3745)
This commit is contained in:
@@ -1479,7 +1479,7 @@ class AgentManager:
|
||||
):
|
||||
pydantic_tool = existing_pydantic_tool
|
||||
else:
|
||||
pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor)
|
||||
pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor, bypass_name_check=True)
|
||||
|
||||
pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class ToolManager:
|
||||
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool_id:
|
||||
@@ -60,7 +60,9 @@ class ToolManager:
|
||||
updated_tool_type = None
|
||||
if "tool_type" in update_data:
|
||||
updated_tool_type = update_data.get("tool_type")
|
||||
tool = self.update_tool_by_id(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
|
||||
tool = self.update_tool_by_id(
|
||||
tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
|
||||
)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
||||
@@ -73,7 +75,9 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
async def create_or_update_tool_async(
|
||||
self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False
|
||||
) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool_id:
|
||||
@@ -88,7 +92,9 @@ class ToolManager:
|
||||
updated_tool_type = None
|
||||
if "tool_type" in update_data:
|
||||
updated_tool_type = update_data.get("tool_type")
|
||||
tool = await self.update_tool_by_id_async(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
|
||||
tool = await self.update_tool_by_id_async(
|
||||
tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
|
||||
)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
||||
@@ -387,7 +393,12 @@ class ToolManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_tool_by_id(
|
||||
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
|
||||
self,
|
||||
tool_id: str,
|
||||
tool_update: ToolUpdate,
|
||||
actor: PydanticUser,
|
||||
updated_tool_type: Optional[ToolType] = None,
|
||||
bypass_name_check: bool = False,
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
# First, check if source code update would cause a name conflict
|
||||
@@ -395,17 +406,29 @@ class ToolManager:
|
||||
new_name = None
|
||||
new_schema = None
|
||||
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
# Derive the new schema and name from the source code
|
||||
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
new_name = new_schema["name"]
|
||||
# TODO: Consider this behavior...is this what we want?
|
||||
# TODO: I feel like it's bad if json_schema strays from source code so
|
||||
# if source code is provided, always derive the name from it
|
||||
if "source_code" in update_data.keys() and not bypass_name_check:
|
||||
# derive the schema from source code to get the function name
|
||||
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
new_name = derived_schema["name"]
|
||||
|
||||
# Get current tool to check if name is changing
|
||||
# if json_schema wasn't provided, use the derived schema
|
||||
if "json_schema" not in update_data.keys():
|
||||
new_schema = derived_schema
|
||||
else:
|
||||
# if json_schema was provided, update only its name to match the source code
|
||||
new_schema = update_data["json_schema"].copy()
|
||||
new_schema["name"] = new_name
|
||||
# update the json_schema in update_data so it gets applied in the loop
|
||||
update_data["json_schema"] = new_schema
|
||||
|
||||
# get current tool to check if name is changing
|
||||
current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
|
||||
# Check if the name is changing and if so, verify it doesn't conflict
|
||||
# check if the name is changing and if so, verify it doesn't conflict
|
||||
if new_name != current_tool.name:
|
||||
# Check if a tool with the new name already exists
|
||||
# check if a tool with the new name already exists
|
||||
existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor)
|
||||
if existing_tool:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
@@ -433,7 +456,12 @@ class ToolManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_tool_by_id_async(
|
||||
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
|
||||
self,
|
||||
tool_id: str,
|
||||
tool_update: ToolUpdate,
|
||||
actor: PydanticUser,
|
||||
updated_tool_type: Optional[ToolType] = None,
|
||||
bypass_name_check: bool = False,
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
# First, check if source code update would cause a name conflict
|
||||
@@ -441,17 +469,29 @@ class ToolManager:
|
||||
new_name = None
|
||||
new_schema = None
|
||||
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
# Derive the new schema and name from the source code
|
||||
new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
new_name = new_schema["name"]
|
||||
# TODO: Consider this behavior...is this what we want?
|
||||
# TODO: I feel like it's bad if json_schema strays from source code so
|
||||
# if source code is provided, always derive the name from it
|
||||
if "source_code" in update_data.keys() and not bypass_name_check:
|
||||
# derive the schema from source code to get the function name
|
||||
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
||||
new_name = derived_schema["name"]
|
||||
|
||||
# Get current tool to check if name is changing
|
||||
# if json_schema wasn't provided, use the derived schema
|
||||
if "json_schema" not in update_data.keys():
|
||||
new_schema = derived_schema
|
||||
else:
|
||||
# if json_schema was provided, update only its name to match the source code
|
||||
new_schema = update_data["json_schema"].copy()
|
||||
new_schema["name"] = new_name
|
||||
# update the json_schema in update_data so it gets applied in the loop
|
||||
update_data["json_schema"] = new_schema
|
||||
|
||||
# get current tool to check if name is changing
|
||||
current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
|
||||
|
||||
# Check if the name is changing and if so, verify it doesn't conflict
|
||||
# check if the name is changing and if so, verify it doesn't conflict
|
||||
if new_name != current_tool.name:
|
||||
# Check if a tool with the new name already exists
|
||||
# check if a tool with the new name already exists
|
||||
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
|
||||
if name_exists:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
|
||||
@@ -1440,3 +1440,85 @@ def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
"""Test that passing both new JSON schema AND source code still renames the tool based on source code"""
|
||||
import textwrap
|
||||
|
||||
# Create initial tool
|
||||
def initial_tool(x: int) -> int:
|
||||
"""
|
||||
Multiply a number by 2
|
||||
|
||||
Args:
|
||||
x: The input number
|
||||
|
||||
Returns:
|
||||
The input multiplied by 2
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
# Create the tool
|
||||
tool = client.tools.upsert_from_function(func=initial_tool)
|
||||
assert tool.name == "initial_tool"
|
||||
|
||||
try:
|
||||
# Define new function source code with different name
|
||||
new_source_code = textwrap.dedent(
|
||||
"""
|
||||
def renamed_function(value: float, multiplier: float = 2.0) -> float:
|
||||
'''
|
||||
Multiply a value by a multiplier
|
||||
|
||||
Args:
|
||||
value: The input value
|
||||
multiplier: The multiplier to use (default 2.0)
|
||||
|
||||
Returns:
|
||||
The value multiplied by the multiplier
|
||||
'''
|
||||
return value * multiplier
|
||||
"""
|
||||
).strip()
|
||||
|
||||
# Create a custom JSON schema that has a different name
|
||||
custom_json_schema = {
|
||||
"name": "custom_schema_name",
|
||||
"description": "Custom description from JSON schema",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {"type": "number", "description": "Input value from JSON schema"},
|
||||
"multiplier": {"type": "number", "description": "Multiplier from JSON schema", "default": 2.0},
|
||||
},
|
||||
"required": ["value"],
|
||||
},
|
||||
}
|
||||
|
||||
# Modify the tool with both new source code AND JSON schema
|
||||
modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema)
|
||||
|
||||
# Verify the name comes from the source code function name, not the JSON schema
|
||||
assert modified_tool.name == "renamed_function"
|
||||
assert modified_tool.source_code == new_source_code
|
||||
|
||||
# Verify the JSON schema was updated to match the function name from source code
|
||||
assert modified_tool.json_schema is not None
|
||||
assert modified_tool.json_schema["name"] == "renamed_function"
|
||||
|
||||
# The description should come from the source code docstring, not the JSON schema
|
||||
assert modified_tool.json_schema["description"] == "Multiply a value by a multiplier"
|
||||
|
||||
# Verify parameters are from the source code, not the custom JSON schema
|
||||
params = modified_tool.json_schema.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
assert "value" in properties
|
||||
assert "multiplier" in properties
|
||||
assert properties["value"]["type"] == "number"
|
||||
assert properties["multiplier"]["type"] == "number"
|
||||
assert params["required"] == ["value"]
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
Reference in New Issue
Block a user