feat: add_function and remove_function commands (#784)

This commit is contained in:
Charles Packer
2024-01-10 23:19:14 -08:00
committed by GitHub
4 changed files with 209 additions and 10 deletions

View File

@@ -1,8 +1,10 @@
import datetime
import uuid
import glob
import inspect
import os
import json
from pathlib import Path
import traceback
from memgpt.data_types import AgentState
@@ -34,7 +36,7 @@ from memgpt.constants import (
CLI_WARNING_PREFIX,
)
from .errors import LLMError
from .functions.functions import load_all_function_sets
from .functions.functions import USER_FUNCTIONS_DIR, load_all_function_sets
def link_functions(function_schemas):
@@ -679,6 +681,46 @@ class Agent(object):
return agent_state
def add_function(self, function_name: str) -> str:
if function_name in self.functions_python.keys():
msg = f"Function {function_name} already loaded"
printd(msg)
return msg
available_functions = load_all_function_sets()
if function_name not in available_functions.keys():
raise ValueError(f"Function {function_name} not found in function library")
self.functions.append(available_functions[function_name]["json_schema"])
self.functions_python[function_name] = available_functions[function_name]["python_function"]
msg = f"Added function {function_name}"
self.save()
printd(msg)
return msg
def remove_function(self, function_name: str) -> str:
if function_name not in self.functions_python.keys():
msg = f"Function {function_name} not loaded, ignoring"
printd(msg)
return msg
# only allow removal of user defined functions
user_func_path = Path(USER_FUNCTIONS_DIR)
func_path = Path(inspect.getfile(self.functions_python[function_name]))
is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts
if not is_subpath:
raise ValueError(f"Function {function_name} is not user defined and cannot be removed")
self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name]
self.functions_python.pop(function_name)
msg = f"Removed function {function_name}"
self.save()
printd(msg)
return msg
def save(self):
"""Save agent state locally"""

View File

@@ -7,7 +7,9 @@ import sys
from memgpt.functions.schema_generator import generate_schema
from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX
sys.path.append(os.path.join(MEMGPT_DIR, "functions"))
USER_FUNCTIONS_DIR = os.path.join(MEMGPT_DIR, "functions")
sys.path.append(USER_FUNCTIONS_DIR)
def load_function_set(module):
@@ -42,24 +44,23 @@ def load_all_function_sets(merge=True):
example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"]
# ~/.memgpt/functions/*.py
user_scripts_dir = os.path.join(MEMGPT_DIR, "functions")
# create if missing
if not os.path.exists(user_scripts_dir):
os.makedirs(user_scripts_dir)
user_module_files = [f for f in os.listdir(user_scripts_dir) if f.endswith(".py") and f != "__init__.py"]
if not os.path.exists(USER_FUNCTIONS_DIR):
os.makedirs(USER_FUNCTIONS_DIR)
user_module_files = [f for f in os.listdir(USER_FUNCTIONS_DIR) if f.endswith(".py") and f != "__init__.py"]
# combine them both (pull from both examples and user-provided)
# all_module_files = example_module_files + user_module_files
# Add user_scripts_dir to sys.path
if user_scripts_dir not in sys.path:
sys.path.append(user_scripts_dir)
if USER_FUNCTIONS_DIR not in sys.path:
sys.path.append(USER_FUNCTIONS_DIR)
schemas_and_functions = {}
for dir_path, module_files in [(function_sets_dir, example_module_files), (user_scripts_dir, user_module_files)]:
for dir_path, module_files in [(function_sets_dir, example_module_files), (USER_FUNCTIONS_DIR, user_module_files)]:
for file in module_files:
module_name = file[:-3] # Remove '.py' from filename
if dir_path == user_scripts_dir:
if dir_path == USER_FUNCTIONS_DIR:
# For user scripts, adjust the module name appropriately
module_full_path = os.path.join(dir_path, file)
try:

View File

@@ -220,6 +220,45 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, no_verify=False, c
)
continue
elif user_input.lower().startswith("/add_function"):
try:
if len(user_input) < len("/add_function "):
print("Missing function name after the command")
continue
function_name = user_input[len("/add_function ") :].strip()
result = memgpt_agent.add_function(function_name)
typer.secho(
f"/add_function succeeded: {result}",
fg=typer.colors.GREEN,
bold=True,
)
except ValueError as e:
typer.secho(
f"/add_function failed:\n{e}",
fg=typer.colors.RED,
bold=True,
)
continue
elif user_input.lower().startswith("/remove_function"):
try:
if len(user_input) < len("/remove_function "):
print("Missing function name after the command")
continue
function_name = user_input[len("/remove_function ") :].strip()
result = memgpt_agent.remove_function(function_name)
typer.secho(
f"/remove_function succeeded: {result}",
fg=typer.colors.GREEN,
bold=True,
)
except ValueError as e:
typer.secho(
f"/remove_function failed:\n{e}",
fg=typer.colors.RED,
bold=True,
)
continue
# No skip options
elif user_input.lower() == "/wipe":
memgpt_agent = agent.Agent(interface)

View File

@@ -0,0 +1,117 @@
from collections import UserDict
import json
import os
import inspect
from memgpt import MemGPT
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from memgpt.functions.functions import USER_FUNCTIONS_DIR
from tests.utils import wipe_config
import pytest
def hello_world(self) -> str:
"""Test function for agent to gain access to
Returns:
str: A message for the world
"""
return "hello, world!"
@pytest.fixture(scope="module")
def agent():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
if os.getenv("OPENAI_API_KEY"):
client = MemGPT(quickstart="openai")
else:
client = MemGPT(quickstart="memgpt_hosted")
agent_state = client.create_agent(
agent_config={
# "name": test_agent_id,
"persona": constants.DEFAULT_PERSONA,
"human": constants.DEFAULT_HUMAN,
}
)
return client.server._get_or_load_agent(user_id="NULL", agent_id=agent_state.id)
@pytest.fixture(scope="module")
def hello_world_function():
with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w") as f:
f.write(inspect.getsource(hello_world))
@pytest.fixture(scope="module")
def ai_function_call():
class AiFunctionCall(UserDict):
def content(self):
return self.data["content"]
return AiFunctionCall(
{
"content": "I will now call hello world",
"function_call": {
"name": "hello_world",
"arguments": json.dumps({}),
},
}
)
def test_add_function_happy(agent, hello_world_function, ai_function_call):
agent.add_function("hello_world")
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" in agent.functions_python.keys()
msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call)
content = json.loads(msgs[-1]["content"])
assert content["message"] == "hello, world!"
assert content["status"] == "OK"
assert not function_failed
def test_add_function_already_loaded(agent, hello_world_function):
agent.add_function("hello_world")
# no exception for duplicate loading
agent.add_function("hello_world")
def test_add_function_not_exist(agent):
# pytest assert exception
with pytest.raises(ValueError):
agent.add_function("non_existent")
def test_remove_function_happy(agent, hello_world_function):
agent.add_function("hello_world")
# ensure function is loaded
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" in agent.functions_python.keys()
agent.remove_function("hello_world")
assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" not in agent.functions_python.keys()
def test_remove_function_not_exist(agent):
# do not raise error
agent.remove_function("non_existent")
def test_remove_base_function_fails(agent):
with pytest.raises(ValueError):
agent.remove_function("send_message")
if __name__ == "__main__":
pytest.main(["-vv", os.path.abspath(__file__)])