feat: disallow creation of tools with the same name (#1285)
This commit is contained in:
@@ -545,7 +545,8 @@ class MetadataStore:
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
|
||||
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||
def list_tools(self) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
available_functions = load_all_function_sets()
|
||||
results = [
|
||||
@@ -622,6 +623,16 @@ class MetadataStore:
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_tool(self, tool_name: str) -> Optional[ToolModel]:
|
||||
# TODO: add user_id when tools can eventually be added by users
|
||||
with self.session_maker() as session:
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0]
|
||||
|
||||
# agent source metadata
|
||||
@enforce_types
|
||||
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):
|
||||
|
||||
@@ -93,7 +93,9 @@ app.include_router(setup_agents_message_router(server, interface, password), pre
|
||||
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(
|
||||
setup_tools_index_router(server, interface, password), prefix=API_PREFIX, dependencies=[Depends(verify_password)]
|
||||
) # admin only
|
||||
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
|
||||
@@ -33,26 +33,31 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||
async def list_all_tools(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
tools = server.ms.list_tools(user_id=user_id)
|
||||
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
|
||||
tools = server.ms.list_tools()
|
||||
return ListToolsResponse(tools=tools)
|
||||
|
||||
@router.post("/tools", tags=["tools"], response_model=CreateToolResponse)
|
||||
async def create_tool(
|
||||
request: CreateToolRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
):
|
||||
"""
|
||||
Create a new tool (dummy route)
|
||||
"""
|
||||
from memgpt.functions.functions import write_function
|
||||
|
||||
# check if function already exists
|
||||
if server.ms.get_tool(request.name):
|
||||
raise ValueError(f"Tool with name {request.name} already exists.")
|
||||
|
||||
# write function to ~/.memgt/functions directory
|
||||
write_function(request.name, request.name, request.source_code)
|
||||
|
||||
|
||||
@@ -697,7 +697,8 @@ class SyncServer(LockingServer):
|
||||
# TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
|
||||
if function_names is not None:
|
||||
preset_override = True
|
||||
available_tools = self.ms.list_tools(user_id=user_id)
|
||||
# available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
|
||||
available_tools = self.ms.list_tools()
|
||||
available_tools_names = [t.name for t in available_tools]
|
||||
assert all([f_name in available_tools_names for f_name in function_names])
|
||||
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
|
||||
@@ -819,7 +820,8 @@ class SyncServer(LockingServer):
|
||||
|
||||
# TODO add a get_message_obj_from_message_id(...) function
|
||||
# this would allow grabbing Message.created_by without having to load the agent object
|
||||
all_available_tools = self.ms.list_tools(user_id=user_id)
|
||||
# all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
|
||||
all_available_tools = self.ms.list_tools()
|
||||
|
||||
for agent_state, return_dict in zip(agents_states, agents_states_dicts):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user