diff --git a/memgpt/agent.py b/memgpt/agent.py index 8e5c17a8..e78e2b48 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -1,23 +1,18 @@ -import inspect import datetime import glob -import math import os -import requests import json import traceback from memgpt.persistence_manager import LocalStateManager from memgpt.config import AgentConfig -from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages +from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from .memory import CoreMemory as Memory, summarize_messages from .openai_tools import completions_with_backoff as create -from .utils import get_local_time, parse_json, united_diff, printd, count_tokens +from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff from .constants import ( FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, - MESSAGE_CHATGPT_FUNCTION_MODEL, - MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_SUMMARY_WARNING_FRAC, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, @@ -25,6 +20,7 @@ from .constants import ( CORE_MEMORY_PERSONA_CHAR_LIMIT, ) from .errors import LLMError +from .functions.functions import load_all_function_sets def initialize_memory(ai_notes, human_notes): @@ -136,7 +132,7 @@ class Agent(object): config, model, system, - functions, + functions, # list of [{'schema': 'x', 'python_function': function_pointer}, ...] interface, persistence_manager, persona_notes, @@ -151,8 +147,18 @@ class Agent(object): self.model = model # Store the system instructions (used to rebuild memory) self.system = system - # Store the functions spec - self.functions = functions + + # Available functions is a mapping from: + # function_name -> { + # json_schema: schema + # python_function: function + # } + # Store the functions schemas (this is passed as an argument to ChatCompletion) + functions_schema = [f_dict["json_schema"] for f_name, f_dict in functions.items()] + self.functions = functions_schema + # Store references to the python objects + self.functions_python = {f_name: f_dict["python_function"] for f_name, f_dict in functions.items()} + # Initialize the memory object self.memory = initialize_memory(persona_notes, human_notes) # Once the memory object is initialize, use it to "bake" the system message @@ -196,34 +202,6 @@ class Agent(object): # When the summarizer is run, set this back to False (to reset) self.agent_alerted_about_memory_pressure = False - self.init_avail_functions() - - def init_avail_functions(self): - """ - Allows subclasses to overwrite this dictionary with overriden methods. - """ - self.available_functions = { - # These functions aren't all visible to the LLM - # To see what functions the LLM sees, check self.functions - "send_message": self.send_ai_message, - "edit_memory": self.edit_memory, - "edit_memory_append": self.edit_memory_append, - "edit_memory_replace": self.edit_memory_replace, - "pause_heartbeats": self.pause_heartbeats, - "core_memory_append": self.edit_memory_append, - "core_memory_replace": self.edit_memory_replace, - "recall_memory_search": self.recall_memory_search, - "recall_memory_search_date": self.recall_memory_search_date, - "conversation_search": self.recall_memory_search, - "conversation_search_date": self.recall_memory_search_date, - "archival_memory_insert": self.archival_memory_insert, - "archival_memory_search": self.archival_memory_search, - # extras - "read_from_text_file": self.read_from_text_file, - "append_to_text_file": self.append_to_text_file, - "http_request": self.http_request, - } - @property def messages(self): return self._messages @@ -331,7 +309,7 @@ class Agent(object): json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. if not json_files: print(f"/load error: no .json checkpoint files found") - raise ValueError(f"Cannot load {agent_name}") + raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}") # Sort files based on modified timestamp, with the latest file being the first. filename = max(json_files, key=os.path.getmtime) @@ -343,12 +321,54 @@ class Agent(object): printd(f"Loading persistence manager from {os.path.join(directory, filename)}") persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config) + # need to dynamically link the functions + # the saved agent.functions will just have the schemas, but we need to + # go through the functions library and pull the respective python functions + + # Available functions is a mapping from: + # function_name -> { + # json_schema: schema + # python_function: function + # } + # agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create) + # [{'name': ..., 'description': ...}, {...}] + available_functions = load_all_function_sets() + linked_function_set = {} + for f_schema in state["functions"]: + # Attempt to find the function in the existing function library + f_name = f_schema.get("name") + if f_name is None: + raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}") + linked_function = available_functions.get(f_name) + if linked_function is None: + raise ValueError( + f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}" + ) + # Once we find a matching function, make sure the schema is identical + if json.dumps(f_schema) != json.dumps(linked_function["json_schema"]): + # error_message = ( + # f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different." + # + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2)}" + # + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2)}" + # ) + schema_diff = get_schema_diff(f_schema, linked_function["json_schema"]) + error_message = ( + f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n" + + "".join(schema_diff) + ) + + # NOTE to handle old configs, instead of erroring here let's just warn + # raise ValueError(error_message) + print(error_message) + linked_function_set[f_name] = linked_function + messages = state["messages"] agent = cls( config=agent_config, model=state["model"], system=state["system"], - functions=state["functions"], + # functions=state["functions"], + functions=linked_function_set, interface=interface, persistence_manager=persistence_manager, persistence_manager_init=False, @@ -479,7 +499,7 @@ class Agent(object): # Failure case 1: function name is wrong function_name = response_message["function_call"]["name"] try: - function_to_call = self.available_functions[function_name] + function_to_call = self.functions_python[function_name] except KeyError as e: error_msg = f"No function named {function_name}" function_response = package_function_response(False, error_msg) @@ -522,6 +542,7 @@ class Agent(object): # Failure case 3: function failed during execution self.interface.function_message(f"Running {function_name}({function_args})") try: + function_args["self"] = self # need to attach self to arg since it's dynamically linked function_response_string = function_to_call(**function_args) function_response = package_function_response(True, function_response_string) function_failed = False @@ -731,159 +752,6 @@ class Agent(object): printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}") - def send_ai_message(self, message): - """AI wanted to send a message""" - self.interface.assistant_message(message) - return None - - def edit_memory(self, name, content): - """Edit memory.name <= content""" - new_len = self.memory.edit(name, content) - self.rebuild_memory() - return None - - def edit_memory_append(self, name, content): - new_len = self.memory.edit_append(name, content) - self.rebuild_memory() - return None - - def edit_memory_replace(self, name, old_content, new_content): - new_len = self.memory.edit_replace(name, old_content, new_content) - self.rebuild_memory() - return None - - def recall_memory_search(self, query, count=5, page=0): - results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted)}" - return results_str - - def recall_memory_search_date(self, start_date, end_date, count=5, page=0): - results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted)}" - return results_str - - def archival_memory_insert(self, content): - self.persistence_manager.archival_memory.insert(content) - return None - - def archival_memory_search(self, query, count=5, page=0): - results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted)}" - return results_str - - def message_chatgpt(self, message): - """Base call to GPT API w/ functions""" - - message_sequence = [ - {"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE}, - {"role": "user", "content": str(message)}, - ] - response = create( - model=MESSAGE_CHATGPT_FUNCTION_MODEL, - messages=message_sequence, - # functions=functions, - # function_call=function_call, - ) - - reply = response.choices[0].message.content - return reply - - def read_from_text_file(self, filename, line_start, num_lines=1, max_chars=500, trunc_message=True): - if not os.path.exists(filename): - raise FileNotFoundError(f"The file '{filename}' does not exist.") - - if line_start < 1 or num_lines < 1: - raise ValueError("Both line_start and num_lines must be positive integers.") - - lines = [] - chars_read = 0 - with open(filename, "r") as file: - for current_line_number, line in enumerate(file, start=1): - if line_start <= current_line_number < line_start + num_lines: - chars_to_add = len(line) - if max_chars is not None and chars_read + chars_to_add > max_chars: - # If adding this line exceeds MAX_CHARS, truncate the line if needed and stop reading further. - excess_chars = (chars_read + chars_to_add) - max_chars - lines.append(line[:-excess_chars].rstrip("\n")) - if trunc_message: - lines.append(f"[SYSTEM ALERT - max chars ({max_chars}) reached during file read]") - break - else: - lines.append(line.rstrip("\n")) - chars_read += chars_to_add - if current_line_number >= line_start + num_lines - 1: - break - - return "\n".join(lines) - - def append_to_text_file(self, filename, content): - if not os.path.exists(filename): - raise FileNotFoundError(f"The file '{filename}' does not exist.") - - with open(filename, "a") as file: - file.write(content + "\n") - - def http_request(self, method, url, payload_json=None): - """ - Makes an HTTP request based on the specified method, URL, and JSON payload. - - Args: - method (str): The HTTP method (e.g., 'GET', 'POST'). - url (str): The URL for the request. - payload_json (str): A JSON string representing the request payload. - - Returns: - dict: The response from the HTTP request. - """ - try: - headers = {"Content-Type": "application/json"} - - # For GET requests, ignore the payload - if method.upper() == "GET": - print(f"[HTTP] launching GET request to {url}") - response = requests.get(url, headers=headers) - else: - # Validate and convert the payload for other types of requests - if payload_json: - payload = json.loads(payload_json) - else: - payload = {} - print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2)}") - response = requests.request(method, url, json=payload, headers=headers) - - return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text} - except Exception as e: - return {"error": str(e)} - - def pause_heartbeats(self, minutes, max_pause=MAX_PAUSE_HEARTBEATS): - """Pause timed heartbeats for N minutes""" - minutes = min(max_pause, minutes) - - # Record the current time - self.pause_heartbeats_start = datetime.datetime.now() - # And record how long the pause should go for - self.pause_heartbeats_minutes = int(minutes) - - return f"Pausing timed heartbeats for {minutes} min" - def heartbeat_is_paused(self): """Check if there's a requested pause on timed heartbeats""" diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index c0bf57cd..f4c59ce3 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -7,7 +7,7 @@ from memgpt.autogen.interface import AutoGenInterface from memgpt.persistence_manager import LocalStateManager import memgpt.system as system import memgpt.constants as constants -import memgpt.presets as presets +import memgpt.presets.presets as presets from memgpt.personas import personas from memgpt.humans import humans from memgpt.config import AgentConfig diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 482ad4f6..43bcbeea 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -14,7 +14,7 @@ import memgpt.interface # for printing to terminal from memgpt.cli.cli_config import configure import memgpt.agent as agent import memgpt.system as system -import memgpt.presets as presets +import memgpt.presets.presets as presets import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 132d0527..8053779b 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -23,7 +23,7 @@ app = typer.Typer() def configure(): """Updates default MemGPT configurations""" - from memgpt.presets import DEFAULT_PRESET, preset_options + from memgpt.presets.presets import DEFAULT_PRESET, preset_options MemGPTConfig.create_config_dir() diff --git a/memgpt/config.py b/memgpt/config.py index fdc00f74..c57ed4b6 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -1,4 +1,5 @@ import glob +import inspect import random import string import json @@ -23,7 +24,7 @@ from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans -from memgpt.presets import DEFAULT_PRESET, preset_options +from memgpt.presets.presets import DEFAULT_PRESET, preset_options model_choices = [ @@ -243,7 +244,7 @@ class MemGPTConfig: if not os.path.exists(MEMGPT_DIR): os.makedirs(MEMGPT_DIR, exist_ok=True) - folders = ["personas", "humans", "archival", "agents"] + folders = ["personas", "humans", "archival", "agents", "functions", "system_prompts", "presets"] for folder in folders: if not os.path.exists(os.path.join(MEMGPT_DIR, folder)): os.makedirs(os.path.join(MEMGPT_DIR, folder)) @@ -339,6 +340,15 @@ class AgentConfig: assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}" with open(agent_config_path, "r") as f: agent_config = json.load(f) + + # allow compatibility accross versions + class_args = inspect.getargspec(cls.__init__).args + agent_fields = list(agent_config.keys()) + for key in agent_fields: + if key not in class_args: + utils.printd(f"Removing missing argument {key} from agent config") + del agent_config[key] + return cls(**agent_config) diff --git a/memgpt/constants.py b/memgpt/constants.py index fc345404..6fb35a62 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -60,4 +60,7 @@ MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep you REQ_HEARTBEAT_MESSAGE = "request_heartbeat == true" FUNC_FAILED_HEARTBEAT_MESSAGE = "Function call failed" +FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat" +FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean" FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function." +RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5 diff --git a/memgpt/functions/__init__.py b/memgpt/functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py new file mode 100644 index 00000000..424017d0 --- /dev/null +++ b/memgpt/functions/function_sets/base.py @@ -0,0 +1,168 @@ +from typing import Optional +import datetime +import os +import json +import math + +from ...constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + +### Functions / tools the agent can use +# All functions should return a response string (or None) +# If the function fails, throw an exception + + +def send_message(self, message: str): + """ + Sends a message to the human user. + + Args: + message (str): Message contents. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + self.interface.assistant_message(message) + return None + + +# Construct the docstring dynamically (since it should use the external constants) +pause_heartbeats_docstring = f""" +Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events. + +Args: + minutes (int): Number of minutes to ignore heartbeats for. Max value of {MAX_PAUSE_HEARTBEATS} minutes ({MAX_PAUSE_HEARTBEATS // 60} hours). + +Returns: + str: Function status response +""" + + +def pause_heartbeats(self, minutes: int): + minutes = min(MAX_PAUSE_HEARTBEATS, minutes) + + # Record the current time + self.pause_heartbeats_start = datetime.datetime.now() + # And record how long the pause should go for + self.pause_heartbeats_minutes = int(minutes) + + return f"Pausing timed heartbeats for {minutes} min" + + +pause_heartbeats.__doc__ = pause_heartbeats_docstring + + +def core_memory_append(self, name: str, content: str): + """ + Append to the contents of core memory. + + Args: + name (str): Section of the memory to be edited (persona or human). + content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + new_len = self.memory.edit_append(name, content) + self.rebuild_memory() + return None + + +def core_memory_replace(self, name: str, old_content: str, new_content: str): + """ + Replace to the contents of core memory. To delete memories, use an empty string for new_content. + + Args: + name (str): Section of the memory to be edited (persona or human). + old_content (str): String to replace. Must be an exact match. + new_content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + new_len = self.memory.edit_replace(name, old_content, new_content) + self.rebuild_memory() + return None + + +def conversation_search(self, query: str, page: Optional[int] = 0): + """ + Search prior conversation history using case-insensitive string matching. + + Args: + query (str): String to search for. + page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + + Returns: + str: Query result string + """ + count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count) + num_pages = math.ceil(total / count) - 1 # 0 index + if len(results) == 0: + results_str = f"No results found." + else: + results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" + results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] + results_str = f"{results_pref} {json.dumps(results_formatted)}" + return results_str + + +def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0): + """ + Search prior conversation history using a date range. + + Args: + start_date (str): The start of the date range to search, in the format 'YYYY-MM-DD'. + end_date (str): The end of the date range to search, in the format 'YYYY-MM-DD'. + page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + + Returns: + str: Query result string + """ + count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count) + num_pages = math.ceil(total / count) - 1 # 0 index + if len(results) == 0: + results_str = f"No results found." + else: + results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" + results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] + results_str = f"{results_pref} {json.dumps(results_formatted)}" + return results_str + + +def archival_memory_insert(self, content: str): + """ + Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. + + Args: + content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + self.persistence_manager.archival_memory.insert(content) + return None + + +def archival_memory_search(self, query: str, page: Optional[int] = 0): + """ + Search archival memory using semantic (embedding-based) search. + + Args: + query (str): String to search for. + page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + + Returns: + str: Query result string + """ + count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count) + num_pages = math.ceil(total / count) - 1 # 0 index + if len(results) == 0: + results_str = f"No results found." + else: + results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" + results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results] + results_str = f"{results_pref} {json.dumps(results_formatted)}" + return results_str diff --git a/memgpt/functions/function_sets/extras.py b/memgpt/functions/function_sets/extras.py new file mode 100644 index 00000000..86883e3e --- /dev/null +++ b/memgpt/functions/function_sets/extras.py @@ -0,0 +1,126 @@ +from typing import Optional +import os +import json +import requests + + +from ...constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS +from ...openai_tools import completions_with_backoff as create + + +def message_chatgpt(self, message: str): + """ + Send a message to a more basic AI, ChatGPT. A useful resource for asking questions. ChatGPT does not retain memory of previous interactions. + + Args: + message (str): Message to send ChatGPT. Phrase your message as a full English sentence. + + Returns: + str: Reply message from ChatGPT + """ + message_sequence = [ + {"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE}, + {"role": "user", "content": str(message)}, + ] + response = create( + model=MESSAGE_CHATGPT_FUNCTION_MODEL, + messages=message_sequence, + # functions=functions, + # function_call=function_call, + ) + + reply = response.choices[0].message.content + return reply + + +def read_from_text_file(self, filename: str, line_start: int, num_lines: Optional[int] = 1): + """ + Read lines from a text file. + + Args: + filename (str): The name of the file to read. + line_start (int): Line to start reading from. + num_lines (Optional[int]): How many lines to read (defaults to 1). + + Returns: + str: Text read from the file + """ + max_chars = 500 + trunc_message = True + if not os.path.exists(filename): + raise FileNotFoundError(f"The file '{filename}' does not exist.") + + if line_start < 1 or num_lines < 1: + raise ValueError("Both line_start and num_lines must be positive integers.") + + lines = [] + chars_read = 0 + with open(filename, "r") as file: + for current_line_number, line in enumerate(file, start=1): + if line_start <= current_line_number < line_start + num_lines: + chars_to_add = len(line) + if max_chars is not None and chars_read + chars_to_add > max_chars: + # If adding this line exceeds MAX_CHARS, truncate the line if needed and stop reading further. + excess_chars = (chars_read + chars_to_add) - max_chars + lines.append(line[:-excess_chars].rstrip("\n")) + if trunc_message: + lines.append(f"[SYSTEM ALERT - max chars ({max_chars}) reached during file read]") + break + else: + lines.append(line.rstrip("\n")) + chars_read += chars_to_add + if current_line_number >= line_start + num_lines - 1: + break + + return "\n".join(lines) + + +def append_to_text_file(self, filename: str, content: str): + """ + Append to a text file. + + Args: + filename (str): The name of the file to append to. + content (str): Content to append to the file. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + if not os.path.exists(filename): + raise FileNotFoundError(f"The file '{filename}' does not exist.") + + with open(filename, "a") as file: + file.write(content + "\n") + + +def http_request(self, method: str, url: str, payload_json: Optional[str] = None): + """ + Generates an HTTP request and returns the response. + + Args: + method (str): The HTTP method (e.g., 'GET', 'POST'). + url (str): The URL for the request. + payload_json (Optional[str]): A JSON string representing the request payload. + + Returns: + dict: The response from the HTTP request. + """ + try: + headers = {"Content-Type": "application/json"} + + # For GET requests, ignore the payload + if method.upper() == "GET": + print(f"[HTTP] launching GET request to {url}") + response = requests.get(url, headers=headers) + else: + # Validate and convert the payload for other types of requests + if payload_json: + payload = json.loads(payload_json) + else: + payload = {} + print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2)}") + response = requests.request(method, url, json=payload, headers=headers) + + return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text} + except Exception as e: + return {"error": str(e)} diff --git a/memgpt/functions/functions.py b/memgpt/functions/functions.py new file mode 100644 index 00000000..f5ab8317 --- /dev/null +++ b/memgpt/functions/functions.py @@ -0,0 +1,77 @@ +import importlib +import inspect +import os + + +from memgpt.functions.schema_generator import generate_schema +from memgpt.constants import MEMGPT_DIR + + +def load_function_set(set_name): + """Load the functions and generate schema for them""" + function_dict = {} + + module_name = f"memgpt.functions.function_sets.{set_name}" + base_functions = importlib.import_module(module_name) + + for attr_name in dir(base_functions): + # Get the attribute + attr = getattr(base_functions, attr_name) + + # Check if it's a callable function and not a built-in or special method + if inspect.isfunction(attr) and attr.__module__ == base_functions.__name__: + if attr_name in function_dict: + raise ValueError(f"Found a duplicate of function name '{attr_name}'") + + generated_schema = generate_schema(attr) + function_dict[attr_name] = { + "python_function": attr, + "json_schema": generated_schema, + } + + if len(function_dict) == 0: + raise ValueError(f"No functions found in module {module_name}") + return function_dict + + +def load_all_function_sets(merge=True): + # functions/examples/*.py + scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script + function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory + # List all .py files in the directory (excluding __init__.py) + example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"] + + # ~/.memgpt/functions/*.py + user_scripts_dir = os.path.join(MEMGPT_DIR, "functions") + # create if missing + if not os.path.exists(user_scripts_dir): + os.makedirs(user_scripts_dir) + user_module_files = [f for f in os.listdir(user_scripts_dir) if f.endswith(".py") and f != "__init__.py"] + + # combine them both (pull from both examples and user-provided) + all_module_files = example_module_files + user_module_files + + schemas_and_functions = {} + for file in all_module_files: + # Convert filename to module name + module_name = f"memgpt.functions.function_sets.{file[:-3]}" # Remove '.py' from filename + + try: + # Load the function set + function_set = load_function_set(file[:-3]) # Pass the module part of the name + schemas_and_functions[module_name] = function_set + except ValueError as e: + print(f"Error loading function set '{module_name}': {e}") + + if merge: + # Put all functions from all sets into the same level dict + merged_functions = {} + for set_name, function_set in schemas_and_functions.items(): + for function_name, function_info in function_set.items(): + if function_name in merged_functions: + raise ValueError(f"Duplicate function name '{function_name}' found in function set '{set_name}'") + merged_functions[function_name] = function_info + return merged_functions + else: + # Nested dict where the top level is organized by the function set name + return schemas_and_functions diff --git a/memgpt/functions/schema_generator.py b/memgpt/functions/schema_generator.py new file mode 100644 index 00000000..1a111da2 --- /dev/null +++ b/memgpt/functions/schema_generator.py @@ -0,0 +1,104 @@ +import inspect +import typing +from typing import get_args + +from docstring_parser import parse + +from memgpt.constants import FUNCTION_PARAM_NAME_REQ_HEARTBEAT, FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT + +NO_HEARTBEAT_FUNCTIONS = ["send_message", "pause_heartbeats"] + + +def is_optional(annotation): + # Check if the annotation is a Union + if getattr(annotation, "__origin__", None) is typing.Union: + # Check if None is one of the options in the Union + return type(None) in annotation.__args__ + return False + + +def optional_length(annotation): + if is_optional(annotation): + # Subtract 1 to account for NoneType + return len(annotation.__args__) - 1 + else: + raise ValueError("The annotation is not an Optional type") + + +def type_to_json_schema_type(py_type): + """ + Maps a Python type to a JSON schema type. + Specifically handles typing.Optional and common Python types. + """ + # if get_origin(py_type) is typing.Optional: + if is_optional(py_type): + # Assert that Optional has only one type argument + type_args = get_args(py_type) + assert optional_length(py_type) == 1, f"Optional type must have exactly one type argument, but got {py_type}" + + # Extract and map the inner type + return type_to_json_schema_type(type_args[0]) + + # Mapping of Python types to JSON schema types + type_map = { + int: "integer", + str: "string", + bool: "boolean", + float: "number", + # Add more mappings as needed + } + if py_type not in type_map: + raise ValueError(f"Python type {py_type} has no corresponding JSON schema type") + + return type_map.get(py_type, "string") # Default to "string" if type not in map + + +def generate_schema(function): + # Get the signature of the function + sig = inspect.signature(function) + + # Parse the docstring + docstring = parse(function.__doc__) + + # Prepare the schema dictionary + schema = { + "name": function.__name__, + "description": docstring.short_description, + "parameters": {"type": "object", "properties": {}, "required": []}, + } + + for param in sig.parameters.values(): + # Exclude 'self' parameter + if param.name == "self": + continue + + # Assert that the parameter has a type annotation + if param.annotation == inspect.Parameter.empty: + raise TypeError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a type annotation") + + # Find the parameter's description in the docstring + param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) + + # Assert that the parameter has a description + if not param_doc or not param_doc.description: + raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring") + + # Add parameter details to the schema + param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) + schema["parameters"]["properties"][param.name] = { + # "type": "string" if param.annotation == str else str(param.annotation), + "type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string", + "description": param_doc.description, + } + if param.default == inspect.Parameter.empty: + schema["parameters"]["required"].append(param.name) + + # append the heartbeat + if function.__name__ not in NO_HEARTBEAT_FUNCTIONS: + schema["parameters"]["properties"][FUNCTION_PARAM_NAME_REQ_HEARTBEAT] = { + "type": FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, + "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, + } + schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT) + + return schema diff --git a/memgpt/main.py b/memgpt/main.py index 25e209e6..d6b1b53a 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -22,7 +22,7 @@ import memgpt.interface # for printing to terminal import memgpt.agent as agent import memgpt.system as system import memgpt.utils as utils -import memgpt.presets as presets +import memgpt.presets.presets as presets import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans diff --git a/memgpt/presets.py b/memgpt/presets.py deleted file mode 100644 index 85d44016..00000000 --- a/memgpt/presets.py +++ /dev/null @@ -1,84 +0,0 @@ -from .prompts import gpt_functions -from .prompts import gpt_system - -DEFAULT_PRESET = "memgpt_chat" -preset_options = [DEFAULT_PRESET] - - -def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): - """Storing combinations of SYSTEM + FUNCTION prompts""" - - from memgpt.agent import Agent - from memgpt.utils import printd - - if preset_name == DEFAULT_PRESET: - functions = [ - "send_message", - "pause_heartbeats", - "core_memory_append", - "core_memory_replace", - "conversation_search", - "conversation_search_date", - "archival_memory_insert", - "archival_memory_search", - ] - available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions] - printd(f"Available functions:\n", [x["name"] for x in available_functions]) - assert len(functions) == len(available_functions) - - if "gpt-3.5" in model: - # use a different system message for gpt-3.5 - preset_name = "memgpt_gpt35_extralong" - - return Agent( - config=agent_config, - model=model, - system=gpt_system.get_system_text(preset_name), - functions=available_functions, - interface=interface, - persistence_manager=persistence_manager, - persona_notes=persona, - human_notes=human, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if "gpt-4" in model else False, - ) - - elif preset_name == "memgpt_extras": - functions = [ - "send_message", - "pause_heartbeats", - "core_memory_append", - "core_memory_replace", - "conversation_search", - "conversation_search_date", - "archival_memory_insert", - "archival_memory_search", - # extra for read/write to files - "read_from_text_file", - "append_to_text_file", - # internet access - "http_request", - ] - available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions] - printd(f"Available functions:\n", [x["name"] for x in available_functions]) - assert len(functions) == len(available_functions) - - if "gpt-3.5" in model: - # use a different system message for gpt-3.5 - preset_name = "memgpt_gpt35_extralong" - - return Agent( - config=agent_config, - model=model, - system=gpt_system.get_system_text("memgpt_chat"), - functions=available_functions, - interface=interface, - persistence_manager=persistence_manager, - persona_notes=persona, - human_notes=human, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if "gpt-4" in model else False, - ) - - else: - raise ValueError(preset_name) diff --git a/memgpt/presets/examples/memgpt_chat.yaml b/memgpt/presets/examples/memgpt_chat.yaml new file mode 100644 index 00000000..4cbd1c93 --- /dev/null +++ b/memgpt/presets/examples/memgpt_chat.yaml @@ -0,0 +1,10 @@ +system_prompt: "memgpt_chat" +functions: + - "send_message" + - "pause_heartbeats" + - "core_memory_append" + - "core_memory_replace" + - "conversation_search" + - "conversation_search_date" + - "archival_memory_insert" + - "archival_memory_search" diff --git a/memgpt/presets/examples/memgpt_docs.yaml b/memgpt/presets/examples/memgpt_docs.yaml new file mode 100644 index 00000000..0ee1ccb7 --- /dev/null +++ b/memgpt/presets/examples/memgpt_docs.yaml @@ -0,0 +1,10 @@ +system_prompt: "memgpt_doc" +functions: + - "send_message" + - "pause_heartbeats" + - "core_memory_append" + - "core_memory_replace" + - "conversation_search" + - "conversation_search_date" + - "archival_memory_insert" + - "archival_memory_search" diff --git a/memgpt/presets/examples/memgpt_extras.yaml b/memgpt/presets/examples/memgpt_extras.yaml new file mode 100644 index 00000000..d28072cb --- /dev/null +++ b/memgpt/presets/examples/memgpt_extras.yaml @@ -0,0 +1,15 @@ +system_prompt: "memgpt_chat" +functions: + - "send_message" + - "pause_heartbeats" + - "core_memory_append" + - "core_memory_replace" + - "conversation_search" + - "conversation_search_date" + - "archival_memory_insert" + - "archival_memory_search" + # extras for read/write to files + - "read_from_text_file" + - "append_to_text_file" + # internet access + - "http_request" diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py new file mode 100644 index 00000000..8745aee5 --- /dev/null +++ b/memgpt/presets/presets.py @@ -0,0 +1,61 @@ +from .utils import load_all_presets, is_valid_yaml_format +from ..prompts import gpt_functions +from ..prompts import gpt_system +from ..functions.functions import load_all_function_sets + +DEFAULT_PRESET = "memgpt_chat" + +available_presets = load_all_presets() +preset_options = list(available_presets.keys()) + + +def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): + """Storing combinations of SYSTEM + FUNCTION prompts""" + + from memgpt.agent import Agent + from memgpt.utils import printd + + # Available functions is a mapping from: + # function_name -> { + # json_schema: schema + # python_function: function + # } + available_functions = load_all_function_sets() + + available_presets = load_all_presets() + if preset_name not in available_presets: + raise ValueError(f"Preset '{preset_name}.yaml' not found") + + preset = available_presets[preset_name] + if not is_valid_yaml_format(preset, list(available_functions.keys())): + raise ValueError(f"Preset '{preset_name}.yaml' is not valid") + + preset_system_prompt = preset["system_prompt"] + preset_function_set_names = preset["functions"] + + # Filter down the function set based on what the preset requested + preset_function_set = {} + for f_name in preset_function_set_names: + if f_name not in available_functions: + raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}") + preset_function_set[f_name] = available_functions[f_name] + assert len(preset_function_set_names) == len(preset_function_set) + printd(f"Available functions:\n", list(preset_function_set.keys())) + + # preset_function_set = {f_name: f_dict for f_name, f_dict in available_functions.items() if f_name in preset_function_set_names} + # printd(f"Available functions:\n", [f_name for f_name, f_dict in preset_function_set.items()]) + # Make sure that every function the preset wanted is inside the available functions + # assert len(preset_function_set_names) == len(preset_function_set) + + return Agent( + config=agent_config, + model=model, + system=gpt_system.get_system_text(preset_system_prompt), + functions=preset_function_set, + interface=interface, + persistence_manager=persistence_manager, + persona_notes=persona, + human_notes=human, + # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now + first_message_verify_mono=True if "gpt-4" in model else False, + ) diff --git a/memgpt/presets/utils.py b/memgpt/presets/utils.py new file mode 100644 index 00000000..0b7ccfe4 --- /dev/null +++ b/memgpt/presets/utils.py @@ -0,0 +1,76 @@ +import os +import glob +import yaml + +from memgpt.constants import MEMGPT_DIR + + +def is_valid_yaml_format(yaml_data, function_set): + """ + Check if the given YAML data follows the specified format and if all functions in the yaml are part of the function_set. + Raises ValueError if any check fails. + + :param yaml_data: The data loaded from a YAML file. + :param function_set: A set of valid function names. + """ + # Check for required keys + if not all(key in yaml_data for key in ["system_prompt", "functions"]): + raise ValueError("YAML data is missing one or more required keys: 'system_prompt', 'functions'.") + + # Check if 'functions' is a list of strings + if not all(isinstance(item, str) for item in yaml_data.get("functions", [])): + raise ValueError("'functions' should be a list of strings.") + + # Check if all functions in YAML are part of function_set + if not set(yaml_data["functions"]).issubset(function_set): + raise ValueError("Some functions in YAML are not part of the provided function set.") + + # If all checks pass + return True + + +def load_yaml_file(file_path): + """ + Load a YAML file and return the data. + + :param file_path: Path to the YAML file. + :return: Data from the YAML file. + """ + with open(file_path, "r") as file: + return yaml.safe_load(file) + + +def load_all_presets(): + """Load all the preset configs in the examples directory""" + + ## Load the examples + # Get the directory in which the script is located + script_directory = os.path.dirname(os.path.abspath(__file__)) + # Construct the path pattern + example_path_pattern = os.path.join(script_directory, "examples", "*.yaml") + # Listing all YAML files + example_yaml_files = glob.glob(example_path_pattern) + + ## Load the user-provided presets + # ~/.memgpt/presets/*.yaml + user_presets_dir = os.path.join(MEMGPT_DIR, "presets") + # Create directory if it doesn't exist + if not os.path.exists(user_presets_dir): + os.makedirs(user_presets_dir) + # Construct the path pattern + user_path_pattern = os.path.join(user_presets_dir, "*.yaml") + # Listing all YAML files + user_yaml_files = glob.glob(user_path_pattern) + + # Pull from both examplesa and user-provided + all_yaml_files = example_yaml_files + user_yaml_files + + # Loading and creating a mapping from file name to YAML data + all_yaml_data = {} + for file_path in all_yaml_files: + # Extracting the base file name without the '.yaml' extension + base_name = os.path.splitext(os.path.basename(file_path))[0] + data = load_yaml_file(file_path) + all_yaml_data[base_name] = data + + return all_yaml_data diff --git a/memgpt/prompts/gpt_functions.py b/memgpt/prompts/gpt_functions.py index 8a2160de..e6d853ec 100644 --- a/memgpt/prompts/gpt_functions.py +++ b/memgpt/prompts/gpt_functions.py @@ -1,10 +1,10 @@ -from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT +from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, MAX_PAUSE_HEARTBEATS # FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1] FUNCTIONS_CHAINING = { "send_message": { "name": "send_message", - "description": "Sends a message to the human user", + "description": "Sends a message to the human user.", "parameters": { "type": "object", "properties": { @@ -26,7 +26,7 @@ FUNCTIONS_CHAINING = { # https://json-schema.org/understanding-json-schema/reference/array.html "minutes": { "type": "integer", - "description": "Number of minutes to ignore heartbeats for. Max value of 360 minutes (6 hours).", + "description": f"Number of minutes to ignore heartbeats for. Max value of {MAX_PAUSE_HEARTBEATS} minutes ({MAX_PAUSE_HEARTBEATS//60} hours).", }, }, "required": ["minutes"], @@ -45,7 +45,7 @@ FUNCTIONS_CHAINING = { }, "request_heartbeat": { "type": "boolean", - "description": "Request an immediate heartbeat after function execution, use to chain multiple functions.", + "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, }, }, "required": ["message", "request_heartbeat"], @@ -67,7 +67,7 @@ FUNCTIONS_CHAINING = { }, "request_heartbeat": { "type": "boolean", - "description": "Request an immediate heartbeat after function execution, use to chain multiple functions.", + "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, }, }, "required": ["name", "content", "request_heartbeat"], @@ -93,7 +93,7 @@ FUNCTIONS_CHAINING = { }, "request_heartbeat": { "type": "boolean", - "description": "Request an immediate heartbeat after function execution, use to chain multiple functions.", + "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, }, }, "required": ["name", "old_content", "new_content", "request_heartbeat"], @@ -140,7 +140,7 @@ FUNCTIONS_CHAINING = { "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, }, }, - "required": ["query", "page", "request_heartbeat"], + "required": ["query", "request_heartbeat"], }, }, "recall_memory_search_date": { @@ -192,7 +192,7 @@ FUNCTIONS_CHAINING = { "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, }, }, - "required": ["start_date", "end_date", "page", "request_heartbeat"], + "required": ["start_date", "end_date", "request_heartbeat"], }, }, "archival_memory_insert": { @@ -232,7 +232,7 @@ FUNCTIONS_CHAINING = { "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, }, }, - "required": ["query", "page", "request_heartbeat"], + "required": ["query", "request_heartbeat"], }, }, "read_from_text_file": { @@ -269,7 +269,7 @@ FUNCTIONS_CHAINING = { "properties": { "filename": { "type": "string", - "description": "The name of the file to read.", + "description": "The name of the file to append to.", }, "content": { "type": "string", @@ -295,9 +295,9 @@ FUNCTIONS_CHAINING = { }, "url": { "type": "string", - "description": "The URL for the request", + "description": "The URL for the request.", }, - "payload": { + "payload_json": { "type": "string", "description": "A JSON string representing the request payload.", }, diff --git a/memgpt/prompts/gpt_system.py b/memgpt/prompts/gpt_system.py index 8100b6ee..aa78e66e 100644 --- a/memgpt/prompts/gpt_system.py +++ b/memgpt/prompts/gpt_system.py @@ -1,12 +1,26 @@ import os +from memgpt.constants import MEMGPT_DIR + def get_system_text(key): filename = f"{key}.txt" file_path = os.path.join(os.path.dirname(__file__), "system", filename) + # first look in prompts/system/*.txt if os.path.exists(file_path): with open(file_path, "r") as file: return file.read().strip() else: - raise FileNotFoundError(f"No file found for key {key}, path={file_path}") + # try looking in ~/.memgpt/system_prompts/*.txt + user_system_prompts_dir = os.path.join(MEMGPT_DIR, "system_prompts") + # create directory if it doesn't exist + if not os.path.exists(user_system_prompts_dir): + os.makedirs(user_system_prompts_dir) + # look inside for a matching system prompt + file_path = os.path.join(user_system_prompts_dir, filename) + if os.path.exists(file_path): + with open(file_path, "r") as file: + return file.read().strip() + else: + raise FileNotFoundError(f"No file found for key {key}, path={file_path}") diff --git a/memgpt/utils.py b/memgpt/utils.py index 955d86b9..ada29249 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -423,3 +423,17 @@ def get_human_text(name: str): file = os.path.basename(file_path) if f"{name}.txt" == file or name == file: return open(file_path, "r").read().strip() + + +def get_schema_diff(schema_a, schema_b): + # Assuming f_schema and linked_function['json_schema'] are your JSON schemas + f_schema_json = json.dumps(schema_a, indent=2) + linked_function_json = json.dumps(schema_b, indent=2) + + # Compute the difference using difflib + difference = list(difflib.ndiff(f_schema_json.splitlines(keepends=True), linked_function_json.splitlines(keepends=True))) + + # Filter out lines that don't represent changes + difference = [line for line in difference if line.startswith("+ ") or line.startswith("- ")] + + return "".join(difference) diff --git a/poetry.lock b/poetry.lock index b67863aa..6220ab9c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -534,6 +534,17 @@ files = [ {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, ] +[[package]] +name = "docstring-parser" +version = "0.15" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "docstring_parser-0.15-py3-none-any.whl", hash = "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9"}, + {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"}, +] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -2133,26 +2144,31 @@ python-versions = ">=3.8" files = [ {file = "PyMuPDF-1.23.6-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:c4eb71b88a22c1008f764b3121b36a9d25340f9920b870508356050a365d9ca1"}, {file = "PyMuPDF-1.23.6-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:3ce2d3678dbf822cff213b1902f2e59756313e543efd516a2b4f15bb0353bd6c"}, + {file = "PyMuPDF-1.23.6-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:2e27857a15c8a810d0b66455b8c8a79013640b6267a9b4ea808a5fe1f47711f2"}, {file = "PyMuPDF-1.23.6-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:5cd05700c8f18c9dafef63ac2ed3b1099ca06017ca0c32deea13093cea1b8671"}, {file = "PyMuPDF-1.23.6-cp310-none-win32.whl", hash = "sha256:951d280c1daafac2fd6a664b031f7f98b27eb2def55d39c92a19087bd8041c5d"}, {file = "PyMuPDF-1.23.6-cp310-none-win_amd64.whl", hash = "sha256:19d1711d5908c4527ad2deef5af2d066649f3f9a12950faf30be5f7251d18abc"}, {file = "PyMuPDF-1.23.6-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:3f0f9b76bc4f039e7587003cbd40684d93a98441549dd033cab38ca07d61988d"}, {file = "PyMuPDF-1.23.6-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e047571d799b30459ad7ee0bc6e68900a7f6b928876f956c976f279808814e72"}, + {file = "PyMuPDF-1.23.6-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:1cbcf05c06f314fdf3042ceee674e9a0ac7fae598347d5442e2138c6046d4e82"}, {file = "PyMuPDF-1.23.6-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:e33f8ec5ba7265fe78b30332840b8f454184addfa79f9c27f160f19789aa5ffd"}, {file = "PyMuPDF-1.23.6-cp311-none-win32.whl", hash = "sha256:2c141f33e2733e48de8524dfd2de56d889feef0c7773b20a8cd216c03ab24793"}, {file = "PyMuPDF-1.23.6-cp311-none-win_amd64.whl", hash = "sha256:8fd9c4ee1dd4744a515b9190d8ba9133348b0d94c362293ed77726aa1c13b0a6"}, {file = "PyMuPDF-1.23.6-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:4d06751d5cd213e96f84f2faaa71a51cf4d641851e07579247ca1190121f173b"}, {file = "PyMuPDF-1.23.6-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:526b26a5207e923aab65877ad305644402851823a352cb92d362053426899354"}, + {file = "PyMuPDF-1.23.6-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:0f852d125defc26716878b1796f4d68870e9065041d00cf46bde317fd8d30e68"}, {file = "PyMuPDF-1.23.6-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:5bdf7020b90987412381acc42427dd1b7a03d771ee9ec273de003e570164ec1a"}, {file = "PyMuPDF-1.23.6-cp312-none-win32.whl", hash = "sha256:e2d64799c6d9a3735be9e162a5d11061c0b7fbcb1e5fc7446e0993d0f815a93a"}, {file = "PyMuPDF-1.23.6-cp312-none-win_amd64.whl", hash = "sha256:c8ea81964c1433ea163ad4b53c56053a87a9ef6e1bd7a879d4d368a3988b60d1"}, {file = "PyMuPDF-1.23.6-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:761501a4965264e81acdd8f2224f993020bf24474e9b34fcdb5805a6826eda1c"}, {file = "PyMuPDF-1.23.6-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:fd8388e82b6045807d19addf310d8119d32908e89f76cc8bbf8cf1ec36fce947"}, + {file = "PyMuPDF-1.23.6-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:4ac9673a6d6ee7e80cb242dacb43f9ca097b502d9c5e44687dbdffc2bce7961a"}, {file = "PyMuPDF-1.23.6-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:6e319c1f49476e07b9a12017c2d031687617713f8a46b7adcec03c636ed04607"}, {file = "PyMuPDF-1.23.6-cp38-none-win32.whl", hash = "sha256:1103eea4ab727e32b9cb93347b35f71562033018c333a7f3a17d115e980fea4a"}, {file = "PyMuPDF-1.23.6-cp38-none-win_amd64.whl", hash = "sha256:991a37e1cba43775ce094da87cf0bf72172a5532a09644003276bc8bfdfe9f1a"}, {file = "PyMuPDF-1.23.6-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:57725e15872f7ab67a9fb3e06e5384d1047b2121e85755c93a6d4266d3ca8983"}, {file = "PyMuPDF-1.23.6-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:224c341fe254adda97c8f06a4c5838cdbcf609fa89e70b1fb179752533378f2f"}, + {file = "PyMuPDF-1.23.6-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:271bdf6059bb8347f9c9c6b721329bd353a933681b1fc62f43241b410e7ab7ae"}, {file = "PyMuPDF-1.23.6-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:57e22bea69690450197b34dcde16bd9fe0265ac4425b4033535ccc5c044246fb"}, {file = "PyMuPDF-1.23.6-cp39-none-win32.whl", hash = "sha256:2885a26220a32fb45ea443443b72194bb7107d6862d8d546b59e4ad0c8a1f2c9"}, {file = "PyMuPDF-1.23.6-cp39-none-win_amd64.whl", hash = "sha256:361cab1be45481bd3dc4e00ec82628ebc189b4f4b6fd9bd78a00cfeed54e0034"}, @@ -2171,6 +2187,7 @@ python-versions = ">=3.8" files = [ {file = "PyMuPDFb-1.23.6-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e5af77580aad3d1103aeec57009d156bfca429cecda14a17c573fcbe97bafb30"}, {file = "PyMuPDFb-1.23.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9925816cbe3e05e920f9be925e5752c2eef42b793885b62075bb0f6a69178598"}, + {file = "PyMuPDFb-1.23.6-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:009e2cff166059e13bf71f93919e688f46b8fc11d122433574cfb0cc9134690e"}, {file = "PyMuPDFb-1.23.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7132b30e6ad6ff2013344e3a481b2287fe0be3710d80694807dd6e0d8635f085"}, {file = "PyMuPDFb-1.23.6-py3-none-win32.whl", hash = "sha256:9d24ddadc204e895bee5000ddc7507c801643548e59f5a56aad6d32981d17eeb"}, {file = "PyMuPDFb-1.23.6-py3-none-win_amd64.whl", hash = "sha256:7bef75988e6979b10ca804cf9487f817aae43b0fff1c6e315b3b9ee0cf1cc32f"}, @@ -3510,4 +3527,4 @@ postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary" [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "32cc1809f381627327c0d8c2334bdee73c3653437fa2b8138df34953d3d2a200" +content-hash = "24e6c3cea1895441e07d362a5a2f9a07a045b92b5364531b8b6e3571904199fe" diff --git a/pyproject.toml b/pyproject.toml index 495fd2c4..2b65902c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ transformers = { version = "4.34.1", optional = true } pre-commit = {version = "^3.5.0", optional = true } pg8000 = {version = "^1.30.3", optional = true} torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true} +docstring-parser = "^0.15" [tool.poetry.extras] legacy = ["faiss-cpu", "numpy"] diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py new file mode 100644 index 00000000..a68d241b --- /dev/null +++ b/tests/test_schema_generator.py @@ -0,0 +1,109 @@ +import inspect + +import memgpt.functions.function_sets.base as base_functions +import memgpt.functions.function_sets.extras as extras_functions +from memgpt.prompts.gpt_functions import FUNCTIONS_CHAINING +from memgpt.functions.schema_generator import generate_schema + + +def send_message(self, message: str): + """ + Sends a message to the human user. + + Args: + message (str): Message contents. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + return None + + +def send_message_missing_types(self, message): + """ + Sends a message to the human user. + + Args: + message (str): Message contents. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + return None + + +def send_message_missing_docstring(self, message: str): + return None + + +def test_schema_generator(): + # Check that a basic function schema converts correctly + correct_schema = { + "name": "send_message", + "description": "Sends a message to the human user.", + "parameters": { + "type": "object", + "properties": {"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}}, + "required": ["message"], + }, + } + generated_schema = generate_schema(send_message) + print(f"\n\nreference_schema={correct_schema}") + print(f"\n\ngenerated_schema={generated_schema}") + assert correct_schema == generated_schema + + # Check that missing types results in an error + try: + _ = generate_schema(send_message_missing_types) + assert False + except: + pass + + # Check that missing docstring results in an error + try: + _ = generate_schema(send_message_missing_docstring) + assert False + except: + pass + + +def test_schema_generator_with_old_function_set(): + # Try all the base functions first + for attr_name in dir(base_functions): + # Get the attribute + attr = getattr(base_functions, attr_name) + + # Check if it's a callable function and not a built-in or special method + if inspect.isfunction(attr): + # Here, 'func' is each function in base_functions + # You can now call the function or do something with it + print("Function name:", attr) + # Example function call (if the function takes no arguments) + # result = func() + function_name = str(attr_name) + real_schema = FUNCTIONS_CHAINING[function_name] + generated_schema = generate_schema(attr) + print(f"\n\nreference_schema={real_schema}") + print(f"\n\ngenerated_schema={generated_schema}") + assert real_schema == generated_schema + + # Then try all the extras functions + for attr_name in dir(extras_functions): + # Get the attribute + attr = getattr(extras_functions, attr_name) + + # Check if it's a callable function and not a built-in or special method + if inspect.isfunction(attr): + if attr_name == "create": + continue + # Here, 'func' is each function in base_functions + # You can now call the function or do something with it + print("Function name:", attr) + # Example function call (if the function takes no arguments) + # result = func() + function_name = str(attr_name) + real_schema = FUNCTIONS_CHAINING[function_name] + generated_schema = generate_schema(attr) + print(f"\n\nreference_schema={real_schema}") + print(f"\n\ngenerated_schema={generated_schema}") + assert real_schema == generated_schema