feat: move tool functions to user (#1487)

This commit is contained in:
Sarah Wooders
2024-06-29 16:49:25 -07:00
committed by GitHub
parent 61364742a0
commit 26bfa0b6d8
8 changed files with 217 additions and 67 deletions

View File

@@ -6,7 +6,11 @@ from requests import HTTPError
from memgpt.functions.functions import parse_source_code
from memgpt.functions.schema_generator import generate_schema
from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.admin.tools import (
CreateToolRequest,
ListToolsResponse,
ToolModel,
)
from memgpt.server.rest_api.admin.users import (
CreateAPIKeyResponse,
CreateUserResponse,
@@ -15,7 +19,6 @@ from memgpt.server.rest_api.admin.users import (
GetAllUsersResponse,
GetAPIKeysResponse,
)
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
class Admin:
@@ -86,6 +89,7 @@ class Admin:
self.delete_key(key)
self.delete_user(user["user_id"])
# tools
def create_tool(
self,
func,
@@ -94,12 +98,10 @@ class Admin:
tags: Optional[List[str]] = None,
):
"""Create a tool
Args:
func (callable): The function to create a tool for.
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
update (bool, optional): Update the tool if it already exists. Defaults to True.
Returns:
Tool object
"""
@@ -110,11 +112,11 @@ class Admin:
source_code = parse_source_code(func)
json_schema = generate_schema(func, name)
source_type = "python"
tool_name = json_schema["name"]
json_schema["name"]
# create data
data = {"name": tool_name, "source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
CreateToolRequest(**data) # validate data:w
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
CreateToolRequest(**data) # validate
# make REST request
response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers)

View File

@@ -28,8 +28,6 @@ from memgpt.models.pydantic_models import (
SourceModel,
ToolModel,
)
# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.agents.command import CommandResponse
from memgpt.server.rest_api.agents.config import GetAgentResponse
from memgpt.server.rest_api.agents.index import CreateAgentResponse, ListAgentsResponse
@@ -54,6 +52,9 @@ from memgpt.server.rest_api.presets.index import (
ListPresetsResponse,
)
from memgpt.server.rest_api.sources.index import ListSourcesResponse
# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
from memgpt.server.server import SyncServer
@@ -235,8 +236,6 @@ class RESTClient(AbstractClient):
self.base_url = base_url
self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
# agents
def list_agents(self):
response = requests.get(f"{self.base_url}/api/agents", headers=self.headers)
return ListAgentsResponse(**response.json())
@@ -610,6 +609,67 @@ class RESTClient(AbstractClient):
response = requests.get(f"{self.base_url}/api/config", headers=self.headers)
return ConfigResponse(**response.json())
# tools
def create_tool(
self,
func,
name: Optional[str] = None,
update: Optional[bool] = True, # TODO: actually use this
tags: Optional[List[str]] = None,
):
"""Create a tool
Args:
func (callable): The function to create a tool for.
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
update (bool, optional): Update the tool if it already exists. Defaults to True.
Returns:
Tool object
"""
# TODO: check if tool already exists
# TODO: how to load modules?
# parse source code/schema
source_code = parse_source_code(func)
json_schema = generate_schema(func, name)
source_type = "python"
json_schema["name"]
# create data
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
try:
CreateToolRequest(**data) # validate data
except Exception as e:
raise ValueError(f"Failed to create tool: {e}, invalid input {data}")
# make REST request
response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return ToolModel(**response.json())
def list_tools(self) -> ListToolsResponse:
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list tools: {response.text}")
return ListToolsResponse(**response.json()).tools
def delete_tool(self, name: str):
response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete tool: {response.text}")
return response.json()
def get_tool(self, name: str):
response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return ToolModel(**response.json())
class LocalClient(AbstractClient):
def __init__(
@@ -820,7 +880,7 @@ class LocalClient(AbstractClient):
tool_name = json_schema["name"]
# check if already exists:
existing_tool = self.server.ms.get_tool(tool_name)
existing_tool = self.server.ms.get_tool(tool_name, self.user_id)
if existing_tool:
if update:
# update existing tool
@@ -829,13 +889,15 @@ class LocalClient(AbstractClient):
existing_tool.tags = tags
existing_tool.json_schema = json_schema
self.server.ms.update_tool(existing_tool)
return self.server.ms.get_tool(tool_name)
return self.server.ms.get_tool(tool_name, self.user_id)
else:
raise ValueError(f"Tool {name} already exists and update=False")
tool = ToolModel(name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema)
tool = ToolModel(
name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema, user_id=self.user_id
)
self.server.ms.add_tool(tool)
return self.server.ms.get_tool(tool_name)
return self.server.ms.get_tool(tool_name, self.user_id)
def list_tools(self):
"""List available tools.
@@ -844,7 +906,13 @@ class LocalClient(AbstractClient):
tools (List[ToolModel]): A list of available tools.
"""
return self.server.ms.list_tools()
return self.server.ms.list_tools(user_id=self.user_id)
def get_tool(self, name: str):
return self.server.ms.get_tool(name, user_id=self.user_id)
def delete_tool(self, name: str):
return self.server.ms.delete_tool(name, user_id=self.user_id)
# data sources

View File

@@ -604,9 +604,11 @@ class MetadataStore:
@enforce_types
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self) -> List[ToolModel]:
def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]:
with self.session_maker() as session:
results = session.query(ToolModel).all()
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
return results
@enforce_types
@@ -677,10 +679,13 @@ class MetadataStore:
return results[0].to_record()
@enforce_types
def get_tool(self, tool_name: str) -> Optional[ToolModel]:
def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]:
# TODO: add user_id when tools can eventually be added by users
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.name == tool_name).all()
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
@@ -752,6 +757,8 @@ class MetadataStore:
@enforce_types
def add_tool(self, tool: ToolModel):
with self.session_maker() as session:
if self.get_tool(tool.name, tool.user_id):
raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}")
session.add(tool)
session.commit()
@@ -811,9 +818,9 @@ class MetadataStore:
session.commit()
@enforce_types
def delete_tool(self, name: str):
def delete_tool(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.name == name).delete()
session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete()
session.commit()
# job related functions

