feat: disallow creation of tools with the same name (#1285)

This commit is contained in:
Sarah Wooders
2024-04-21 17:57:58 -07:00
committed by GitHub
parent b35de11d96
commit 132c789ec5
4 changed files with 27 additions and 7 deletions

View File

@@ -545,7 +545,8 @@ class MetadataStore:
return [r.to_record() for r in results]
@enforce_types
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self) -> List[ToolModel]:
with self.session_maker() as session:
available_functions = load_all_function_sets()
results = [
@@ -622,6 +623,16 @@ class MetadataStore:
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_tool(self, tool_name: str) -> Optional[ToolModel]:
# TODO: add user_id when tools can eventually be added by users
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.name == tool_name).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
# agent source metadata
@enforce_types
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):

View File

@@ -93,7 +93,9 @@ app.include_router(setup_agents_message_router(server, interface, password), pre
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(
setup_tools_index_router(server, interface, password), prefix=API_PREFIX, dependencies=[Depends(verify_password)]
) # admin only
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)

View File

@@ -33,26 +33,31 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server),
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
tools = server.ms.list_tools(user_id=user_id)
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
tools = server.ms.list_tools()
return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=CreateToolResponse)
async def create_tool(
request: CreateToolRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
"""
Create a new tool (dummy route)
"""
from memgpt.functions.functions import write_function
# check if function already exists
if server.ms.get_tool(request.name):
raise ValueError(f"Tool with name {request.name} already exists.")
# write function to ~/.memgt/functions directory
write_function(request.name, request.name, request.source_code)

View File

@@ -697,7 +697,8 @@ class SyncServer(LockingServer):
# TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
if function_names is not None:
preset_override = True
available_tools = self.ms.list_tools(user_id=user_id)
# available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
available_tools = self.ms.list_tools()
available_tools_names = [t.name for t in available_tools]
assert all([f_name in available_tools_names for f_name in function_names])
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
@@ -819,7 +820,8 @@ class SyncServer(LockingServer):
# TODO add a get_message_obj_from_message_id(...) function
# this would allow grabbing Message.created_by without having to load the agent object
all_available_tools = self.ms.list_tools(user_id=user_id)
# all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
all_available_tools = self.ms.list_tools()
for agent_state, return_dict in zip(agents_states, agents_states_dicts):