From 132c789ec5e0ce4ab34c130ad8c33d86def1e72c Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sun, 21 Apr 2024 17:57:58 -0700 Subject: [PATCH] feat: disallow creation of tools with the same name (#1285) --- memgpt/metadata.py | 13 ++++++++++++- memgpt/server/rest_api/server.py | 4 +++- memgpt/server/rest_api/tools/index.py | 11 ++++++++--- memgpt/server/server.py | 6 ++++-- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index a48bde51..f02f953f 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -545,7 +545,8 @@ class MetadataStore: return [r.to_record() for r in results] @enforce_types - def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: + # def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools + def list_tools(self) -> List[ToolModel]: with self.session_maker() as session: available_functions = load_all_function_sets() results = [ @@ -622,6 +623,16 @@ class MetadataStore: assert len(results) == 1, f"Expected 1 result, got {len(results)}" return results[0].to_record() + @enforce_types + def get_tool(self, tool_name: str) -> Optional[ToolModel]: + # TODO: add user_id when tools can eventually be added by users + with self.session_maker() as session: + results = session.query(ToolModel).filter(ToolModel.name == tool_name).all() + if len(results) == 0: + return None + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + return results[0] + # agent source metadata @enforce_types def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID): diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index fcc46c15..c1a3981e 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -93,7 +93,9 @@ app.include_router(setup_agents_message_router(server, interface, password), pre app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) -app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX) +app.include_router( + setup_tools_index_router(server, interface, password), prefix=API_PREFIX, dependencies=[Depends(verify_password)] +) # admin only app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX) diff --git a/memgpt/server/rest_api/tools/index.py b/memgpt/server/rest_api/tools/index.py index 6b5a0672..3f0e2e17 100644 --- a/memgpt/server/rest_api/tools/index.py +++ b/memgpt/server/rest_api/tools/index.py @@ -33,26 +33,31 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa @router.get("/tools", tags=["tools"], response_model=ListToolsResponse) async def list_all_tools( - user_id: uuid.UUID = Depends(get_current_user_with_server), + # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific ): """ Get a list of all tools available to agents created by a user """ # Clear the interface interface.clear() - tools = server.ms.list_tools(user_id=user_id) + # tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific + tools = server.ms.list_tools() return ListToolsResponse(tools=tools) @router.post("/tools", tags=["tools"], response_model=CreateToolResponse) async def create_tool( request: CreateToolRequest = Body(...), - user_id: uuid.UUID = Depends(get_current_user_with_server), + # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific ): """ Create a new tool (dummy route) """ from memgpt.functions.functions import write_function + # check if function already exists + if server.ms.get_tool(request.name): + raise ValueError(f"Tool with name {request.name} already exists.") + # write function to ~/.memgt/functions directory write_function(request.name, request.name, request.source_code) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 443b4896..9bcb2404 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -697,7 +697,8 @@ class SyncServer(LockingServer): # TODO remove (https://github.com/cpacker/MemGPT/issues/1138) if function_names is not None: preset_override = True - available_tools = self.ms.list_tools(user_id=user_id) + # available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific + available_tools = self.ms.list_tools() available_tools_names = [t.name for t in available_tools] assert all([f_name in available_tools_names for f_name in function_names]) preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names] @@ -819,7 +820,8 @@ class SyncServer(LockingServer): # TODO add a get_message_obj_from_message_id(...) function # this would allow grabbing Message.created_by without having to load the agent object - all_available_tools = self.ms.list_tools(user_id=user_id) + # all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific + all_available_tools = self.ms.list_tools() for agent_state, return_dict in zip(agents_states, agents_states_dicts):