View File

@@ -56,8 +56,8 @@ class PresetModel(BaseModel):
class ToolModel(SQLModel, table=True):
# TODO move into database
name: str = Field(..., description="The name of the function.", primary_key=True)
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.")
name: str = Field(..., description="The name of the function.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True)
tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")
@@ -65,6 +65,9 @@ class ToolModel(SQLModel, table=True):
json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")
# optional: user_id (user-specific tools)
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.")
# Needed for Column(JSON)
class Config:
arbitrary_types_allowed = True

View File

@@ -39,7 +39,7 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
# Clear the interface
interface.clear()
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name)
server.ms.delete_tool(name=tool_name, user_id=None)
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool(tool_name: str):
@@ -49,29 +49,26 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
# Clear the interface
interface.clear()
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
tool = server.ms.get_tool(tool_name=tool_name)
tool = server.ms.get_tool(tool_name=tool_name, user_id=None)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools(
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
async def list_all_tools():
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
tools = server.ms.list_tools()
tools = server.ms.list_tools(user_id=None)
return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool(
request: CreateToolRequest = Body(...),
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
"""
Create a new tool

View File

@@ -2,7 +2,7 @@ import uuid
from functools import partial
from typing import List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import ToolModel
@@ -18,7 +18,7 @@ class ListToolsResponse(BaseModel):
class CreateToolRequest(BaseModel):
name: str = Field(..., description="The name of the function.")
json_schema: dict = Field(..., description="JSON schema of the tool.")
source_code: str = Field(..., description="The source code of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
@@ -31,18 +31,30 @@ class CreateToolResponse(BaseModel):
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool(
tool_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name, user_id=user_id)
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool(
tool_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
tool = server.ms.get_tool(tool_name=tool_name)
tool = server.ms.get_tool(tool_name=tool_name, user_id=user_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
@@ -50,15 +62,33 @@ def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterfac
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
tools = server.ms.list_tools()
tools = server.ms.list_tools(user_id=user_id)
return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool(
request: CreateToolRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Create a new tool
"""
try:
return server.create_tool(
json_schema=request.json_schema,
source_code=request.source_code,
source_type=request.source_type,
tags=request.tags,
user_id=user_id,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
return router

View File

@@ -336,7 +336,7 @@ class SyncServer(LockingServer):
# Instantiate an agent object using the state retrieved
logger.info(f"Creating an agent object")
tool_objs = [self.ms.get_tool(name) for name in agent_state.tools] # get tool objects
tool_objs = [self.ms.get_tool(name, user_id) for name in agent_state.tools] # get tool objects
memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
# Add the agent to the in-memory store and return its reference
@@ -763,7 +763,7 @@ class SyncServer(LockingServer):
# get tools
tool_objs = []
for tool_name in tools:
tool_obj = self.ms.get_tool(tool_name)
tool_obj = self.ms.get_tool(tool_name, user_id=user_id)
assert tool_obj is not None, f"Tool {tool_name} does not exist"
tool_objs.append(tool_obj)
@@ -1487,7 +1487,13 @@ class SyncServer(LockingServer):
return sources_with_metadata
def create_tool(
self, json_schema: dict, source_code: str, source_type: str, tags: Optional[List[str]] = None, exists_ok: Optional[bool] = True
self,
json_schema: dict,
source_code: str,
source_type: str,
tags: Optional[List[str]] = None,
exists_ok: Optional[bool] = True,
user_id: Optional[uuid.UUID] = None,
) -> ToolModel: # TODO: add other fields
"""Create a new tool
@@ -1511,10 +1517,12 @@ class SyncServer(LockingServer):
raise ValueError(f"Tool with name {name} already exists.")
else:
# create new tool
tool = ToolModel(name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type)
tool = ToolModel(
name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type, user_id=user_id
)
self.ms.add_tool(tool)
return self.ms.get_tool(name)
return self.ms.get_tool(name, user_id=user_id)
def delete_tool(self, name: str):
"""Delete a tool"""

View File

@@ -68,11 +68,12 @@ def run_server():
# Fixture to create clients with different configurations
@pytest.fixture(
# params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented
params=[{"server": True}], # whether to use REST API server # TODO: add when implemented
params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented
# params=[{"server": True}], # whether to use REST API server # TODO: add when implemented
scope="module",
)
def client(request):
def admin_client(request):
if request.param["server"]:
# get URL from enviornment
server_url = os.getenv("MEMGPT_SERVER_URL")
@@ -92,11 +93,19 @@ def client(request):
admin._reset_server()
else:
print("Testing local client")
# use local client (no server)
token = None
server_url = None
client = create_client(base_url=server_url, token=token) # This yields control back to the test function
yield None
@pytest.fixture(scope="module")
def client(admin_client):
if admin_client:
# create user via admin client
response = admin_client.create_user()
print("Created user", response.user_id, response.api_key)
client = create_client(base_url=admin_client.base_url, token=response.api_key)
yield client
else:
client = create_client()
yield client
@@ -124,6 +133,39 @@ def test_create_tool(client):
assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}"
print(f"Updated tools {[t.name for t in tools]}")
# check tool id
tool = client.get_tool(tool.name)
def test_create_agent_tool_admin(admin_client):
if admin_client is None:
return
def print_tool(message: str):
"""
Args:
message (str): The message to print.
Returns:
str: The message that was printed.
"""
print(message)
return message
tools = admin_client.list_tools()
print(f"Original tools {[t.name for t in tools]}")
tool = admin_client.create_tool(print_tool, tags=["extras"])
tools = admin_client.list_tools()
assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}"
print(f"Updated tools {[t.name for t in tools]}")
# check tool id
tool = admin_client.get_tool(tool.name)
assert tool.user_id is None, f"Expected {tool.user_id} to be None"
def test_create_agent_tool(client):
"""Test creation of a agent tool"""
@@ -144,22 +186,15 @@ def test_create_agent_tool(client):
return None
# TODO: test attaching and using function on agent
tool = client.create_tool(core_memory_clear, tags=["extras"])
tool = client.create_tool(core_memory_clear, tags=["extras"], update=True)
print(f"Created tool", tool.name)
if isinstance(client, Admin):
# conver to user client type
response = client.create_user()
print("Created user", response.user_id, response.api_key)
user_client = create_client(base_url=client.base_url, token=response.api_key)
else:
user_client = client
agent = user_client.create_agent(
name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you"
)
# create agent with tool
agent = client.create_agent(name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you")
assert str(tool.user_id) == str(agent.user_id), f"Expected {tool.user_id} to be {agent.user_id}"
# initial memory
initial_memory = user_client.get_agent_memory(agent.id)
initial_memory = client.get_agent_memory(agent.id)
print("initial memory", initial_memory)
human = initial_memory.core_memory.human
persona = initial_memory.core_memory.persona
@@ -168,11 +203,11 @@ def test_create_agent_tool(client):
assert len(persona) > 0, "Expected persona memory to be non-empty"
# test agent tool
response = user_client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool")
response = client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool")
print(response)
# updated memory
updated_memory = user_client.get_agent_memory(agent.id)
updated_memory = client.get_agent_memory(agent.id)
human = updated_memory.core_memory.human
persona = updated_memory.core_memory.persona
print("Updated memory:", human, persona)