fix: Fix upsert base tools for rethink memory (#1940)

This commit is contained in:
Matthew Zhou
2025-04-29 16:33:18 -07:00
committed by GitHub
parent 3630c76814
commit 2548e0271d
2 changed files with 31 additions and 5 deletions

View File

@@ -41,18 +41,24 @@ class ToolManager:
@enforce_types
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
if tool:
tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor)
if tool_id:
# Put to dict and remove fields that should not be reset
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
# If there's anything to update
if update_data:
tool = self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor)
# In case we want to update the tool type
# Useful if we are shuffling around base tools
updated_tool_type = None
if "tool_type" in update_data:
updated_tool_type = update_data.get("tool_type")
tool = self.update_tool_by_id(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
else:
printd(
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
)
tool = self.get_tool_by_id(tool_id, actor=actor)
else:
tool = self.create_tool(pydantic_tool, actor=actor)
@@ -114,6 +120,16 @@ class ToolManager:
except NoResultFound:
return None
@enforce_types
def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
with self.session_maker() as session:
tool = ToolModel.read(db_session=session, name=tool_name, actor=actor)
return tool.id
except NoResultFound:
return None
@enforce_types
def list_tools(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
"""List all tools with optional pagination."""
@@ -156,7 +172,9 @@ class ToolManager:
return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET)
@enforce_types
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool:
def update_tool_by_id(
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
) -> PydanticTool:
"""Update a tool by its ID with the given ToolUpdate object."""
with self.session_maker() as session:
# Fetch the tool by ID
@@ -175,6 +193,9 @@ class ToolManager:
tool.json_schema = new_schema
tool.name = new_schema["name"]
if updated_tool_type:
tool.tool_type = updated_tool_type
# Save the updated tool to the database
return tool.update(db_session=session, actor=actor).to_pydantic()
@@ -248,5 +269,4 @@ class ToolManager:
)
# TODO: Delete any base tools that are stale
return tools

View File

@@ -2191,6 +2191,12 @@ def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
# Assertions to check if the update was successful
assert updated_tool.description == updated_description
assert updated_tool.return_char_limit == return_char_limit
assert updated_tool.tool_type == ToolType.CUSTOM
# Dangerous: we bypass safety to give it another tool type
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user, updated_tool_type=ToolType.EXTERNAL_LANGCHAIN)
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
assert updated_tool.tool_type == ToolType.EXTERNAL_LANGCHAIN
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, print_tool, default_user):