fix: Fix upsert base tools for rethink memory (#1940)
This commit is contained in:
@@ -41,18 +41,24 @@ class ToolManager:
|
||||
@enforce_types
|
||||
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool:
|
||||
tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool_id:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
# If there's anything to update
|
||||
if update_data:
|
||||
tool = self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor)
|
||||
# In case we want to update the tool type
|
||||
# Useful if we are shuffling around base tools
|
||||
updated_tool_type = None
|
||||
if "tool_type" in update_data:
|
||||
updated_tool_type = update_data.get("tool_type")
|
||||
tool = self.update_tool_by_id(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
||||
)
|
||||
tool = self.get_tool_by_id(tool_id, actor=actor)
|
||||
else:
|
||||
tool = self.create_tool(pydantic_tool, actor=actor)
|
||||
|
||||
@@ -114,6 +120,16 @@ class ToolManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
|
||||
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
|
||||
try:
|
||||
with self.session_maker() as session:
|
||||
tool = ToolModel.read(db_session=session, name=tool_name, actor=actor)
|
||||
return tool.id
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def list_tools(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
@@ -156,7 +172,9 @@ class ToolManager:
|
||||
return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET)
|
||||
|
||||
@enforce_types
|
||||
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool:
|
||||
def update_tool_by_id(
|
||||
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
with self.session_maker() as session:
|
||||
# Fetch the tool by ID
|
||||
@@ -175,6 +193,9 @@ class ToolManager:
|
||||
tool.json_schema = new_schema
|
||||
tool.name = new_schema["name"]
|
||||
|
||||
if updated_tool_type:
|
||||
tool.tool_type = updated_tool_type
|
||||
|
||||
# Save the updated tool to the database
|
||||
return tool.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
@@ -248,5 +269,4 @@ class ToolManager:
|
||||
)
|
||||
|
||||
# TODO: Delete any base tools that are stale
|
||||
|
||||
return tools
|
||||
|
||||
@@ -2191,6 +2191,12 @@ def test_update_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
# Assertions to check if the update was successful
|
||||
assert updated_tool.description == updated_description
|
||||
assert updated_tool.return_char_limit == return_char_limit
|
||||
assert updated_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
# Dangerous: we bypass safety to give it another tool type
|
||||
server.tool_manager.update_tool_by_id(print_tool.id, tool_update, actor=default_user, updated_tool_type=ToolType.EXTERNAL_LANGCHAIN)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
assert updated_tool.tool_type == ToolType.EXTERNAL_LANGCHAIN
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, print_tool, default_user):
|
||||
|
||||
Reference in New Issue
Block a user