feat: disallow creation of tools with the same name (#1285)

This commit is contained in:
Sarah Wooders
2024-04-21 17:57:58 -07:00
committed by GitHub
parent b35de11d96
commit 132c789ec5
4 changed files with 27 additions and 7 deletions

View File

@@ -545,7 +545,8 @@ class MetadataStore:
return [r.to_record() for r in results] return [r.to_record() for r in results]
@enforce_types @enforce_types
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self) -> List[ToolModel]:
with self.session_maker() as session: with self.session_maker() as session:
available_functions = load_all_function_sets() available_functions = load_all_function_sets()
results = [ results = [
@@ -622,6 +623,16 @@ class MetadataStore:
assert len(results) == 1, f"Expected 1 result, got {len(results)}" assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record() return results[0].to_record()
@enforce_types
def get_tool(self, tool_name: str) -> Optional[ToolModel]:
# TODO: add user_id when tools can eventually be added by users
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.name == tool_name).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
# agent source metadata # agent source metadata
@enforce_types @enforce_types
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID): def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):

View File

@@ -93,7 +93,9 @@ app.include_router(setup_agents_message_router(server, interface, password), pre
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(
setup_tools_index_router(server, interface, password), prefix=API_PREFIX, dependencies=[Depends(verify_password)]
) # admin only
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)

View File

@@ -33,26 +33,31 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse) @router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools( async def list_all_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server), # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
): ):
""" """
Get a list of all tools available to agents created by a user Get a list of all tools available to agents created by a user
""" """
# Clear the interface # Clear the interface
interface.clear() interface.clear()
tools = server.ms.list_tools(user_id=user_id) # tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
tools = server.ms.list_tools()
return ListToolsResponse(tools=tools) return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=CreateToolResponse) @router.post("/tools", tags=["tools"], response_model=CreateToolResponse)
async def create_tool( async def create_tool(
request: CreateToolRequest = Body(...), request: CreateToolRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server), # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
): ):
""" """
Create a new tool (dummy route) Create a new tool (dummy route)
""" """
from memgpt.functions.functions import write_function from memgpt.functions.functions import write_function
# check if function already exists
if server.ms.get_tool(request.name):
raise ValueError(f"Tool with name {request.name} already exists.")
# write function to ~/.memgt/functions directory # write function to ~/.memgt/functions directory
write_function(request.name, request.name, request.source_code) write_function(request.name, request.name, request.source_code)

View File

@@ -697,7 +697,8 @@ class SyncServer(LockingServer):
# TODO remove (https://github.com/cpacker/MemGPT/issues/1138) # TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
if function_names is not None: if function_names is not None:
preset_override = True preset_override = True
available_tools = self.ms.list_tools(user_id=user_id) # available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
available_tools = self.ms.list_tools()
available_tools_names = [t.name for t in available_tools] available_tools_names = [t.name for t in available_tools]
assert all([f_name in available_tools_names for f_name in function_names]) assert all([f_name in available_tools_names for f_name in function_names])
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names] preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
@@ -819,7 +820,8 @@ class SyncServer(LockingServer):
# TODO add a get_message_obj_from_message_id(...) function # TODO add a get_message_obj_from_message_id(...) function
# this would allow grabbing Message.created_by without having to load the agent object # this would allow grabbing Message.created_by without having to load the agent object
all_available_tools = self.ms.list_tools(user_id=user_id) # all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
all_available_tools = self.ms.list_tools()
for agent_state, return_dict in zip(agents_states, agents_states_dicts): for agent_state, return_dict in zip(agents_states, agents_states_dicts):