diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 7c63bf52..6b877db4 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -41,18 +41,24 @@ class ToolManager: @enforce_types def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" - tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor) - if tool: + tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor) + if tool_id: # Put to dict and remove fields that should not be reset update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True) # If there's anything to update if update_data: - tool = self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor) + # In case we want to update the tool type + # Useful if we are shuffling around base tools + updated_tool_type = None + if "tool_type" in update_data: + updated_tool_type = update_data.get("tool_type") + tool = self.update_tool_by_id(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type) else: printd( f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update." ) + tool = self.get_tool_by_id(tool_id, actor=actor) else: tool = self.create_tool(pydantic_tool, actor=actor) @@ -114,6 +120,16 @@ class ToolManager: except NoResultFound: return None + @enforce_types + def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]: + """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" + try: + with self.session_maker() as session: + tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) + return tool.id + except NoResultFound: + return None + @enforce_types def list_tools(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination.""" @@ -156,7 +172,9 @@ class ToolManager: return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET) @enforce_types - def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool: + def update_tool_by_id( + self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None + ) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" with self.session_maker() as session: # Fetch the tool by ID @@ -175,6 +193,9 @@ class ToolManager: tool.json_schema = new_schema tool.name = new_schema["name"] + if updated_tool_type: + tool.tool_type = updated_tool_type + # Save the updated tool to the database return tool.update(db_session=session, actor=actor).to_pydantic() @@ -248,5 +269,4 @@ class ToolManager: ) # TODO: Delete any base tools that are stale - return tools diff --git a/tests/test_managers.py b/tests/test_managers.py index fefa1657..fa69b802 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2191,6 +2191,12 @@ def test_update_tool_by_id(server: SyncServer, print_tool, default_user): # Assertions to check if the update was successful assert updated_tool.description == updated_description assert updated_tool.return_char_limit == return_char_limit + assert updated_tool.tool_type == ToolType.CUSTOM + + # Dangerous: we bypass safety to give it another tool type + server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user, updated_tool_type=ToolType.EXTERNAL_LANGCHAIN) + updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user) + assert updated_tool.tool_type == ToolType.EXTERNAL_LANGCHAIN def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, print_tool, default_user):