From 47d8650f094f1ff1d7018c6110a14bd028a8b200 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 5 Aug 2025 11:36:23 -0700 Subject: [PATCH] fix: Fix tool renaming if json schema is passed in (#3745) --- letta/services/agent_manager.py | 2 +- letta/services/tool_manager.py | 84 ++++++++++++++++++++++++--------- tests/test_sdk_client.py | 82 ++++++++++++++++++++++++++++++++ 3 files changed, 145 insertions(+), 23 deletions(-) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 6c0be674..05232e27 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1479,7 +1479,7 @@ class AgentManager: ): pydantic_tool = existing_pydantic_tool else: - pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor) + pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor, bypass_name_check=True) pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor) diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 123d2a25..9c76a825 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -46,7 +46,7 @@ class ToolManager: # TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object @enforce_types @trace_method - def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: + def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor) if tool_id: @@ -60,7 +60,9 @@ class ToolManager: updated_tool_type = None if "tool_type" in update_data: updated_tool_type = update_data.get("tool_type") - tool = self.update_tool_by_id(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type) + tool = self.update_tool_by_id( + tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check + ) else: printd( f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update." @@ -73,7 +75,9 @@ class ToolManager: @enforce_types @trace_method - async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: + async def create_or_update_tool_async( + self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False + ) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor) if tool_id: @@ -88,7 +92,9 @@ class ToolManager: updated_tool_type = None if "tool_type" in update_data: updated_tool_type = update_data.get("tool_type") - tool = await self.update_tool_by_id_async(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type) + tool = await self.update_tool_by_id_async( + tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check + ) else: printd( f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update." @@ -387,7 +393,12 @@ class ToolManager: @enforce_types @trace_method def update_tool_by_id( - self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None + self, + tool_id: str, + tool_update: ToolUpdate, + actor: PydanticUser, + updated_tool_type: Optional[ToolType] = None, + bypass_name_check: bool = False, ) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" # First, check if source code update would cause a name conflict @@ -395,17 +406,29 @@ class ToolManager: new_name = None new_schema = None - if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): - # Derive the new schema and name from the source code - new_schema = derive_openai_json_schema(source_code=update_data["source_code"]) - new_name = new_schema["name"] + # TODO: Consider this behavior...is this what we want? + # TODO: I feel like it's bad if json_schema strays from source code so + # if source code is provided, always derive the name from it + if "source_code" in update_data.keys() and not bypass_name_check: + # derive the schema from source code to get the function name + derived_schema = derive_openai_json_schema(source_code=update_data["source_code"]) + new_name = derived_schema["name"] - # Get current tool to check if name is changing + # if json_schema wasn't provided, use the derived schema + if "json_schema" not in update_data.keys(): + new_schema = derived_schema + else: + # if json_schema was provided, update only its name to match the source code + new_schema = update_data["json_schema"].copy() + new_schema["name"] = new_name + # update the json_schema in update_data so it gets applied in the loop + update_data["json_schema"] = new_schema + + # get current tool to check if name is changing current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor) - - # Check if the name is changing and if so, verify it doesn't conflict + # check if the name is changing and if so, verify it doesn't conflict if new_name != current_tool.name: - # Check if a tool with the new name already exists + # check if a tool with the new name already exists existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor) if existing_tool: raise LettaToolNameConflictError(tool_name=new_name) @@ -433,7 +456,12 @@ class ToolManager: @enforce_types @trace_method async def update_tool_by_id_async( - self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None + self, + tool_id: str, + tool_update: ToolUpdate, + actor: PydanticUser, + updated_tool_type: Optional[ToolType] = None, + bypass_name_check: bool = False, ) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" # First, check if source code update would cause a name conflict @@ -441,17 +469,29 @@ class ToolManager: new_name = None new_schema = None - if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): - # Derive the new schema and name from the source code - new_schema = derive_openai_json_schema(source_code=update_data["source_code"]) - new_name = new_schema["name"] + # TODO: Consider this behavior...is this what we want? + # TODO: I feel like it's bad if json_schema strays from source code so + # if source code is provided, always derive the name from it + if "source_code" in update_data.keys() and not bypass_name_check: + # derive the schema from source code to get the function name + derived_schema = derive_openai_json_schema(source_code=update_data["source_code"]) + new_name = derived_schema["name"] - # Get current tool to check if name is changing + # if json_schema wasn't provided, use the derived schema + if "json_schema" not in update_data.keys(): + new_schema = derived_schema + else: + # if json_schema was provided, update only its name to match the source code + new_schema = update_data["json_schema"].copy() + new_schema["name"] = new_name + # update the json_schema in update_data so it gets applied in the loop + update_data["json_schema"] = new_schema + + # get current tool to check if name is changing current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor) - - # Check if the name is changing and if so, verify it doesn't conflict + # check if the name is changing and if so, verify it doesn't conflict if new_name != current_tool.name: - # Check if a tool with the new name already exists + # check if a tool with the new name already exists name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor) if name_exists: raise LettaToolNameConflictError(tool_name=new_name) diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index d3b15636..335141bf 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1440,3 +1440,85 @@ def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient): finally: # Clean up client.tools.delete(tool_id=tool.id) + + +def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient): + """Test that passing both new JSON schema AND source code still renames the tool based on source code""" + import textwrap + + # Create initial tool + def initial_tool(x: int) -> int: + """ + Multiply a number by 2 + + Args: + x: The input number + + Returns: + The input multiplied by 2 + """ + return x * 2 + + # Create the tool + tool = client.tools.upsert_from_function(func=initial_tool) + assert tool.name == "initial_tool" + + try: + # Define new function source code with different name + new_source_code = textwrap.dedent( + """ + def renamed_function(value: float, multiplier: float = 2.0) -> float: + ''' + Multiply a value by a multiplier + + Args: + value: The input value + multiplier: The multiplier to use (default 2.0) + + Returns: + The value multiplied by the multiplier + ''' + return value * multiplier + """ + ).strip() + + # Create a custom JSON schema that has a different name + custom_json_schema = { + "name": "custom_schema_name", + "description": "Custom description from JSON schema", + "parameters": { + "type": "object", + "properties": { + "value": {"type": "number", "description": "Input value from JSON schema"}, + "multiplier": {"type": "number", "description": "Multiplier from JSON schema", "default": 2.0}, + }, + "required": ["value"], + }, + } + + # Modify the tool with both new source code AND JSON schema + modified_tool = client.tools.modify(tool_id=tool.id, source_code=new_source_code, json_schema=custom_json_schema) + + # Verify the name comes from the source code function name, not the JSON schema + assert modified_tool.name == "renamed_function" + assert modified_tool.source_code == new_source_code + + # Verify the JSON schema was updated to match the function name from source code + assert modified_tool.json_schema is not None + assert modified_tool.json_schema["name"] == "renamed_function" + + # The description should come from the source code docstring, not the JSON schema + assert modified_tool.json_schema["description"] == "Multiply a value by a multiplier" + + # Verify parameters are from the source code, not the custom JSON schema + params = modified_tool.json_schema.get("parameters", {}) + properties = params.get("properties", {}) + assert "value" in properties + assert "multiplier" in properties + assert properties["value"]["type"] == "number" + assert properties["multiplier"]["type"] == "number" + assert params["required"] == ["value"] + + finally: + # Clean up + client.tools.delete(tool_id=tool.id)