diff --git a/memgpt/agent.py b/memgpt/agent.py index 1039cc54..357b6f96 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -1,8 +1,10 @@ import datetime import uuid import glob +import inspect import os import json +from pathlib import Path import traceback from memgpt.data_types import AgentState @@ -34,7 +36,7 @@ from memgpt.constants import ( CLI_WARNING_PREFIX, ) from .errors import LLMError -from .functions.functions import load_all_function_sets +from .functions.functions import USER_FUNCTIONS_DIR, load_all_function_sets def link_functions(function_schemas): @@ -679,6 +681,46 @@ class Agent(object): return agent_state + def add_function(self, function_name: str) -> str: + if function_name in self.functions_python.keys(): + msg = f"Function {function_name} already loaded" + printd(msg) + return msg + + available_functions = load_all_function_sets() + if function_name not in available_functions.keys(): + raise ValueError(f"Function {function_name} not found in function library") + + self.functions.append(available_functions[function_name]["json_schema"]) + self.functions_python[function_name] = available_functions[function_name]["python_function"] + + msg = f"Added function {function_name}" + self.save() + printd(msg) + return msg + + def remove_function(self, function_name: str) -> str: + if function_name not in self.functions_python.keys(): + msg = f"Function {function_name} not loaded, ignoring" + printd(msg) + return msg + + # only allow removal of user defined functions + user_func_path = Path(USER_FUNCTIONS_DIR) + func_path = Path(inspect.getfile(self.functions_python[function_name])) + is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts + + if not is_subpath: + raise ValueError(f"Function {function_name} is not user defined and cannot be removed") + + self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name] + self.functions_python.pop(function_name) + + msg = f"Removed function {function_name}" + self.save() + printd(msg) + return msg + def save(self): """Save agent state locally""" diff --git a/memgpt/functions/functions.py b/memgpt/functions/functions.py index af1ced56..03a6bdff 100644 --- a/memgpt/functions/functions.py +++ b/memgpt/functions/functions.py @@ -7,7 +7,9 @@ import sys from memgpt.functions.schema_generator import generate_schema from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX -sys.path.append(os.path.join(MEMGPT_DIR, "functions")) +USER_FUNCTIONS_DIR = os.path.join(MEMGPT_DIR, "functions") + +sys.path.append(USER_FUNCTIONS_DIR) def load_function_set(module): @@ -42,24 +44,23 @@ def load_all_function_sets(merge=True): example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"] # ~/.memgpt/functions/*.py - user_scripts_dir = os.path.join(MEMGPT_DIR, "functions") # create if missing - if not os.path.exists(user_scripts_dir): - os.makedirs(user_scripts_dir) - user_module_files = [f for f in os.listdir(user_scripts_dir) if f.endswith(".py") and f != "__init__.py"] + if not os.path.exists(USER_FUNCTIONS_DIR): + os.makedirs(USER_FUNCTIONS_DIR) + user_module_files = [f for f in os.listdir(USER_FUNCTIONS_DIR) if f.endswith(".py") and f != "__init__.py"] # combine them both (pull from both examples and user-provided) # all_module_files = example_module_files + user_module_files # Add user_scripts_dir to sys.path - if user_scripts_dir not in sys.path: - sys.path.append(user_scripts_dir) + if USER_FUNCTIONS_DIR not in sys.path: + sys.path.append(USER_FUNCTIONS_DIR) schemas_and_functions = {} - for dir_path, module_files in [(function_sets_dir, example_module_files), (user_scripts_dir, user_module_files)]: + for dir_path, module_files in [(function_sets_dir, example_module_files), (USER_FUNCTIONS_DIR, user_module_files)]: for file in module_files: module_name = file[:-3] # Remove '.py' from filename - if dir_path == user_scripts_dir: + if dir_path == USER_FUNCTIONS_DIR: # For user scripts, adjust the module name appropriately module_full_path = os.path.join(dir_path, file) try: diff --git a/memgpt/main.py b/memgpt/main.py index dd0edca5..44dd1456 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -220,6 +220,45 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, no_verify=False, c ) continue + elif user_input.lower().startswith("/add_function"): + try: + if len(user_input) < len("/add_function "): + print("Missing function name after the command") + continue + function_name = user_input[len("/add_function ") :].strip() + result = memgpt_agent.add_function(function_name) + typer.secho( + f"/add_function succeeded: {result}", + fg=typer.colors.GREEN, + bold=True, + ) + except ValueError as e: + typer.secho( + f"/add_function failed:\n{e}", + fg=typer.colors.RED, + bold=True, + ) + continue + elif user_input.lower().startswith("/remove_function"): + try: + if len(user_input) < len("/remove_function "): + print("Missing function name after the command") + continue + function_name = user_input[len("/remove_function ") :].strip() + result = memgpt_agent.remove_function(function_name) + typer.secho( + f"/remove_function succeeded: {result}", + fg=typer.colors.GREEN, + bold=True, + ) + except ValueError as e: + typer.secho( + f"/remove_function failed:\n{e}", + fg=typer.colors.RED, + bold=True, + ) + continue + # No skip options elif user_input.lower() == "/wipe": memgpt_agent = agent.Agent(interface) diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py new file mode 100644 index 00000000..e42d7fc5 --- /dev/null +++ b/tests/test_agent_function_update.py @@ -0,0 +1,117 @@ +from collections import UserDict +import json +import os +import inspect +from memgpt import MemGPT +from memgpt import constants +import memgpt.functions.function_sets.base as base_functions +from memgpt.functions.functions import USER_FUNCTIONS_DIR + +from tests.utils import wipe_config + +import pytest + + +def hello_world(self) -> str: + """Test function for agent to gain access to + + Returns: + str: A message for the world + """ + return "hello, world!" + + +@pytest.fixture(scope="module") +def agent(): + """Create a test agent that we can call functions on""" + wipe_config() + global client + if os.getenv("OPENAI_API_KEY"): + client = MemGPT(quickstart="openai") + else: + client = MemGPT(quickstart="memgpt_hosted") + + agent_state = client.create_agent( + agent_config={ + # "name": test_agent_id, + "persona": constants.DEFAULT_PERSONA, + "human": constants.DEFAULT_HUMAN, + } + ) + + return client.server._get_or_load_agent(user_id="NULL", agent_id=agent_state.id) + + +@pytest.fixture(scope="module") +def hello_world_function(): + with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w") as f: + f.write(inspect.getsource(hello_world)) + + +@pytest.fixture(scope="module") +def ai_function_call(): + class AiFunctionCall(UserDict): + def content(self): + return self.data["content"] + + return AiFunctionCall( + { + "content": "I will now call hello world", + "function_call": { + "name": "hello_world", + "arguments": json.dumps({}), + }, + } + ) + + +def test_add_function_happy(agent, hello_world_function, ai_function_call): + agent.add_function("hello_world") + + assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" in agent.functions_python.keys() + + msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call) + content = json.loads(msgs[-1]["content"]) + assert content["message"] == "hello, world!" + assert content["status"] == "OK" + assert not function_failed + + +def test_add_function_already_loaded(agent, hello_world_function): + agent.add_function("hello_world") + # no exception for duplicate loading + agent.add_function("hello_world") + + +def test_add_function_not_exist(agent): + # pytest assert exception + with pytest.raises(ValueError): + agent.add_function("non_existent") + + +def test_remove_function_happy(agent, hello_world_function): + agent.add_function("hello_world") + + # ensure function is loaded + assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" in agent.functions_python.keys() + + agent.remove_function("hello_world") + + assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" not in agent.functions_python.keys() + + +def test_remove_function_not_exist(agent): + # do not raise error + agent.remove_function("non_existent") + + +def test_remove_base_function_fails(agent): + with pytest.raises(ValueError): + agent.remove_function("send_message") + + +if __name__ == "__main__": + pytest.main(["-vv", os.path.abspath(__file__)])