Configurable presets to support easy extension of MemGPT's function set (#420)
* partial * working schema builder, tested that it matches the hand-written schemas * correct another schema diff * refactor * basic working test * refactored preset creation to use yaml files * added docstring-parser * add code for dynamic function linking in agent loading * pretty schema diff printer * support pulling from ~/.memgpt/functions/*.py * clean * allow looking for system prompts in ~/.memgpt/system_prompts * create ~/.memgpt/system_prompts if it doesn't exist * pull presets from ~/.memgpt/presets in addition to examples folder * add support for loading agent configs that have additional keys --------- Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
256
memgpt/agent.py
256
memgpt/agent.py
@@ -1,23 +1,18 @@
|
||||
import inspect
|
||||
import datetime
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.config import AgentConfig
|
||||
from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from .memory import CoreMemory as Memory, summarize_messages
|
||||
from .openai_tools import completions_with_backoff as create
|
||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
|
||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
|
||||
from .constants import (
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
MAX_PAUSE_HEARTBEATS,
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
@@ -25,6 +20,7 @@ from .constants import (
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
)
|
||||
from .errors import LLMError
|
||||
from .functions.functions import load_all_function_sets
|
||||
|
||||
|
||||
def initialize_memory(ai_notes, human_notes):
|
||||
@@ -136,7 +132,7 @@ class Agent(object):
|
||||
config,
|
||||
model,
|
||||
system,
|
||||
functions,
|
||||
functions, # list of [{'schema': 'x', 'python_function': function_pointer}, ...]
|
||||
interface,
|
||||
persistence_manager,
|
||||
persona_notes,
|
||||
@@ -151,8 +147,18 @@ class Agent(object):
|
||||
self.model = model
|
||||
# Store the system instructions (used to rebuild memory)
|
||||
self.system = system
|
||||
# Store the functions spec
|
||||
self.functions = functions
|
||||
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
# Store the functions schemas (this is passed as an argument to ChatCompletion)
|
||||
functions_schema = [f_dict["json_schema"] for f_name, f_dict in functions.items()]
|
||||
self.functions = functions_schema
|
||||
# Store references to the python objects
|
||||
self.functions_python = {f_name: f_dict["python_function"] for f_name, f_dict in functions.items()}
|
||||
|
||||
# Initialize the memory object
|
||||
self.memory = initialize_memory(persona_notes, human_notes)
|
||||
# Once the memory object is initialize, use it to "bake" the system message
|
||||
@@ -196,34 +202,6 @@ class Agent(object):
|
||||
# When the summarizer is run, set this back to False (to reset)
|
||||
self.agent_alerted_about_memory_pressure = False
|
||||
|
||||
self.init_avail_functions()
|
||||
|
||||
def init_avail_functions(self):
|
||||
"""
|
||||
Allows subclasses to overwrite this dictionary with overriden methods.
|
||||
"""
|
||||
self.available_functions = {
|
||||
# These functions aren't all visible to the LLM
|
||||
# To see what functions the LLM sees, check self.functions
|
||||
"send_message": self.send_ai_message,
|
||||
"edit_memory": self.edit_memory,
|
||||
"edit_memory_append": self.edit_memory_append,
|
||||
"edit_memory_replace": self.edit_memory_replace,
|
||||
"pause_heartbeats": self.pause_heartbeats,
|
||||
"core_memory_append": self.edit_memory_append,
|
||||
"core_memory_replace": self.edit_memory_replace,
|
||||
"recall_memory_search": self.recall_memory_search,
|
||||
"recall_memory_search_date": self.recall_memory_search_date,
|
||||
"conversation_search": self.recall_memory_search,
|
||||
"conversation_search_date": self.recall_memory_search_date,
|
||||
"archival_memory_insert": self.archival_memory_insert,
|
||||
"archival_memory_search": self.archival_memory_search,
|
||||
# extras
|
||||
"read_from_text_file": self.read_from_text_file,
|
||||
"append_to_text_file": self.append_to_text_file,
|
||||
"http_request": self.http_request,
|
||||
}
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
@@ -331,7 +309,7 @@ class Agent(object):
|
||||
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
|
||||
if not json_files:
|
||||
print(f"/load error: no .json checkpoint files found")
|
||||
raise ValueError(f"Cannot load {agent_name}")
|
||||
raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}")
|
||||
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
filename = max(json_files, key=os.path.getmtime)
|
||||
@@ -343,12 +321,54 @@ class Agent(object):
|
||||
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
|
||||
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)
|
||||
|
||||
# need to dynamically link the functions
|
||||
# the saved agent.functions will just have the schemas, but we need to
|
||||
# go through the functions library and pull the respective python functions
|
||||
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
# agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create)
|
||||
# [{'name': ..., 'description': ...}, {...}]
|
||||
available_functions = load_all_function_sets()
|
||||
linked_function_set = {}
|
||||
for f_schema in state["functions"]:
|
||||
# Attempt to find the function in the existing function library
|
||||
f_name = f_schema.get("name")
|
||||
if f_name is None:
|
||||
raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}")
|
||||
linked_function = available_functions.get(f_name)
|
||||
if linked_function is None:
|
||||
raise ValueError(
|
||||
f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}"
|
||||
)
|
||||
# Once we find a matching function, make sure the schema is identical
|
||||
if json.dumps(f_schema) != json.dumps(linked_function["json_schema"]):
|
||||
# error_message = (
|
||||
# f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different."
|
||||
# + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2)}"
|
||||
# + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2)}"
|
||||
# )
|
||||
schema_diff = get_schema_diff(f_schema, linked_function["json_schema"])
|
||||
error_message = (
|
||||
f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n"
|
||||
+ "".join(schema_diff)
|
||||
)
|
||||
|
||||
# NOTE to handle old configs, instead of erroring here let's just warn
|
||||
# raise ValueError(error_message)
|
||||
print(error_message)
|
||||
linked_function_set[f_name] = linked_function
|
||||
|
||||
messages = state["messages"]
|
||||
agent = cls(
|
||||
config=agent_config,
|
||||
model=state["model"],
|
||||
system=state["system"],
|
||||
functions=state["functions"],
|
||||
# functions=state["functions"],
|
||||
functions=linked_function_set,
|
||||
interface=interface,
|
||||
persistence_manager=persistence_manager,
|
||||
persistence_manager_init=False,
|
||||
@@ -479,7 +499,7 @@ class Agent(object):
|
||||
# Failure case 1: function name is wrong
|
||||
function_name = response_message["function_call"]["name"]
|
||||
try:
|
||||
function_to_call = self.available_functions[function_name]
|
||||
function_to_call = self.functions_python[function_name]
|
||||
except KeyError as e:
|
||||
error_msg = f"No function named {function_name}"
|
||||
function_response = package_function_response(False, error_msg)
|
||||
@@ -522,6 +542,7 @@ class Agent(object):
|
||||
# Failure case 3: function failed during execution
|
||||
self.interface.function_message(f"Running {function_name}({function_args})")
|
||||
try:
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response_string = function_to_call(**function_args)
|
||||
function_response = package_function_response(True, function_response_string)
|
||||
function_failed = False
|
||||
@@ -731,159 +752,6 @@ class Agent(object):
|
||||
|
||||
printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")
|
||||
|
||||
def send_ai_message(self, message):
|
||||
"""AI wanted to send a message"""
|
||||
self.interface.assistant_message(message)
|
||||
return None
|
||||
|
||||
def edit_memory(self, name, content):
|
||||
"""Edit memory.name <= content"""
|
||||
new_len = self.memory.edit(name, content)
|
||||
self.rebuild_memory()
|
||||
return None
|
||||
|
||||
def edit_memory_append(self, name, content):
|
||||
new_len = self.memory.edit_append(name, content)
|
||||
self.rebuild_memory()
|
||||
return None
|
||||
|
||||
def edit_memory_replace(self, name, old_content, new_content):
|
||||
new_len = self.memory.edit_replace(name, old_content, new_content)
|
||||
self.rebuild_memory()
|
||||
return None
|
||||
|
||||
def recall_memory_search(self, query, count=5, page=0):
|
||||
results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
|
||||
results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
def archival_memory_insert(self, content):
|
||||
self.persistence_manager.archival_memory.insert(content)
|
||||
return None
|
||||
|
||||
def archival_memory_search(self, query, count=5, page=0):
|
||||
results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
def message_chatgpt(self, message):
|
||||
"""Base call to GPT API w/ functions"""
|
||||
|
||||
message_sequence = [
|
||||
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
|
||||
{"role": "user", "content": str(message)},
|
||||
]
|
||||
response = create(
|
||||
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
messages=message_sequence,
|
||||
# functions=functions,
|
||||
# function_call=function_call,
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
return reply
|
||||
|
||||
def read_from_text_file(self, filename, line_start, num_lines=1, max_chars=500, trunc_message=True):
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError(f"The file '{filename}' does not exist.")
|
||||
|
||||
if line_start < 1 or num_lines < 1:
|
||||
raise ValueError("Both line_start and num_lines must be positive integers.")
|
||||
|
||||
lines = []
|
||||
chars_read = 0
|
||||
with open(filename, "r") as file:
|
||||
for current_line_number, line in enumerate(file, start=1):
|
||||
if line_start <= current_line_number < line_start + num_lines:
|
||||
chars_to_add = len(line)
|
||||
if max_chars is not None and chars_read + chars_to_add > max_chars:
|
||||
# If adding this line exceeds MAX_CHARS, truncate the line if needed and stop reading further.
|
||||
excess_chars = (chars_read + chars_to_add) - max_chars
|
||||
lines.append(line[:-excess_chars].rstrip("\n"))
|
||||
if trunc_message:
|
||||
lines.append(f"[SYSTEM ALERT - max chars ({max_chars}) reached during file read]")
|
||||
break
|
||||
else:
|
||||
lines.append(line.rstrip("\n"))
|
||||
chars_read += chars_to_add
|
||||
if current_line_number >= line_start + num_lines - 1:
|
||||
break
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def append_to_text_file(self, filename, content):
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError(f"The file '{filename}' does not exist.")
|
||||
|
||||
with open(filename, "a") as file:
|
||||
file.write(content + "\n")
|
||||
|
||||
def http_request(self, method, url, payload_json=None):
|
||||
"""
|
||||
Makes an HTTP request based on the specified method, URL, and JSON payload.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method (e.g., 'GET', 'POST').
|
||||
url (str): The URL for the request.
|
||||
payload_json (str): A JSON string representing the request payload.
|
||||
|
||||
Returns:
|
||||
dict: The response from the HTTP request.
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# For GET requests, ignore the payload
|
||||
if method.upper() == "GET":
|
||||
print(f"[HTTP] launching GET request to {url}")
|
||||
response = requests.get(url, headers=headers)
|
||||
else:
|
||||
# Validate and convert the payload for other types of requests
|
||||
if payload_json:
|
||||
payload = json.loads(payload_json)
|
||||
else:
|
||||
payload = {}
|
||||
print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2)}")
|
||||
response = requests.request(method, url, json=payload, headers=headers)
|
||||
|
||||
return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def pause_heartbeats(self, minutes, max_pause=MAX_PAUSE_HEARTBEATS):
|
||||
"""Pause timed heartbeats for N minutes"""
|
||||
minutes = min(max_pause, minutes)
|
||||
|
||||
# Record the current time
|
||||
self.pause_heartbeats_start = datetime.datetime.now()
|
||||
# And record how long the pause should go for
|
||||
self.pause_heartbeats_minutes = int(minutes)
|
||||
|
||||
return f"Pausing timed heartbeats for {minutes} min"
|
||||
|
||||
def heartbeat_is_paused(self):
|
||||
"""Check if there's a requested pause on timed heartbeats"""
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from memgpt.autogen.interface import AutoGenInterface
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
import memgpt.presets as presets
|
||||
import memgpt.presets.presets as presets
|
||||
from memgpt.personas import personas
|
||||
from memgpt.humans import humans
|
||||
from memgpt.config import AgentConfig
|
||||
|
||||
@@ -14,7 +14,7 @@ import memgpt.interface # for printing to terminal
|
||||
from memgpt.cli.cli_config import configure
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.presets as presets
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
|
||||
@@ -23,7 +23,7 @@ app = typer.Typer()
|
||||
def configure():
|
||||
"""Updates default MemGPT configurations"""
|
||||
|
||||
from memgpt.presets import DEFAULT_PRESET, preset_options
|
||||
from memgpt.presets.presets import DEFAULT_PRESET, preset_options
|
||||
|
||||
MemGPTConfig.create_config_dir()
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import glob
|
||||
import inspect
|
||||
import random
|
||||
import string
|
||||
import json
|
||||
@@ -23,7 +24,7 @@ from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
from memgpt.presets import DEFAULT_PRESET, preset_options
|
||||
from memgpt.presets.presets import DEFAULT_PRESET, preset_options
|
||||
|
||||
|
||||
model_choices = [
|
||||
@@ -243,7 +244,7 @@ class MemGPTConfig:
|
||||
if not os.path.exists(MEMGPT_DIR):
|
||||
os.makedirs(MEMGPT_DIR, exist_ok=True)
|
||||
|
||||
folders = ["personas", "humans", "archival", "agents"]
|
||||
folders = ["personas", "humans", "archival", "agents", "functions", "system_prompts", "presets"]
|
||||
for folder in folders:
|
||||
if not os.path.exists(os.path.join(MEMGPT_DIR, folder)):
|
||||
os.makedirs(os.path.join(MEMGPT_DIR, folder))
|
||||
@@ -339,6 +340,15 @@ class AgentConfig:
|
||||
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
|
||||
with open(agent_config_path, "r") as f:
|
||||
agent_config = json.load(f)
|
||||
|
||||
# allow compatibility accross versions
|
||||
class_args = inspect.getargspec(cls.__init__).args
|
||||
agent_fields = list(agent_config.keys())
|
||||
for key in agent_fields:
|
||||
if key not in class_args:
|
||||
utils.printd(f"Removing missing argument {key} from agent config")
|
||||
del agent_config[key]
|
||||
|
||||
return cls(**agent_config)
|
||||
|
||||
|
||||
|
||||
@@ -60,4 +60,7 @@ MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep you
|
||||
|
||||
REQ_HEARTBEAT_MESSAGE = "request_heartbeat == true"
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE = "Function call failed"
|
||||
FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat"
|
||||
FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean"
|
||||
FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function."
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
|
||||
|
||||
0
memgpt/functions/__init__.py
Normal file
0
memgpt/functions/__init__.py
Normal file
168
memgpt/functions/function_sets/base.py
Normal file
168
memgpt/functions/function_sets/base.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
|
||||
from ...constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
### Functions / tools the agent can use
|
||||
# All functions should return a response string (or None)
|
||||
# If the function fails, throw an exception
|
||||
|
||||
|
||||
def send_message(self, message: str):
|
||||
"""
|
||||
Sends a message to the human user.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.interface.assistant_message(message)
|
||||
return None
|
||||
|
||||
|
||||
# Construct the docstring dynamically (since it should use the external constants)
|
||||
pause_heartbeats_docstring = f"""
|
||||
Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events.
|
||||
|
||||
Args:
|
||||
minutes (int): Number of minutes to ignore heartbeats for. Max value of {MAX_PAUSE_HEARTBEATS} minutes ({MAX_PAUSE_HEARTBEATS // 60} hours).
|
||||
|
||||
Returns:
|
||||
str: Function status response
|
||||
"""
|
||||
|
||||
|
||||
def pause_heartbeats(self, minutes: int):
|
||||
minutes = min(MAX_PAUSE_HEARTBEATS, minutes)
|
||||
|
||||
# Record the current time
|
||||
self.pause_heartbeats_start = datetime.datetime.now()
|
||||
# And record how long the pause should go for
|
||||
self.pause_heartbeats_minutes = int(minutes)
|
||||
|
||||
return f"Pausing timed heartbeats for {minutes} min"
|
||||
|
||||
|
||||
pause_heartbeats.__doc__ = pause_heartbeats_docstring
|
||||
|
||||
|
||||
def core_memory_append(self, name: str, content: str):
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
new_len = self.memory.edit_append(name, content)
|
||||
self.rebuild_memory()
|
||||
return None
|
||||
|
||||
|
||||
def core_memory_replace(self, name: str, old_content: str, new_content: str):
|
||||
"""
|
||||
Replace to the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
new_len = self.memory.edit_replace(name, old_content, new_content)
|
||||
self.rebuild_memory()
|
||||
return None
|
||||
|
||||
|
||||
def conversation_search(self, query: str, page: Optional[int] = 0):
|
||||
"""
|
||||
Search prior conversation history using case-insensitive string matching.
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
"""
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
|
||||
def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0):
|
||||
"""
|
||||
Search prior conversation history using a date range.
|
||||
|
||||
Args:
|
||||
start_date (str): The start of the date range to search, in the format 'YYYY-MM-DD'.
|
||||
end_date (str): The end of the date range to search, in the format 'YYYY-MM-DD'.
|
||||
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
"""
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
|
||||
def archival_memory_insert(self, content: str):
|
||||
"""
|
||||
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
||||
|
||||
Args:
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.persistence_manager.archival_memory.insert(content)
|
||||
return None
|
||||
|
||||
|
||||
def archival_memory_search(self, query: str, page: Optional[int] = 0):
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
"""
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
126
memgpt/functions/function_sets/extras.py
Normal file
126
memgpt/functions/function_sets/extras.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
|
||||
from ...constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
|
||||
from ...openai_tools import completions_with_backoff as create
|
||||
|
||||
|
||||
def message_chatgpt(self, message: str):
|
||||
"""
|
||||
Send a message to a more basic AI, ChatGPT. A useful resource for asking questions. ChatGPT does not retain memory of previous interactions.
|
||||
|
||||
Args:
|
||||
message (str): Message to send ChatGPT. Phrase your message as a full English sentence.
|
||||
|
||||
Returns:
|
||||
str: Reply message from ChatGPT
|
||||
"""
|
||||
message_sequence = [
|
||||
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
|
||||
{"role": "user", "content": str(message)},
|
||||
]
|
||||
response = create(
|
||||
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
messages=message_sequence,
|
||||
# functions=functions,
|
||||
# function_call=function_call,
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
return reply
|
||||
|
||||
|
||||
def read_from_text_file(self, filename: str, line_start: int, num_lines: Optional[int] = 1):
|
||||
"""
|
||||
Read lines from a text file.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to read.
|
||||
line_start (int): Line to start reading from.
|
||||
num_lines (Optional[int]): How many lines to read (defaults to 1).
|
||||
|
||||
Returns:
|
||||
str: Text read from the file
|
||||
"""
|
||||
max_chars = 500
|
||||
trunc_message = True
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError(f"The file '{filename}' does not exist.")
|
||||
|
||||
if line_start < 1 or num_lines < 1:
|
||||
raise ValueError("Both line_start and num_lines must be positive integers.")
|
||||
|
||||
lines = []
|
||||
chars_read = 0
|
||||
with open(filename, "r") as file:
|
||||
for current_line_number, line in enumerate(file, start=1):
|
||||
if line_start <= current_line_number < line_start + num_lines:
|
||||
chars_to_add = len(line)
|
||||
if max_chars is not None and chars_read + chars_to_add > max_chars:
|
||||
# If adding this line exceeds MAX_CHARS, truncate the line if needed and stop reading further.
|
||||
excess_chars = (chars_read + chars_to_add) - max_chars
|
||||
lines.append(line[:-excess_chars].rstrip("\n"))
|
||||
if trunc_message:
|
||||
lines.append(f"[SYSTEM ALERT - max chars ({max_chars}) reached during file read]")
|
||||
break
|
||||
else:
|
||||
lines.append(line.rstrip("\n"))
|
||||
chars_read += chars_to_add
|
||||
if current_line_number >= line_start + num_lines - 1:
|
||||
break
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def append_to_text_file(self, filename: str, content: str):
|
||||
"""
|
||||
Append to a text file.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to append to.
|
||||
content (str): Content to append to the file.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError(f"The file '{filename}' does not exist.")
|
||||
|
||||
with open(filename, "a") as file:
|
||||
file.write(content + "\n")
|
||||
|
||||
|
||||
def http_request(self, method: str, url: str, payload_json: Optional[str] = None):
|
||||
"""
|
||||
Generates an HTTP request and returns the response.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method (e.g., 'GET', 'POST').
|
||||
url (str): The URL for the request.
|
||||
payload_json (Optional[str]): A JSON string representing the request payload.
|
||||
|
||||
Returns:
|
||||
dict: The response from the HTTP request.
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# For GET requests, ignore the payload
|
||||
if method.upper() == "GET":
|
||||
print(f"[HTTP] launching GET request to {url}")
|
||||
response = requests.get(url, headers=headers)
|
||||
else:
|
||||
# Validate and convert the payload for other types of requests
|
||||
if payload_json:
|
||||
payload = json.loads(payload_json)
|
||||
else:
|
||||
payload = {}
|
||||
print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2)}")
|
||||
response = requests.request(method, url, json=payload, headers=headers)
|
||||
|
||||
return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
77
memgpt/functions/functions.py
Normal file
77
memgpt/functions/functions.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
||||
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
|
||||
|
||||
def load_function_set(set_name):
|
||||
"""Load the functions and generate schema for them"""
|
||||
function_dict = {}
|
||||
|
||||
module_name = f"memgpt.functions.function_sets.{set_name}"
|
||||
base_functions = importlib.import_module(module_name)
|
||||
|
||||
for attr_name in dir(base_functions):
|
||||
# Get the attribute
|
||||
attr = getattr(base_functions, attr_name)
|
||||
|
||||
# Check if it's a callable function and not a built-in or special method
|
||||
if inspect.isfunction(attr) and attr.__module__ == base_functions.__name__:
|
||||
if attr_name in function_dict:
|
||||
raise ValueError(f"Found a duplicate of function name '{attr_name}'")
|
||||
|
||||
generated_schema = generate_schema(attr)
|
||||
function_dict[attr_name] = {
|
||||
"python_function": attr,
|
||||
"json_schema": generated_schema,
|
||||
}
|
||||
|
||||
if len(function_dict) == 0:
|
||||
raise ValueError(f"No functions found in module {module_name}")
|
||||
return function_dict
|
||||
|
||||
|
||||
def load_all_function_sets(merge=True):
|
||||
# functions/examples/*.py
|
||||
scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
|
||||
function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory
|
||||
# List all .py files in the directory (excluding __init__.py)
|
||||
example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"]
|
||||
|
||||
# ~/.memgpt/functions/*.py
|
||||
user_scripts_dir = os.path.join(MEMGPT_DIR, "functions")
|
||||
# create if missing
|
||||
if not os.path.exists(user_scripts_dir):
|
||||
os.makedirs(user_scripts_dir)
|
||||
user_module_files = [f for f in os.listdir(user_scripts_dir) if f.endswith(".py") and f != "__init__.py"]
|
||||
|
||||
# combine them both (pull from both examples and user-provided)
|
||||
all_module_files = example_module_files + user_module_files
|
||||
|
||||
schemas_and_functions = {}
|
||||
for file in all_module_files:
|
||||
# Convert filename to module name
|
||||
module_name = f"memgpt.functions.function_sets.{file[:-3]}" # Remove '.py' from filename
|
||||
|
||||
try:
|
||||
# Load the function set
|
||||
function_set = load_function_set(file[:-3]) # Pass the module part of the name
|
||||
schemas_and_functions[module_name] = function_set
|
||||
except ValueError as e:
|
||||
print(f"Error loading function set '{module_name}': {e}")
|
||||
|
||||
if merge:
|
||||
# Put all functions from all sets into the same level dict
|
||||
merged_functions = {}
|
||||
for set_name, function_set in schemas_and_functions.items():
|
||||
for function_name, function_info in function_set.items():
|
||||
if function_name in merged_functions:
|
||||
raise ValueError(f"Duplicate function name '{function_name}' found in function set '{set_name}'")
|
||||
merged_functions[function_name] = function_info
|
||||
return merged_functions
|
||||
else:
|
||||
# Nested dict where the top level is organized by the function set name
|
||||
return schemas_and_functions
|
||||
104
memgpt/functions/schema_generator.py
Normal file
104
memgpt/functions/schema_generator.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import inspect
|
||||
import typing
|
||||
from typing import get_args
|
||||
|
||||
from docstring_parser import parse
|
||||
|
||||
from memgpt.constants import FUNCTION_PARAM_NAME_REQ_HEARTBEAT, FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT
|
||||
|
||||
NO_HEARTBEAT_FUNCTIONS = ["send_message", "pause_heartbeats"]
|
||||
|
||||
|
||||
def is_optional(annotation):
|
||||
# Check if the annotation is a Union
|
||||
if getattr(annotation, "__origin__", None) is typing.Union:
|
||||
# Check if None is one of the options in the Union
|
||||
return type(None) in annotation.__args__
|
||||
return False
|
||||
|
||||
|
||||
def optional_length(annotation):
|
||||
if is_optional(annotation):
|
||||
# Subtract 1 to account for NoneType
|
||||
return len(annotation.__args__) - 1
|
||||
else:
|
||||
raise ValueError("The annotation is not an Optional type")
|
||||
|
||||
|
||||
def type_to_json_schema_type(py_type):
|
||||
"""
|
||||
Maps a Python type to a JSON schema type.
|
||||
Specifically handles typing.Optional and common Python types.
|
||||
"""
|
||||
# if get_origin(py_type) is typing.Optional:
|
||||
if is_optional(py_type):
|
||||
# Assert that Optional has only one type argument
|
||||
type_args = get_args(py_type)
|
||||
assert optional_length(py_type) == 1, f"Optional type must have exactly one type argument, but got {py_type}"
|
||||
|
||||
# Extract and map the inner type
|
||||
return type_to_json_schema_type(type_args[0])
|
||||
|
||||
# Mapping of Python types to JSON schema types
|
||||
type_map = {
|
||||
int: "integer",
|
||||
str: "string",
|
||||
bool: "boolean",
|
||||
float: "number",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
if py_type not in type_map:
|
||||
raise ValueError(f"Python type {py_type} has no corresponding JSON schema type")
|
||||
|
||||
return type_map.get(py_type, "string") # Default to "string" if type not in map
|
||||
|
||||
|
||||
def generate_schema(function):
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
|
||||
# Parse the docstring
|
||||
docstring = parse(function.__doc__)
|
||||
|
||||
# Prepare the schema dictionary
|
||||
schema = {
|
||||
"name": function.__name__,
|
||||
"description": docstring.short_description,
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
for param in sig.parameters.values():
|
||||
# Exclude 'self' parameter
|
||||
if param.name == "self":
|
||||
continue
|
||||
|
||||
# Assert that the parameter has a type annotation
|
||||
if param.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a type annotation")
|
||||
|
||||
# Find the parameter's description in the docstring
|
||||
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
|
||||
|
||||
# Assert that the parameter has a description
|
||||
if not param_doc or not param_doc.description:
|
||||
raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
|
||||
|
||||
# Add parameter details to the schema
|
||||
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
|
||||
schema["parameters"]["properties"][param.name] = {
|
||||
# "type": "string" if param.annotation == str else str(param.annotation),
|
||||
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
|
||||
"description": param_doc.description,
|
||||
}
|
||||
if param.default == inspect.Parameter.empty:
|
||||
schema["parameters"]["required"].append(param.name)
|
||||
|
||||
# append the heartbeat
|
||||
if function.__name__ not in NO_HEARTBEAT_FUNCTIONS:
|
||||
schema["parameters"]["properties"][FUNCTION_PARAM_NAME_REQ_HEARTBEAT] = {
|
||||
"type": FUNCTION_PARAM_TYPE_REQ_HEARTBEAT,
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
}
|
||||
schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT)
|
||||
|
||||
return schema
|
||||
@@ -22,7 +22,7 @@ import memgpt.interface # for printing to terminal
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.utils as utils
|
||||
import memgpt.presets as presets
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
from .prompts import gpt_functions
|
||||
from .prompts import gpt_system
|
||||
|
||||
DEFAULT_PRESET = "memgpt_chat"
|
||||
preset_options = [DEFAULT_PRESET]
|
||||
|
||||
|
||||
def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
||||
"""Storing combinations of SYSTEM + FUNCTION prompts"""
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.utils import printd
|
||||
|
||||
if preset_name == DEFAULT_PRESET:
|
||||
functions = [
|
||||
"send_message",
|
||||
"pause_heartbeats",
|
||||
"core_memory_append",
|
||||
"core_memory_replace",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
"archival_memory_search",
|
||||
]
|
||||
available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
|
||||
printd(f"Available functions:\n", [x["name"] for x in available_functions])
|
||||
assert len(functions) == len(available_functions)
|
||||
|
||||
if "gpt-3.5" in model:
|
||||
# use a different system message for gpt-3.5
|
||||
preset_name = "memgpt_gpt35_extralong"
|
||||
|
||||
return Agent(
|
||||
config=agent_config,
|
||||
model=model,
|
||||
system=gpt_system.get_system_text(preset_name),
|
||||
functions=available_functions,
|
||||
interface=interface,
|
||||
persistence_manager=persistence_manager,
|
||||
persona_notes=persona,
|
||||
human_notes=human,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if "gpt-4" in model else False,
|
||||
)
|
||||
|
||||
elif preset_name == "memgpt_extras":
|
||||
functions = [
|
||||
"send_message",
|
||||
"pause_heartbeats",
|
||||
"core_memory_append",
|
||||
"core_memory_replace",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
"archival_memory_search",
|
||||
# extra for read/write to files
|
||||
"read_from_text_file",
|
||||
"append_to_text_file",
|
||||
# internet access
|
||||
"http_request",
|
||||
]
|
||||
available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
|
||||
printd(f"Available functions:\n", [x["name"] for x in available_functions])
|
||||
assert len(functions) == len(available_functions)
|
||||
|
||||
if "gpt-3.5" in model:
|
||||
# use a different system message for gpt-3.5
|
||||
preset_name = "memgpt_gpt35_extralong"
|
||||
|
||||
return Agent(
|
||||
config=agent_config,
|
||||
model=model,
|
||||
system=gpt_system.get_system_text("memgpt_chat"),
|
||||
functions=available_functions,
|
||||
interface=interface,
|
||||
persistence_manager=persistence_manager,
|
||||
persona_notes=persona,
|
||||
human_notes=human,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if "gpt-4" in model else False,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(preset_name)
|
||||
10
memgpt/presets/examples/memgpt_chat.yaml
Normal file
10
memgpt/presets/examples/memgpt_chat.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
system_prompt: "memgpt_chat"
|
||||
functions:
|
||||
- "send_message"
|
||||
- "pause_heartbeats"
|
||||
- "core_memory_append"
|
||||
- "core_memory_replace"
|
||||
- "conversation_search"
|
||||
- "conversation_search_date"
|
||||
- "archival_memory_insert"
|
||||
- "archival_memory_search"
|
||||
10
memgpt/presets/examples/memgpt_docs.yaml
Normal file
10
memgpt/presets/examples/memgpt_docs.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
system_prompt: "memgpt_doc"
|
||||
functions:
|
||||
- "send_message"
|
||||
- "pause_heartbeats"
|
||||
- "core_memory_append"
|
||||
- "core_memory_replace"
|
||||
- "conversation_search"
|
||||
- "conversation_search_date"
|
||||
- "archival_memory_insert"
|
||||
- "archival_memory_search"
|
||||
15
memgpt/presets/examples/memgpt_extras.yaml
Normal file
15
memgpt/presets/examples/memgpt_extras.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
system_prompt: "memgpt_chat"
|
||||
functions:
|
||||
- "send_message"
|
||||
- "pause_heartbeats"
|
||||
- "core_memory_append"
|
||||
- "core_memory_replace"
|
||||
- "conversation_search"
|
||||
- "conversation_search_date"
|
||||
- "archival_memory_insert"
|
||||
- "archival_memory_search"
|
||||
# extras for read/write to files
|
||||
- "read_from_text_file"
|
||||
- "append_to_text_file"
|
||||
# internet access
|
||||
- "http_request"
|
||||
61
memgpt/presets/presets.py
Normal file
61
memgpt/presets/presets.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from .utils import load_all_presets, is_valid_yaml_format
|
||||
from ..prompts import gpt_functions
|
||||
from ..prompts import gpt_system
|
||||
from ..functions.functions import load_all_function_sets
|
||||
|
||||
DEFAULT_PRESET = "memgpt_chat"
|
||||
|
||||
available_presets = load_all_presets()
|
||||
preset_options = list(available_presets.keys())
|
||||
|
||||
|
||||
def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
||||
"""Storing combinations of SYSTEM + FUNCTION prompts"""
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.utils import printd
|
||||
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
available_functions = load_all_function_sets()
|
||||
|
||||
available_presets = load_all_presets()
|
||||
if preset_name not in available_presets:
|
||||
raise ValueError(f"Preset '{preset_name}.yaml' not found")
|
||||
|
||||
preset = available_presets[preset_name]
|
||||
if not is_valid_yaml_format(preset, list(available_functions.keys())):
|
||||
raise ValueError(f"Preset '{preset_name}.yaml' is not valid")
|
||||
|
||||
preset_system_prompt = preset["system_prompt"]
|
||||
preset_function_set_names = preset["functions"]
|
||||
|
||||
# Filter down the function set based on what the preset requested
|
||||
preset_function_set = {}
|
||||
for f_name in preset_function_set_names:
|
||||
if f_name not in available_functions:
|
||||
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
|
||||
preset_function_set[f_name] = available_functions[f_name]
|
||||
assert len(preset_function_set_names) == len(preset_function_set)
|
||||
printd(f"Available functions:\n", list(preset_function_set.keys()))
|
||||
|
||||
# preset_function_set = {f_name: f_dict for f_name, f_dict in available_functions.items() if f_name in preset_function_set_names}
|
||||
# printd(f"Available functions:\n", [f_name for f_name, f_dict in preset_function_set.items()])
|
||||
# Make sure that every function the preset wanted is inside the available functions
|
||||
# assert len(preset_function_set_names) == len(preset_function_set)
|
||||
|
||||
return Agent(
|
||||
config=agent_config,
|
||||
model=model,
|
||||
system=gpt_system.get_system_text(preset_system_prompt),
|
||||
functions=preset_function_set,
|
||||
interface=interface,
|
||||
persistence_manager=persistence_manager,
|
||||
persona_notes=persona,
|
||||
human_notes=human,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if "gpt-4" in model else False,
|
||||
)
|
||||
76
memgpt/presets/utils.py
Normal file
76
memgpt/presets/utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
import glob
|
||||
import yaml
|
||||
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
|
||||
|
||||
def is_valid_yaml_format(yaml_data, function_set):
|
||||
"""
|
||||
Check if the given YAML data follows the specified format and if all functions in the yaml are part of the function_set.
|
||||
Raises ValueError if any check fails.
|
||||
|
||||
:param yaml_data: The data loaded from a YAML file.
|
||||
:param function_set: A set of valid function names.
|
||||
"""
|
||||
# Check for required keys
|
||||
if not all(key in yaml_data for key in ["system_prompt", "functions"]):
|
||||
raise ValueError("YAML data is missing one or more required keys: 'system_prompt', 'functions'.")
|
||||
|
||||
# Check if 'functions' is a list of strings
|
||||
if not all(isinstance(item, str) for item in yaml_data.get("functions", [])):
|
||||
raise ValueError("'functions' should be a list of strings.")
|
||||
|
||||
# Check if all functions in YAML are part of function_set
|
||||
if not set(yaml_data["functions"]).issubset(function_set):
|
||||
raise ValueError("Some functions in YAML are not part of the provided function set.")
|
||||
|
||||
# If all checks pass
|
||||
return True
|
||||
|
||||
|
||||
def load_yaml_file(file_path):
|
||||
"""
|
||||
Load a YAML file and return the data.
|
||||
|
||||
:param file_path: Path to the YAML file.
|
||||
:return: Data from the YAML file.
|
||||
"""
|
||||
with open(file_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
|
||||
def load_all_presets():
|
||||
"""Load all the preset configs in the examples directory"""
|
||||
|
||||
## Load the examples
|
||||
# Get the directory in which the script is located
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
# Construct the path pattern
|
||||
example_path_pattern = os.path.join(script_directory, "examples", "*.yaml")
|
||||
# Listing all YAML files
|
||||
example_yaml_files = glob.glob(example_path_pattern)
|
||||
|
||||
## Load the user-provided presets
|
||||
# ~/.memgpt/presets/*.yaml
|
||||
user_presets_dir = os.path.join(MEMGPT_DIR, "presets")
|
||||
# Create directory if it doesn't exist
|
||||
if not os.path.exists(user_presets_dir):
|
||||
os.makedirs(user_presets_dir)
|
||||
# Construct the path pattern
|
||||
user_path_pattern = os.path.join(user_presets_dir, "*.yaml")
|
||||
# Listing all YAML files
|
||||
user_yaml_files = glob.glob(user_path_pattern)
|
||||
|
||||
# Pull from both examplesa and user-provided
|
||||
all_yaml_files = example_yaml_files + user_yaml_files
|
||||
|
||||
# Loading and creating a mapping from file name to YAML data
|
||||
all_yaml_data = {}
|
||||
for file_path in all_yaml_files:
|
||||
# Extracting the base file name without the '.yaml' extension
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
data = load_yaml_file(file_path)
|
||||
all_yaml_data[base_name] = data
|
||||
|
||||
return all_yaml_data
|
||||
@@ -1,10 +1,10 @@
|
||||
from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT
|
||||
from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, MAX_PAUSE_HEARTBEATS
|
||||
|
||||
# FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1]
|
||||
FUNCTIONS_CHAINING = {
|
||||
"send_message": {
|
||||
"name": "send_message",
|
||||
"description": "Sends a message to the human user",
|
||||
"description": "Sends a message to the human user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -26,7 +26,7 @@ FUNCTIONS_CHAINING = {
|
||||
# https://json-schema.org/understanding-json-schema/reference/array.html
|
||||
"minutes": {
|
||||
"type": "integer",
|
||||
"description": "Number of minutes to ignore heartbeats for. Max value of 360 minutes (6 hours).",
|
||||
"description": f"Number of minutes to ignore heartbeats for. Max value of {MAX_PAUSE_HEARTBEATS} minutes ({MAX_PAUSE_HEARTBEATS//60} hours).",
|
||||
},
|
||||
},
|
||||
"required": ["minutes"],
|
||||
@@ -45,7 +45,7 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution, use to chain multiple functions.",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
"required": ["message", "request_heartbeat"],
|
||||
@@ -67,7 +67,7 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution, use to chain multiple functions.",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
"required": ["name", "content", "request_heartbeat"],
|
||||
@@ -93,7 +93,7 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution, use to chain multiple functions.",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
"required": ["name", "old_content", "new_content", "request_heartbeat"],
|
||||
@@ -140,7 +140,7 @@ FUNCTIONS_CHAINING = {
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
"required": ["query", "page", "request_heartbeat"],
|
||||
"required": ["query", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
"recall_memory_search_date": {
|
||||
@@ -192,7 +192,7 @@ FUNCTIONS_CHAINING = {
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
"required": ["start_date", "end_date", "page", "request_heartbeat"],
|
||||
"required": ["start_date", "end_date", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
"archival_memory_insert": {
|
||||
@@ -232,7 +232,7 @@ FUNCTIONS_CHAINING = {
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
"required": ["query", "page", "request_heartbeat"],
|
||||
"required": ["query", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
"read_from_text_file": {
|
||||
@@ -269,7 +269,7 @@ FUNCTIONS_CHAINING = {
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The name of the file to read.",
|
||||
"description": "The name of the file to append to.",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
@@ -295,9 +295,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL for the request",
|
||||
"description": "The URL for the request.",
|
||||
},
|
||||
"payload": {
|
||||
"payload_json": {
|
||||
"type": "string",
|
||||
"description": "A JSON string representing the request payload.",
|
||||
},
|
||||
|
||||
@@ -1,12 +1,26 @@
|
||||
import os
|
||||
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
|
||||
|
||||
def get_system_text(key):
|
||||
filename = f"{key}.txt"
|
||||
file_path = os.path.join(os.path.dirname(__file__), "system", filename)
|
||||
|
||||
# first look in prompts/system/*.txt
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
return file.read().strip()
|
||||
else:
|
||||
raise FileNotFoundError(f"No file found for key {key}, path={file_path}")
|
||||
# try looking in ~/.memgpt/system_prompts/*.txt
|
||||
user_system_prompts_dir = os.path.join(MEMGPT_DIR, "system_prompts")
|
||||
# create directory if it doesn't exist
|
||||
if not os.path.exists(user_system_prompts_dir):
|
||||
os.makedirs(user_system_prompts_dir)
|
||||
# look inside for a matching system prompt
|
||||
file_path = os.path.join(user_system_prompts_dir, filename)
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
return file.read().strip()
|
||||
else:
|
||||
raise FileNotFoundError(f"No file found for key {key}, path={file_path}")
|
||||
|
||||
@@ -423,3 +423,17 @@ def get_human_text(name: str):
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
return open(file_path, "r").read().strip()
|
||||
|
||||
|
||||
def get_schema_diff(schema_a, schema_b):
|
||||
# Assuming f_schema and linked_function['json_schema'] are your JSON schemas
|
||||
f_schema_json = json.dumps(schema_a, indent=2)
|
||||
linked_function_json = json.dumps(schema_b, indent=2)
|
||||
|
||||
# Compute the difference using difflib
|
||||
difference = list(difflib.ndiff(f_schema_json.splitlines(keepends=True), linked_function_json.splitlines(keepends=True)))
|
||||
|
||||
# Filter out lines that don't represent changes
|
||||
difference = [line for line in difference if line.startswith("+ ") or line.startswith("- ")]
|
||||
|
||||
return "".join(difference)
|
||||
|
||||
19
poetry.lock
generated
19
poetry.lock
generated
@@ -534,6 +534,17 @@ files = [
|
||||
{file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.15"
|
||||
description = "Parse Python docstrings in reST, Google and Numpydoc format"
|
||||
optional = false
|
||||
python-versions = ">=3.6,<4.0"
|
||||
files = [
|
||||
{file = "docstring_parser-0.15-py3-none-any.whl", hash = "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9"},
|
||||
{file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.1.3"
|
||||
@@ -2133,26 +2144,31 @@ python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:c4eb71b88a22c1008f764b3121b36a9d25340f9920b870508356050a365d9ca1"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:3ce2d3678dbf822cff213b1902f2e59756313e543efd516a2b4f15bb0353bd6c"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:2e27857a15c8a810d0b66455b8c8a79013640b6267a9b4ea808a5fe1f47711f2"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:5cd05700c8f18c9dafef63ac2ed3b1099ca06017ca0c32deea13093cea1b8671"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-win32.whl", hash = "sha256:951d280c1daafac2fd6a664b031f7f98b27eb2def55d39c92a19087bd8041c5d"},
|
||||
{file = "PyMuPDF-1.23.6-cp310-none-win_amd64.whl", hash = "sha256:19d1711d5908c4527ad2deef5af2d066649f3f9a12950faf30be5f7251d18abc"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:3f0f9b76bc4f039e7587003cbd40684d93a98441549dd033cab38ca07d61988d"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e047571d799b30459ad7ee0bc6e68900a7f6b928876f956c976f279808814e72"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:1cbcf05c06f314fdf3042ceee674e9a0ac7fae598347d5442e2138c6046d4e82"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:e33f8ec5ba7265fe78b30332840b8f454184addfa79f9c27f160f19789aa5ffd"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-win32.whl", hash = "sha256:2c141f33e2733e48de8524dfd2de56d889feef0c7773b20a8cd216c03ab24793"},
|
||||
{file = "PyMuPDF-1.23.6-cp311-none-win_amd64.whl", hash = "sha256:8fd9c4ee1dd4744a515b9190d8ba9133348b0d94c362293ed77726aa1c13b0a6"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:4d06751d5cd213e96f84f2faaa71a51cf4d641851e07579247ca1190121f173b"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:526b26a5207e923aab65877ad305644402851823a352cb92d362053426899354"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:0f852d125defc26716878b1796f4d68870e9065041d00cf46bde317fd8d30e68"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:5bdf7020b90987412381acc42427dd1b7a03d771ee9ec273de003e570164ec1a"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-win32.whl", hash = "sha256:e2d64799c6d9a3735be9e162a5d11061c0b7fbcb1e5fc7446e0993d0f815a93a"},
|
||||
{file = "PyMuPDF-1.23.6-cp312-none-win_amd64.whl", hash = "sha256:c8ea81964c1433ea163ad4b53c56053a87a9ef6e1bd7a879d4d368a3988b60d1"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:761501a4965264e81acdd8f2224f993020bf24474e9b34fcdb5805a6826eda1c"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:fd8388e82b6045807d19addf310d8119d32908e89f76cc8bbf8cf1ec36fce947"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:4ac9673a6d6ee7e80cb242dacb43f9ca097b502d9c5e44687dbdffc2bce7961a"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:6e319c1f49476e07b9a12017c2d031687617713f8a46b7adcec03c636ed04607"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-win32.whl", hash = "sha256:1103eea4ab727e32b9cb93347b35f71562033018c333a7f3a17d115e980fea4a"},
|
||||
{file = "PyMuPDF-1.23.6-cp38-none-win_amd64.whl", hash = "sha256:991a37e1cba43775ce094da87cf0bf72172a5532a09644003276bc8bfdfe9f1a"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:57725e15872f7ab67a9fb3e06e5384d1047b2121e85755c93a6d4266d3ca8983"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:224c341fe254adda97c8f06a4c5838cdbcf609fa89e70b1fb179752533378f2f"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:271bdf6059bb8347f9c9c6b721329bd353a933681b1fc62f43241b410e7ab7ae"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:57e22bea69690450197b34dcde16bd9fe0265ac4425b4033535ccc5c044246fb"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-win32.whl", hash = "sha256:2885a26220a32fb45ea443443b72194bb7107d6862d8d546b59e4ad0c8a1f2c9"},
|
||||
{file = "PyMuPDF-1.23.6-cp39-none-win_amd64.whl", hash = "sha256:361cab1be45481bd3dc4e00ec82628ebc189b4f4b6fd9bd78a00cfeed54e0034"},
|
||||
@@ -2171,6 +2187,7 @@ python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e5af77580aad3d1103aeec57009d156bfca429cecda14a17c573fcbe97bafb30"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9925816cbe3e05e920f9be925e5752c2eef42b793885b62075bb0f6a69178598"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:009e2cff166059e13bf71f93919e688f46b8fc11d122433574cfb0cc9134690e"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7132b30e6ad6ff2013344e3a481b2287fe0be3710d80694807dd6e0d8635f085"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-win32.whl", hash = "sha256:9d24ddadc204e895bee5000ddc7507c801643548e59f5a56aad6d32981d17eeb"},
|
||||
{file = "PyMuPDFb-1.23.6-py3-none-win_amd64.whl", hash = "sha256:7bef75988e6979b10ca804cf9487f817aae43b0fff1c6e315b3b9ee0cf1cc32f"},
|
||||
@@ -3510,4 +3527,4 @@ postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.12,>=3.9"
|
||||
content-hash = "32cc1809f381627327c0d8c2334bdee73c3653437fa2b8138df34953d3d2a200"
|
||||
content-hash = "24e6c3cea1895441e07d362a5a2f9a07a045b92b5364531b8b6e3571904199fe"
|
||||
|
||||
@@ -45,6 +45,7 @@ transformers = { version = "4.34.1", optional = true }
|
||||
pre-commit = {version = "^3.5.0", optional = true }
|
||||
pg8000 = {version = "^1.30.3", optional = true}
|
||||
torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
|
||||
docstring-parser = "^0.15"
|
||||
|
||||
[tool.poetry.extras]
|
||||
legacy = ["faiss-cpu", "numpy"]
|
||||
|
||||
109
tests/test_schema_generator.py
Normal file
109
tests/test_schema_generator.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import inspect
|
||||
|
||||
import memgpt.functions.function_sets.base as base_functions
|
||||
import memgpt.functions.function_sets.extras as extras_functions
|
||||
from memgpt.prompts.gpt_functions import FUNCTIONS_CHAINING
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
|
||||
|
||||
def send_message(self, message: str):
|
||||
"""
|
||||
Sends a message to the human user.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def send_message_missing_types(self, message):
|
||||
"""
|
||||
Sends a message to the human user.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def send_message_missing_docstring(self, message: str):
|
||||
return None
|
||||
|
||||
|
||||
def test_schema_generator():
|
||||
# Check that a basic function schema converts correctly
|
||||
correct_schema = {
|
||||
"name": "send_message",
|
||||
"description": "Sends a message to the human user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}},
|
||||
"required": ["message"],
|
||||
},
|
||||
}
|
||||
generated_schema = generate_schema(send_message)
|
||||
print(f"\n\nreference_schema={correct_schema}")
|
||||
print(f"\n\ngenerated_schema={generated_schema}")
|
||||
assert correct_schema == generated_schema
|
||||
|
||||
# Check that missing types results in an error
|
||||
try:
|
||||
_ = generate_schema(send_message_missing_types)
|
||||
assert False
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check that missing docstring results in an error
|
||||
try:
|
||||
_ = generate_schema(send_message_missing_docstring)
|
||||
assert False
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def test_schema_generator_with_old_function_set():
|
||||
# Try all the base functions first
|
||||
for attr_name in dir(base_functions):
|
||||
# Get the attribute
|
||||
attr = getattr(base_functions, attr_name)
|
||||
|
||||
# Check if it's a callable function and not a built-in or special method
|
||||
if inspect.isfunction(attr):
|
||||
# Here, 'func' is each function in base_functions
|
||||
# You can now call the function or do something with it
|
||||
print("Function name:", attr)
|
||||
# Example function call (if the function takes no arguments)
|
||||
# result = func()
|
||||
function_name = str(attr_name)
|
||||
real_schema = FUNCTIONS_CHAINING[function_name]
|
||||
generated_schema = generate_schema(attr)
|
||||
print(f"\n\nreference_schema={real_schema}")
|
||||
print(f"\n\ngenerated_schema={generated_schema}")
|
||||
assert real_schema == generated_schema
|
||||
|
||||
# Then try all the extras functions
|
||||
for attr_name in dir(extras_functions):
|
||||
# Get the attribute
|
||||
attr = getattr(extras_functions, attr_name)
|
||||
|
||||
# Check if it's a callable function and not a built-in or special method
|
||||
if inspect.isfunction(attr):
|
||||
if attr_name == "create":
|
||||
continue
|
||||
# Here, 'func' is each function in base_functions
|
||||
# You can now call the function or do something with it
|
||||
print("Function name:", attr)
|
||||
# Example function call (if the function takes no arguments)
|
||||
# result = func()
|
||||
function_name = str(attr_name)
|
||||
real_schema = FUNCTIONS_CHAINING[function_name]
|
||||
generated_schema = generate_schema(attr)
|
||||
print(f"\n\nreference_schema={real_schema}")
|
||||
print(f"\n\ngenerated_schema={generated_schema}")
|
||||
assert real_schema == generated_schema
|
||||
Reference in New Issue
Block a user