feat: move tool functions to user (#1487)
This commit is contained in:
@@ -6,7 +6,11 @@ from requests import HTTPError
|
||||
|
||||
from memgpt.functions.functions import parse_source_code
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
from memgpt.server.rest_api.admin.tools import (
|
||||
CreateToolRequest,
|
||||
ListToolsResponse,
|
||||
ToolModel,
|
||||
)
|
||||
from memgpt.server.rest_api.admin.users import (
|
||||
CreateAPIKeyResponse,
|
||||
CreateUserResponse,
|
||||
@@ -15,7 +19,6 @@ from memgpt.server.rest_api.admin.users import (
|
||||
GetAllUsersResponse,
|
||||
GetAPIKeysResponse,
|
||||
)
|
||||
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
|
||||
|
||||
|
||||
class Admin:
|
||||
@@ -86,6 +89,7 @@ class Admin:
|
||||
self.delete_key(key)
|
||||
self.delete_user(user["user_id"])
|
||||
|
||||
# tools
|
||||
def create_tool(
|
||||
self,
|
||||
func,
|
||||
@@ -94,12 +98,10 @@ class Admin:
|
||||
tags: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create a tool
|
||||
|
||||
Args:
|
||||
func (callable): The function to create a tool for.
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
update (bool, optional): Update the tool if it already exists. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tool object
|
||||
"""
|
||||
@@ -110,11 +112,11 @@ class Admin:
|
||||
source_code = parse_source_code(func)
|
||||
json_schema = generate_schema(func, name)
|
||||
source_type = "python"
|
||||
tool_name = json_schema["name"]
|
||||
json_schema["name"]
|
||||
|
||||
# create data
|
||||
data = {"name": tool_name, "source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
|
||||
CreateToolRequest(**data) # validate data:w
|
||||
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
|
||||
CreateToolRequest(**data) # validate
|
||||
|
||||
# make REST request
|
||||
response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers)
|
||||
|
||||
@@ -28,8 +28,6 @@ from memgpt.models.pydantic_models import (
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
)
|
||||
|
||||
# import pydantic response objects from memgpt.server.rest_api
|
||||
from memgpt.server.rest_api.agents.command import CommandResponse
|
||||
from memgpt.server.rest_api.agents.config import GetAgentResponse
|
||||
from memgpt.server.rest_api.agents.index import CreateAgentResponse, ListAgentsResponse
|
||||
@@ -54,6 +52,9 @@ from memgpt.server.rest_api.presets.index import (
|
||||
ListPresetsResponse,
|
||||
)
|
||||
from memgpt.server.rest_api.sources.index import ListSourcesResponse
|
||||
|
||||
# import pydantic response objects from memgpt.server.rest_api
|
||||
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
|
||||
@@ -235,8 +236,6 @@ class RESTClient(AbstractClient):
|
||||
self.base_url = base_url
|
||||
self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"}
|
||||
|
||||
# agents
|
||||
|
||||
def list_agents(self):
|
||||
response = requests.get(f"{self.base_url}/api/agents", headers=self.headers)
|
||||
return ListAgentsResponse(**response.json())
|
||||
@@ -610,6 +609,67 @@ class RESTClient(AbstractClient):
|
||||
response = requests.get(f"{self.base_url}/api/config", headers=self.headers)
|
||||
return ConfigResponse(**response.json())
|
||||
|
||||
# tools
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
func,
|
||||
name: Optional[str] = None,
|
||||
update: Optional[bool] = True, # TODO: actually use this
|
||||
tags: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create a tool
|
||||
|
||||
Args:
|
||||
func (callable): The function to create a tool for.
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
update (bool, optional): Update the tool if it already exists. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tool object
|
||||
"""
|
||||
|
||||
# TODO: check if tool already exists
|
||||
# TODO: how to load modules?
|
||||
# parse source code/schema
|
||||
source_code = parse_source_code(func)
|
||||
json_schema = generate_schema(func, name)
|
||||
source_type = "python"
|
||||
json_schema["name"]
|
||||
|
||||
# create data
|
||||
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
|
||||
try:
|
||||
CreateToolRequest(**data) # validate data
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create tool: {e}, invalid input {data}")
|
||||
|
||||
# make REST request
|
||||
response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return ToolModel(**response.json())
|
||||
|
||||
def list_tools(self) -> ListToolsResponse:
|
||||
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list tools: {response.text}")
|
||||
return ListToolsResponse(**response.json()).tools
|
||||
|
||||
def delete_tool(self, name: str):
|
||||
response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete tool: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def get_tool(self, name: str):
|
||||
response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return ToolModel(**response.json())
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
def __init__(
|
||||
@@ -820,7 +880,7 @@ class LocalClient(AbstractClient):
|
||||
tool_name = json_schema["name"]
|
||||
|
||||
# check if already exists:
|
||||
existing_tool = self.server.ms.get_tool(tool_name)
|
||||
existing_tool = self.server.ms.get_tool(tool_name, self.user_id)
|
||||
if existing_tool:
|
||||
if update:
|
||||
# update existing tool
|
||||
@@ -829,13 +889,15 @@ class LocalClient(AbstractClient):
|
||||
existing_tool.tags = tags
|
||||
existing_tool.json_schema = json_schema
|
||||
self.server.ms.update_tool(existing_tool)
|
||||
return self.server.ms.get_tool(tool_name)
|
||||
return self.server.ms.get_tool(tool_name, self.user_id)
|
||||
else:
|
||||
raise ValueError(f"Tool {name} already exists and update=False")
|
||||
|
||||
tool = ToolModel(name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema)
|
||||
tool = ToolModel(
|
||||
name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema, user_id=self.user_id
|
||||
)
|
||||
self.server.ms.add_tool(tool)
|
||||
return self.server.ms.get_tool(tool_name)
|
||||
return self.server.ms.get_tool(tool_name, self.user_id)
|
||||
|
||||
def list_tools(self):
|
||||
"""List available tools.
|
||||
@@ -844,7 +906,13 @@ class LocalClient(AbstractClient):
|
||||
tools (List[ToolModel]): A list of available tools.
|
||||
|
||||
"""
|
||||
return self.server.ms.list_tools()
|
||||
return self.server.ms.list_tools(user_id=self.user_id)
|
||||
|
||||
def get_tool(self, name: str):
|
||||
return self.server.ms.get_tool(name, user_id=self.user_id)
|
||||
|
||||
def delete_tool(self, name: str):
|
||||
return self.server.ms.delete_tool(name, user_id=self.user_id)
|
||||
|
||||
# data sources
|
||||
|
||||
|
||||
@@ -604,9 +604,11 @@ class MetadataStore:
|
||||
|
||||
@enforce_types
|
||||
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||
def list_tools(self) -> List[ToolModel]:
|
||||
def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(ToolModel).all()
|
||||
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
@@ -677,10 +679,13 @@ class MetadataStore:
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_tool(self, tool_name: str) -> Optional[ToolModel]:
|
||||
def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]:
|
||||
# TODO: add user_id when tools can eventually be added by users
|
||||
with self.session_maker() as session:
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).all()
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
||||
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
@@ -752,6 +757,8 @@ class MetadataStore:
|
||||
@enforce_types
|
||||
def add_tool(self, tool: ToolModel):
|
||||
with self.session_maker() as session:
|
||||
if self.get_tool(tool.name, tool.user_id):
|
||||
raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}")
|
||||
session.add(tool)
|
||||
session.commit()
|
||||
|
||||
@@ -811,9 +818,9 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_tool(self, name: str):
|
||||
def delete_tool(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.name == name).delete()
|
||||
session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
# job related functions
|
||||
|
||||
@@ -56,8 +56,8 @@ class PresetModel(BaseModel):
|
||||
|
||||
class ToolModel(SQLModel, table=True):
|
||||
# TODO move into database
|
||||
name: str = Field(..., description="The name of the function.", primary_key=True)
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.")
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True)
|
||||
tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
source_code: Optional[str] = Field(..., description="The source code of the function.")
|
||||
@@ -65,6 +65,9 @@ class ToolModel(SQLModel, table=True):
|
||||
|
||||
json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")
|
||||
|
||||
# optional: user_id (user-specific tools)
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.")
|
||||
|
||||
# Needed for Column(JSON)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -39,7 +39,7 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
server.ms.delete_tool(name=tool_name)
|
||||
server.ms.delete_tool(name=tool_name, user_id=None)
|
||||
|
||||
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
|
||||
async def get_tool(tool_name: str):
|
||||
@@ -49,29 +49,26 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
tool = server.ms.get_tool(tool_name=tool_name)
|
||||
tool = server.ms.get_tool(tool_name=tool_name, user_id=None)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
return tool
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||
async def list_all_tools(
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
):
|
||||
async def list_all_tools():
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
|
||||
tools = server.ms.list_tools()
|
||||
tools = server.ms.list_tools(user_id=None)
|
||||
return ListToolsResponse(tools=tools)
|
||||
|
||||
@router.post("/tools", tags=["tools"], response_model=ToolModel)
|
||||
async def create_tool(
|
||||
request: CreateToolRequest = Body(...),
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from functools import partial
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
@@ -18,7 +18,7 @@ class ListToolsResponse(BaseModel):
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
json_schema: dict = Field(..., description="JSON schema of the tool.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
@@ -31,18 +31,30 @@ class CreateToolResponse(BaseModel):
|
||||
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.delete("/tools/{tool_name}", tags=["tools"])
|
||||
async def delete_tool(
|
||||
tool_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
server.ms.delete_tool(name=tool_name, user_id=user_id)
|
||||
|
||||
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
|
||||
async def get_tool(
|
||||
tool_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
tool = server.ms.get_tool(tool_name=tool_name)
|
||||
tool = server.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
@@ -50,15 +62,33 @@ def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterfac
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||
async def list_all_tools(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
|
||||
tools = server.ms.list_tools()
|
||||
tools = server.ms.list_tools(user_id=user_id)
|
||||
return ListToolsResponse(tools=tools)
|
||||
|
||||
@router.post("/tools", tags=["tools"], response_model=ToolModel)
|
||||
async def create_tool(
|
||||
request: CreateToolRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
return server.create_tool(
|
||||
json_schema=request.json_schema,
|
||||
source_code=request.source_code,
|
||||
source_type=request.source_type,
|
||||
tags=request.tags,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
|
||||
|
||||
return router
|
||||
|
||||
@@ -336,7 +336,7 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Instantiate an agent object using the state retrieved
|
||||
logger.info(f"Creating an agent object")
|
||||
tool_objs = [self.ms.get_tool(name) for name in agent_state.tools] # get tool objects
|
||||
tool_objs = [self.ms.get_tool(name, user_id) for name in agent_state.tools] # get tool objects
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
||||
|
||||
# Add the agent to the in-memory store and return its reference
|
||||
@@ -763,7 +763,7 @@ class SyncServer(LockingServer):
|
||||
# get tools
|
||||
tool_objs = []
|
||||
for tool_name in tools:
|
||||
tool_obj = self.ms.get_tool(tool_name)
|
||||
tool_obj = self.ms.get_tool(tool_name, user_id=user_id)
|
||||
assert tool_obj is not None, f"Tool {tool_name} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
@@ -1487,7 +1487,13 @@ class SyncServer(LockingServer):
|
||||
return sources_with_metadata
|
||||
|
||||
def create_tool(
|
||||
self, json_schema: dict, source_code: str, source_type: str, tags: Optional[List[str]] = None, exists_ok: Optional[bool] = True
|
||||
self,
|
||||
json_schema: dict,
|
||||
source_code: str,
|
||||
source_type: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
exists_ok: Optional[bool] = True,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
) -> ToolModel: # TODO: add other fields
|
||||
"""Create a new tool
|
||||
|
||||
@@ -1511,10 +1517,12 @@ class SyncServer(LockingServer):
|
||||
raise ValueError(f"Tool with name {name} already exists.")
|
||||
else:
|
||||
# create new tool
|
||||
tool = ToolModel(name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type)
|
||||
tool = ToolModel(
|
||||
name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type, user_id=user_id
|
||||
)
|
||||
self.ms.add_tool(tool)
|
||||
|
||||
return self.ms.get_tool(name)
|
||||
return self.ms.get_tool(name, user_id=user_id)
|
||||
|
||||
def delete_tool(self, name: str):
|
||||
"""Delete a tool"""
|
||||
|
||||
@@ -68,11 +68,12 @@ def run_server():
|
||||
|
||||
# Fixture to create clients with different configurations
|
||||
@pytest.fixture(
|
||||
# params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented
|
||||
params=[{"server": True}], # whether to use REST API server # TODO: add when implemented
|
||||
params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented
|
||||
# params=[{"server": True}], # whether to use REST API server # TODO: add when implemented
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
def admin_client(request):
|
||||
|
||||
if request.param["server"]:
|
||||
# get URL from enviornment
|
||||
server_url = os.getenv("MEMGPT_SERVER_URL")
|
||||
@@ -92,11 +93,19 @@ def client(request):
|
||||
|
||||
admin._reset_server()
|
||||
else:
|
||||
print("Testing local client")
|
||||
# use local client (no server)
|
||||
token = None
|
||||
server_url = None
|
||||
client = create_client(base_url=server_url, token=token) # This yields control back to the test function
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(admin_client):
|
||||
if admin_client:
|
||||
# create user via admin client
|
||||
response = admin_client.create_user()
|
||||
print("Created user", response.user_id, response.api_key)
|
||||
client = create_client(base_url=admin_client.base_url, token=response.api_key)
|
||||
yield client
|
||||
else:
|
||||
client = create_client()
|
||||
yield client
|
||||
|
||||
|
||||
@@ -124,6 +133,39 @@ def test_create_tool(client):
|
||||
assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}"
|
||||
print(f"Updated tools {[t.name for t in tools]}")
|
||||
|
||||
# check tool id
|
||||
tool = client.get_tool(tool.name)
|
||||
|
||||
|
||||
def test_create_agent_tool_admin(admin_client):
|
||||
if admin_client is None:
|
||||
return
|
||||
|
||||
def print_tool(message: str):
|
||||
"""
|
||||
Args:
|
||||
message (str): The message to print.
|
||||
|
||||
Returns:
|
||||
str: The message that was printed.
|
||||
|
||||
"""
|
||||
print(message)
|
||||
return message
|
||||
|
||||
tools = admin_client.list_tools()
|
||||
print(f"Original tools {[t.name for t in tools]}")
|
||||
|
||||
tool = admin_client.create_tool(print_tool, tags=["extras"])
|
||||
|
||||
tools = admin_client.list_tools()
|
||||
assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}"
|
||||
print(f"Updated tools {[t.name for t in tools]}")
|
||||
|
||||
# check tool id
|
||||
tool = admin_client.get_tool(tool.name)
|
||||
assert tool.user_id is None, f"Expected {tool.user_id} to be None"
|
||||
|
||||
|
||||
def test_create_agent_tool(client):
|
||||
"""Test creation of a agent tool"""
|
||||
@@ -144,22 +186,15 @@ def test_create_agent_tool(client):
|
||||
return None
|
||||
|
||||
# TODO: test attaching and using function on agent
|
||||
tool = client.create_tool(core_memory_clear, tags=["extras"])
|
||||
tool = client.create_tool(core_memory_clear, tags=["extras"], update=True)
|
||||
print(f"Created tool", tool.name)
|
||||
|
||||
if isinstance(client, Admin):
|
||||
# conver to user client type
|
||||
response = client.create_user()
|
||||
print("Created user", response.user_id, response.api_key)
|
||||
user_client = create_client(base_url=client.base_url, token=response.api_key)
|
||||
else:
|
||||
user_client = client
|
||||
|
||||
agent = user_client.create_agent(
|
||||
name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you"
|
||||
)
|
||||
# create agent with tool
|
||||
agent = client.create_agent(name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you")
|
||||
assert str(tool.user_id) == str(agent.user_id), f"Expected {tool.user_id} to be {agent.user_id}"
|
||||
|
||||
# initial memory
|
||||
initial_memory = user_client.get_agent_memory(agent.id)
|
||||
initial_memory = client.get_agent_memory(agent.id)
|
||||
print("initial memory", initial_memory)
|
||||
human = initial_memory.core_memory.human
|
||||
persona = initial_memory.core_memory.persona
|
||||
@@ -168,11 +203,11 @@ def test_create_agent_tool(client):
|
||||
assert len(persona) > 0, "Expected persona memory to be non-empty"
|
||||
|
||||
# test agent tool
|
||||
response = user_client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool")
|
||||
response = client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool")
|
||||
print(response)
|
||||
|
||||
# updated memory
|
||||
updated_memory = user_client.get_agent_memory(agent.id)
|
||||
updated_memory = client.get_agent_memory(agent.id)
|
||||
human = updated_memory.core_memory.human
|
||||
persona = updated_memory.core_memory.persona
|
||||
print("Updated memory:", human, persona)
|
||||
|
||||
Reference in New Issue
Block a user