feat: disallow creation of tools with the same name (#1285)
This commit is contained in:
@@ -545,7 +545,8 @@ class MetadataStore:
|
|||||||
return [r.to_record() for r in results]
|
return [r.to_record() for r in results]
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
|
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||||
|
def list_tools(self) -> List[ToolModel]:
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
available_functions = load_all_function_sets()
|
available_functions = load_all_function_sets()
|
||||||
results = [
|
results = [
|
||||||
@@ -622,6 +623,16 @@ class MetadataStore:
|
|||||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||||
return results[0].to_record()
|
return results[0].to_record()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_tool(self, tool_name: str) -> Optional[ToolModel]:
|
||||||
|
# TODO: add user_id when tools can eventually be added by users
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = session.query(ToolModel).filter(ToolModel.name == tool_name).all()
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||||
|
return results[0]
|
||||||
|
|
||||||
# agent source metadata
|
# agent source metadata
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):
|
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):
|
||||||
|
|||||||
@@ -93,7 +93,9 @@ app.include_router(setup_agents_message_router(server, interface, password), pre
|
|||||||
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
|
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
|
||||||
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
|
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
|
||||||
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
||||||
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
|
app.include_router(
|
||||||
|
setup_tools_index_router(server, interface, password), prefix=API_PREFIX, dependencies=[Depends(verify_password)]
|
||||||
|
) # admin only
|
||||||
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
|
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
|
||||||
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)
|
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)
|
||||||
|
|
||||||
|
|||||||
@@ -33,26 +33,31 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
|
|||||||
|
|
||||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||||
async def list_all_tools(
|
async def list_all_tools(
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get a list of all tools available to agents created by a user
|
Get a list of all tools available to agents created by a user
|
||||||
"""
|
"""
|
||||||
# Clear the interface
|
# Clear the interface
|
||||||
interface.clear()
|
interface.clear()
|
||||||
tools = server.ms.list_tools(user_id=user_id)
|
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
|
||||||
|
tools = server.ms.list_tools()
|
||||||
return ListToolsResponse(tools=tools)
|
return ListToolsResponse(tools=tools)
|
||||||
|
|
||||||
@router.post("/tools", tags=["tools"], response_model=CreateToolResponse)
|
@router.post("/tools", tags=["tools"], response_model=CreateToolResponse)
|
||||||
async def create_tool(
|
async def create_tool(
|
||||||
request: CreateToolRequest = Body(...),
|
request: CreateToolRequest = Body(...),
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new tool (dummy route)
|
Create a new tool (dummy route)
|
||||||
"""
|
"""
|
||||||
from memgpt.functions.functions import write_function
|
from memgpt.functions.functions import write_function
|
||||||
|
|
||||||
|
# check if function already exists
|
||||||
|
if server.ms.get_tool(request.name):
|
||||||
|
raise ValueError(f"Tool with name {request.name} already exists.")
|
||||||
|
|
||||||
# write function to ~/.memgt/functions directory
|
# write function to ~/.memgt/functions directory
|
||||||
write_function(request.name, request.name, request.source_code)
|
write_function(request.name, request.name, request.source_code)
|
||||||
|
|
||||||
|
|||||||
@@ -697,7 +697,8 @@ class SyncServer(LockingServer):
|
|||||||
# TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
|
# TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
|
||||||
if function_names is not None:
|
if function_names is not None:
|
||||||
preset_override = True
|
preset_override = True
|
||||||
available_tools = self.ms.list_tools(user_id=user_id)
|
# available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
|
||||||
|
available_tools = self.ms.list_tools()
|
||||||
available_tools_names = [t.name for t in available_tools]
|
available_tools_names = [t.name for t in available_tools]
|
||||||
assert all([f_name in available_tools_names for f_name in function_names])
|
assert all([f_name in available_tools_names for f_name in function_names])
|
||||||
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
|
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
|
||||||
@@ -819,7 +820,8 @@ class SyncServer(LockingServer):
|
|||||||
|
|
||||||
# TODO add a get_message_obj_from_message_id(...) function
|
# TODO add a get_message_obj_from_message_id(...) function
|
||||||
# this would allow grabbing Message.created_by without having to load the agent object
|
# this would allow grabbing Message.created_by without having to load the agent object
|
||||||
all_available_tools = self.ms.list_tools(user_id=user_id)
|
# all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
|
||||||
|
all_available_tools = self.ms.list_tools()
|
||||||
|
|
||||||
for agent_state, return_dict in zip(agents_states, agents_states_dicts):
|
for agent_state, return_dict in zip(agents_states, agents_states_dicts):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user