feat: add_function and remove_function commands (#784)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import datetime
|
||||
import uuid
|
||||
import glob
|
||||
import inspect
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
|
||||
from memgpt.data_types import AgentState
|
||||
@@ -34,7 +36,7 @@ from memgpt.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
)
|
||||
from .errors import LLMError
|
||||
from .functions.functions import load_all_function_sets
|
||||
from .functions.functions import USER_FUNCTIONS_DIR, load_all_function_sets
|
||||
|
||||
|
||||
def link_functions(function_schemas):
|
||||
@@ -679,6 +681,46 @@ class Agent(object):
|
||||
|
||||
return agent_state
|
||||
|
||||
def add_function(self, function_name: str) -> str:
|
||||
if function_name in self.functions_python.keys():
|
||||
msg = f"Function {function_name} already loaded"
|
||||
printd(msg)
|
||||
return msg
|
||||
|
||||
available_functions = load_all_function_sets()
|
||||
if function_name not in available_functions.keys():
|
||||
raise ValueError(f"Function {function_name} not found in function library")
|
||||
|
||||
self.functions.append(available_functions[function_name]["json_schema"])
|
||||
self.functions_python[function_name] = available_functions[function_name]["python_function"]
|
||||
|
||||
msg = f"Added function {function_name}"
|
||||
self.save()
|
||||
printd(msg)
|
||||
return msg
|
||||
|
||||
def remove_function(self, function_name: str) -> str:
|
||||
if function_name not in self.functions_python.keys():
|
||||
msg = f"Function {function_name} not loaded, ignoring"
|
||||
printd(msg)
|
||||
return msg
|
||||
|
||||
# only allow removal of user defined functions
|
||||
user_func_path = Path(USER_FUNCTIONS_DIR)
|
||||
func_path = Path(inspect.getfile(self.functions_python[function_name]))
|
||||
is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts
|
||||
|
||||
if not is_subpath:
|
||||
raise ValueError(f"Function {function_name} is not user defined and cannot be removed")
|
||||
|
||||
self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name]
|
||||
self.functions_python.pop(function_name)
|
||||
|
||||
msg = f"Removed function {function_name}"
|
||||
self.save()
|
||||
printd(msg)
|
||||
return msg
|
||||
|
||||
def save(self):
|
||||
"""Save agent state locally"""
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@ import sys
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX
|
||||
|
||||
sys.path.append(os.path.join(MEMGPT_DIR, "functions"))
|
||||
USER_FUNCTIONS_DIR = os.path.join(MEMGPT_DIR, "functions")
|
||||
|
||||
sys.path.append(USER_FUNCTIONS_DIR)
|
||||
|
||||
|
||||
def load_function_set(module):
|
||||
@@ -42,24 +44,23 @@ def load_all_function_sets(merge=True):
|
||||
example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"]
|
||||
|
||||
# ~/.memgpt/functions/*.py
|
||||
user_scripts_dir = os.path.join(MEMGPT_DIR, "functions")
|
||||
# create if missing
|
||||
if not os.path.exists(user_scripts_dir):
|
||||
os.makedirs(user_scripts_dir)
|
||||
user_module_files = [f for f in os.listdir(user_scripts_dir) if f.endswith(".py") and f != "__init__.py"]
|
||||
if not os.path.exists(USER_FUNCTIONS_DIR):
|
||||
os.makedirs(USER_FUNCTIONS_DIR)
|
||||
user_module_files = [f for f in os.listdir(USER_FUNCTIONS_DIR) if f.endswith(".py") and f != "__init__.py"]
|
||||
|
||||
# combine them both (pull from both examples and user-provided)
|
||||
# all_module_files = example_module_files + user_module_files
|
||||
|
||||
# Add user_scripts_dir to sys.path
|
||||
if user_scripts_dir not in sys.path:
|
||||
sys.path.append(user_scripts_dir)
|
||||
if USER_FUNCTIONS_DIR not in sys.path:
|
||||
sys.path.append(USER_FUNCTIONS_DIR)
|
||||
|
||||
schemas_and_functions = {}
|
||||
for dir_path, module_files in [(function_sets_dir, example_module_files), (user_scripts_dir, user_module_files)]:
|
||||
for dir_path, module_files in [(function_sets_dir, example_module_files), (USER_FUNCTIONS_DIR, user_module_files)]:
|
||||
for file in module_files:
|
||||
module_name = file[:-3] # Remove '.py' from filename
|
||||
if dir_path == user_scripts_dir:
|
||||
if dir_path == USER_FUNCTIONS_DIR:
|
||||
# For user scripts, adjust the module name appropriately
|
||||
module_full_path = os.path.join(dir_path, file)
|
||||
try:
|
||||
|
||||
@@ -220,6 +220,45 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, no_verify=False, c
|
||||
)
|
||||
continue
|
||||
|
||||
elif user_input.lower().startswith("/add_function"):
|
||||
try:
|
||||
if len(user_input) < len("/add_function "):
|
||||
print("Missing function name after the command")
|
||||
continue
|
||||
function_name = user_input[len("/add_function ") :].strip()
|
||||
result = memgpt_agent.add_function(function_name)
|
||||
typer.secho(
|
||||
f"/add_function succeeded: {result}",
|
||||
fg=typer.colors.GREEN,
|
||||
bold=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
typer.secho(
|
||||
f"/add_function failed:\n{e}",
|
||||
fg=typer.colors.RED,
|
||||
bold=True,
|
||||
)
|
||||
continue
|
||||
elif user_input.lower().startswith("/remove_function"):
|
||||
try:
|
||||
if len(user_input) < len("/remove_function "):
|
||||
print("Missing function name after the command")
|
||||
continue
|
||||
function_name = user_input[len("/remove_function ") :].strip()
|
||||
result = memgpt_agent.remove_function(function_name)
|
||||
typer.secho(
|
||||
f"/remove_function succeeded: {result}",
|
||||
fg=typer.colors.GREEN,
|
||||
bold=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
typer.secho(
|
||||
f"/remove_function failed:\n{e}",
|
||||
fg=typer.colors.RED,
|
||||
bold=True,
|
||||
)
|
||||
continue
|
||||
|
||||
# No skip options
|
||||
elif user_input.lower() == "/wipe":
|
||||
memgpt_agent = agent.Agent(interface)
|
||||
|
||||
117
tests/test_agent_function_update.py
Normal file
117
tests/test_agent_function_update.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from collections import UserDict
|
||||
import json
|
||||
import os
|
||||
import inspect
|
||||
from memgpt import MemGPT
|
||||
from memgpt import constants
|
||||
import memgpt.functions.function_sets.base as base_functions
|
||||
from memgpt.functions.functions import USER_FUNCTIONS_DIR
|
||||
|
||||
from tests.utils import wipe_config
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def hello_world(self) -> str:
|
||||
"""Test function for agent to gain access to
|
||||
|
||||
Returns:
|
||||
str: A message for the world
|
||||
"""
|
||||
return "hello, world!"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent():
|
||||
"""Create a test agent that we can call functions on"""
|
||||
wipe_config()
|
||||
global client
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
client = MemGPT(quickstart="openai")
|
||||
else:
|
||||
client = MemGPT(quickstart="memgpt_hosted")
|
||||
|
||||
agent_state = client.create_agent(
|
||||
agent_config={
|
||||
# "name": test_agent_id,
|
||||
"persona": constants.DEFAULT_PERSONA,
|
||||
"human": constants.DEFAULT_HUMAN,
|
||||
}
|
||||
)
|
||||
|
||||
return client.server._get_or_load_agent(user_id="NULL", agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hello_world_function():
|
||||
with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w") as f:
|
||||
f.write(inspect.getsource(hello_world))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ai_function_call():
|
||||
class AiFunctionCall(UserDict):
|
||||
def content(self):
|
||||
return self.data["content"]
|
||||
|
||||
return AiFunctionCall(
|
||||
{
|
||||
"content": "I will now call hello world",
|
||||
"function_call": {
|
||||
"name": "hello_world",
|
||||
"arguments": json.dumps({}),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_add_function_happy(agent, hello_world_function, ai_function_call):
|
||||
agent.add_function("hello_world")
|
||||
|
||||
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" in agent.functions_python.keys()
|
||||
|
||||
msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call)
|
||||
content = json.loads(msgs[-1]["content"])
|
||||
assert content["message"] == "hello, world!"
|
||||
assert content["status"] == "OK"
|
||||
assert not function_failed
|
||||
|
||||
|
||||
def test_add_function_already_loaded(agent, hello_world_function):
|
||||
agent.add_function("hello_world")
|
||||
# no exception for duplicate loading
|
||||
agent.add_function("hello_world")
|
||||
|
||||
|
||||
def test_add_function_not_exist(agent):
|
||||
# pytest assert exception
|
||||
with pytest.raises(ValueError):
|
||||
agent.add_function("non_existent")
|
||||
|
||||
|
||||
def test_remove_function_happy(agent, hello_world_function):
|
||||
agent.add_function("hello_world")
|
||||
|
||||
# ensure function is loaded
|
||||
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" in agent.functions_python.keys()
|
||||
|
||||
agent.remove_function("hello_world")
|
||||
|
||||
assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" not in agent.functions_python.keys()
|
||||
|
||||
|
||||
def test_remove_function_not_exist(agent):
|
||||
# do not raise error
|
||||
agent.remove_function("non_existent")
|
||||
|
||||
|
||||
def test_remove_base_function_fails(agent):
|
||||
with pytest.raises(ValueError):
|
||||
agent.remove_function("send_message")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-vv", os.path.abspath(__file__)])
|
||||
Reference in New Issue
Block a user