diff --git a/.github/workflows/poetry-publish.yml b/.github/workflows/poetry-publish.yml index 86717ff7..a1f1e646 100644 --- a/.github/workflows/poetry-publish.yml +++ b/.github/workflows/poetry-publish.yml @@ -3,7 +3,7 @@ on: release: types: [published] workflow_dispatch: - + jobs: build-and-publish: name: Build and Publish to PyPI diff --git a/memgpt/agent.py b/memgpt/agent.py index 215e7f0b..2ba08a05 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -1,6 +1,7 @@ import asyncio import inspect import datetime +import glob import pickle import math import os @@ -8,12 +9,14 @@ import json import threading import openai - +from memgpt.persistence_manager import LocalStateManager +from memgpt.config import AgentConfig from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from .memory import CoreMemory as Memory, summarize_messages, a_summarize_messages from .openai_tools import acompletions_with_backoff as acreate, completions_with_backoff as create from .utils import get_local_time, parse_json, united_diff, printd, count_tokens from .constants import ( + MEMGPT_DIR, FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, MESSAGE_CHATGPT_FUNCTION_MODEL, @@ -167,6 +170,7 @@ async def call_function(function_to_call, **function_args): class Agent(object): def __init__( self, + config, model, system, functions, @@ -178,6 +182,8 @@ class Agent(object): persistence_manager_init=True, first_message_verify_mono=True, ): + # agent config + self.config = config # gpt-4, gpt-3.5-turbo self.model = model # Store the system instructions (used to rebuild memory) @@ -194,7 +200,8 @@ class Agent(object): ) # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) - self.messages_total_init = self.messages_total + # self.messages_total_init = self.messages_total + self.messages_total_init = len(self._messages) - 1 printd(f"AgentAsync initialized, self.messages_total={self.messages_total}") # Interface must implement: @@ -331,6 +338,61 @@ class Agent(object): with open(filename, "w") as file: json.dump(self.to_dict(), file) + def save(self): + """Save agent state locally""" + + timestamp = get_local_time().replace(" ", "_").replace(":", "_") + agent_name = self.config.name # TODO: fix + + # save agent state + filename = f"{timestamp}.json" + os.makedirs(self.config.save_state_dir(), exist_ok=True) + self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename)) + + # save the persistence manager too + filename = f"{timestamp}.persistence.pickle" + os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True) + self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename)) + + @classmethod + def load_agent(cls, interface, agent_config: AgentConfig): + """Load saved agent state""" + # TODO: support loading from specific file + agent_name = agent_config.name + + # load state + directory = agent_config.save_state_dir() + json_files = glob.glob(f"{directory}/*.json") # This will list all .json files in the current directory. + if not json_files: + print(f"/load error: no .json checkpoint files found") + raise ValueError(f"Cannot load {agent_name}") + + # Sort files based on modified timestamp, with the latest file being the first. + filename = max(json_files, key=os.path.getmtime) + state = json.load(open(filename, "r")) + + # load persistence manager + filename = os.path.basename(filename).replace(".json", ".persistence.pickle") + directory = agent_config.save_persistence_manager_dir() + persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config) + + messages = state["messages"] + agent = cls( + config=agent_config, + model=state["model"], + system=state["system"], + functions=state["functions"], + interface=interface, + persistence_manager=persistence_manager, + persistence_manager_init=False, + persona_notes=state["memory"]["persona"], + human_notes=state["memory"]["human"], + messages_total=state["messages_total"] if "messages_total" in state else len(messages) - 1, + ) + agent._messages = messages + agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"]) + return agent + @classmethod def load(cls, state, interface, persistence_manager): model = state["model"] @@ -875,6 +937,9 @@ class AgentAsync(Agent): if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user": printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue") + from pprint import pprint + + pprint(input_message_sequence[-1]) # Step 1: send the conversation and available functions to GPT if not skip_verify and (first_message or self.messages_total == self.messages_total_init): @@ -901,9 +966,9 @@ class AgentAsync(Agent): # Add the extra metadata to the assistant response # (e.g. enough metadata to enable recreating the API call) - assert "api_response" not in all_response_messages[0] + assert "api_response" not in all_response_messages[0], f"api_response already in {all_response_messages[0]}" all_response_messages[0]["api_response"] = response_message_copy - assert "api_args" not in all_response_messages[0] + assert "api_args" not in all_response_messages[0], f"api_args already in {all_response_messages[0]}" all_response_messages[0]["api_args"] = { "model": self.model, "messages": input_message_sequence, @@ -933,6 +998,7 @@ class AgentAsync(Agent): except Exception as e: printd(f"step() failed\nuser_message = {user_message}\nerror = {e}") + print(f"step() failed\nuser_message = {user_message}\nerror = {e}") # If we got a context alert, try trimming the messages length, then try again if "maximum context length" in str(e): @@ -943,6 +1009,7 @@ class AgentAsync(Agent): return await self.step(user_message, first_message=first_message) else: printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'") + print(e) raise e async def summarize_messages_inplace(self, cutoff=None): diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index 2909b297..50609286 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -40,7 +40,7 @@ def create_memgpt_autogen_agent_from_config( autogen_memgpt_agent = create_autogen_memgpt_agent( name, - preset=presets.DEFAULT, + preset=presets.DEFAULT_PRESET, model=model, persona_description=persona_desc, user_description=user_desc, @@ -50,7 +50,7 @@ def create_memgpt_autogen_agent_from_config( if human_input_mode != "ALWAYS": coop_agent1 = create_autogen_memgpt_agent( name, - preset=presets.DEFAULT, + preset=presets.DEFAULT_PRESET, model=model, persona_description=persona_desc, user_description=user_desc, @@ -65,7 +65,7 @@ def create_memgpt_autogen_agent_from_config( else: coop_agent2 = create_autogen_memgpt_agent( name, - preset=presets.DEFAULT, + preset=presets.DEFAULT_PRESET, model=model, persona_description=persona_desc, user_description=user_desc, @@ -86,7 +86,7 @@ def create_memgpt_autogen_agent_from_config( def create_autogen_memgpt_agent( autogen_name, - preset=presets.DEFAULT, + preset=presets.DEFAULT_PRESET, model=constants.DEFAULT_MEMGPT_MODEL, persona_description=personas.DEFAULT, user_description=humans.DEFAULT, diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py new file mode 100644 index 00000000..30ae919f --- /dev/null +++ b/memgpt/cli/cli.py @@ -0,0 +1,148 @@ +import typer +import sys +import io +import logging +import asyncio +import os +from prettytable import PrettyTable +import questionary +import openai + +from llama_index import set_global_service_context +from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext + +import memgpt.interface # for printing to terminal +from memgpt.cli.cli_config import configure +import memgpt.agent as agent +import memgpt.system as system +import memgpt.presets as presets +import memgpt.constants as constants +import memgpt.personas.personas as personas +import memgpt.humans.humans as humans +import memgpt.utils as utils +from memgpt.utils import printd +from memgpt.persistence_manager import LocalStateManager +from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.constants import MEMGPT_DIR +from memgpt.agent import AgentAsync +from memgpt.embeddings import embedding_model + + +def run( + persona: str = typer.Option(None, help="Specify persona"), + agent: str = typer.Option(None, help="Specify agent save file"), + human: str = typer.Option(None, help="Specify human"), + model: str = typer.Option(None, help="Specify the LLM model"), + preset: str = typer.Option(None, help="Specify preset"), + data_source: str = typer.Option(None, help="Specify data source to attach to agent"), + first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"), + debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"), + no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"), + yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"), +): + """Start chatting with an MemGPT agent + + Example usage: `memgpt run --agent myagent --data-source mydata --persona mypersona --human myhuman --model gpt-3.5-turbo` + + :param persona: Specify persona + :param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name) + :param human: Specify human + :param model: Specify the LLM model + :param data_source: Specify data source to attach to agent (if new agent is being created) + + """ + + # setup logger + utils.DEBUG = debug + logging.getLogger().setLevel(logging.CRITICAL) + if debug: + logging.getLogger().setLevel(logging.DEBUG) + + if not MemGPTConfig.exists(): # if no config, run configure + if yes: + # use defaults + config = MemGPTConfig() + else: + # use input + configure() + config = MemGPTConfig.load() + else: # load config + config = MemGPTConfig.load() + + # override with command line arguments + if debug: + config.debug = debug + if no_verify: + config.no_verify = no_verify + + # determine agent to use, if not provided + if not yes and not agent: + agent_files = utils.list_agent_config_files() + agents = [AgentConfig.load(f).name for f in agent_files] + + if len(agents) > 0: + select_agent = questionary.confirm("Would you like to select an existing agent?").ask() + if select_agent: + agent = questionary.select("Select agent:", choices=agents).ask() + + # configure llama index + config = MemGPTConfig.load() + original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index + sys.stdout = io.StringIO() + embed_model = embedding_model(config) + service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size) + set_global_service_context(service_context) + sys.stdout = original_stdout + + # create agent config + if agent and AgentConfig.exists(agent): # use existing agent + typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN) + agent_config = AgentConfig.load(agent) + printd("State path:", agent_config.save_state_dir()) + printd("Persistent manager path:", agent_config.save_persistence_manager_dir()) + printd("Index path:", agent_config.save_agent_index_dir()) + # persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load + # TODO: load prior agent state + assert not any( + [persona, human, model] + ), f"Cannot override existing agent state with command line arguments: {persona}, {human}, {model}" + + # load existing agent + memgpt_agent = AgentAsync.load_agent(memgpt.interface, agent_config) + else: # create new agent + # create new agent config: override defaults with args if provided + typer.secho("Creating new agent...", fg=typer.colors.GREEN) + agent_config = AgentConfig( + name=agent if agent else None, + persona=persona if persona else config.default_persona, + human=human if human else config.default_human, + model=model if model else config.model, + preset=preset if preset else config.preset, + ) + + # attach data source to agent + agent_config.attach_data_source(data_source) + + # TODO: allow configrable state manager (only local is supported right now) + persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill + + # save new agent config + agent_config.save() + typer.secho(f"Created new agent {agent_config.name}.", fg=typer.colors.GREEN) + + # create agent + memgpt_agent = presets.use_preset( + agent_config.preset, + agent_config, + agent_config.model, + agent_config.persona, + agent_config.human, + memgpt.interface, + persistence_manager, + ) + + # start event loop + from memgpt.main import run_agent_loop + + loop = asyncio.get_event_loop() + loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py new file mode 100644 index 00000000..905a007a --- /dev/null +++ b/memgpt/cli/cli_config.py @@ -0,0 +1,199 @@ +import questionary +import openai +from prettytable import PrettyTable +import typer +import os +import shutil + +# from memgpt.cli import app +from memgpt import utils + +import memgpt.humans.humans as humans +import memgpt.personas.personas as personas +from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.constants import MEMGPT_DIR + +app = typer.Typer() + + +@app.command() +def configure(): + """Updates default MemGPT configurations""" + + from memgpt.presets import DEFAULT_PRESET, preset_options + + MemGPTConfig.create_config_dir() + + # openai credentials + use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?").ask() + if use_openai: + # search for key in enviornment + openai_key = os.getenv("OPENAI_API_KEY") + if not openai_key: + openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask() + + # azure credentials + use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?").ask() + use_azure_deployment_ids = False + if use_azure: + # search for key in enviornment + azure_key = os.getenv("AZURE_API_KEY") + azure_endpoint = (os.getenv("AZURE_ENDPOINT"),) + azure_version = (os.getenv("AZURE_VERSION"),) + azure_deployment = (os.getenv("AZURE_OPENAI_DEPLOYMENT"),) + azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") + + if all([azure_key, azure_endpoint, azure_version]): + print(f"Using Microsoft endpoint {azure_endpoint}.") + if all([azure_deployment, azure_embedding_deployment]): + print(f"Using deployment id {azure_deployment}") + use_azure_deployment_ids = True + + # configure openai + openai.api_type = "azure" + openai.api_key = azure_key + openai.api_base = azure_endpoint + openai.api_version = azure_version + else: + print("Missing enviornment variables for Azure. Please set then run `memgpt configure` again.") + # TODO: allow for manual setting + use_azure = False + + # TODO: configure local model + + # configure provider + use_local = not use_openai and os.getenv("OPENAI_API_BASE") + endpoint_options = [] + if os.getenv("OPENAI_API_BASE") is not None: + endpoint_options.append(os.getenv("OPENAI_API_BASE")) + if os.getenv("AZURE_ENDPOINT") is not None: + endpoint_options += ["azure"] + if use_openai: + endpoint_options += ["openai"] + + assert len(endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE." + if len(endpoint_options) == 1: + default_endpoint = endpoint_options[0] + else: + default_endpoint = questionary.select("Select default endpoint:", endpoint_options).ask() + + # configure preset + default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask() + + # default model + if use_openai or use_azure: + model_options = [] + if use_openai: + model_options += ["gpt-3.5-turbo", "gpt-3.5", "gpt-4"] + default_model = questionary.select( + "Select default model (recommended: gpt-4):", choices=["gpt-3.5-turbo", "gpt-3.5", "gpt-4"], default="gpt-4" + ).ask() + else: + default_model = "local" # TODO: figure out if this is ok? this is for local endpoint + + # defaults + personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()] + print(personas) + default_persona = questionary.select("Select default persona:", personas, default="sam_pov").ask() + humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()] + print(humans) + default_human = questionary.select("Select default human:", humans, default="cs_phd").ask() + + # TODO: figure out if we should set a default agent or not + default_agent = None + # agents = [os.path.basename(f).replace(".json", "") for f in utils.list_agent_config_files()] + # if len(agents) > 0: # agents have been created + # default_agent = questionary.select( + # "Select default agent:", + # agents + # ).ask() + # else: + # default_agent = None + + # TODO: allow configuring embedding model + + config = MemGPTConfig( + model=default_model, + preset=default_preset, + model_endpoint=default_endpoint, + default_persona=default_persona, + default_human=default_human, + default_agent=default_agent, + openai_key=openai_key if use_openai else None, + azure_key=azure_key if use_azure else None, + azure_endpoint=azure_endpoint if use_azure else None, + azure_version=azure_version if use_azure else None, + azure_deployment=azure_deployment if use_azure_deployment_ids else None, + azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None, + ) + print(f"Saving config to {config.config_path}") + config.save() + + +@app.command() +def list(option: str): + if option == "agents": + """List all agents""" + table = PrettyTable() + table.field_names = ["Name", "Model", "Persona", "Human", "Data Source"] + for agent_file in utils.list_agent_config_files(): + agent_name = os.path.basename(agent_file).replace(".json", "") + agent_config = AgentConfig.load(agent_name) + table.add_row([agent_name, agent_config.model, agent_config.persona, agent_config.human, agent_config.data_source]) + print(table) + elif option == "humans": + """List all humans""" + table = PrettyTable() + table.field_names = ["Name", "Text"] + for human_file in utils.list_human_files(): + text = open(human_file, "r").read() + name = os.path.basename(human_file).replace("txt", "") + table.add_row([name, text]) + print(table) + elif option == "personas": + """List all personas""" + table = PrettyTable() + table.field_names = ["Name", "Text"] + for persona_file in utils.list_persona_files(): + print(persona_file) + text = open(persona_file, "r").read() + name = os.path.basename(persona_file).replace(".txt", "") + table.add_row([name, text]) + print(table) + elif option == "sources": + """List all data sources""" + table = PrettyTable() + table.field_names = ["Name", "Create Time", "Agents"] + for data_source_file in os.listdir(os.path.join(MEMGPT_DIR, "archival")): + name = os.path.basename(data_source_file) + table.add_row([name, "TODO", "TODO"]) + print(table) + else: + raise ValueError(f"Unknown option {option}") + + +@app.command() +def add( + option: str, # [human, persona] + name: str = typer.Option(help="Name of human/persona"), + text: str = typer.Option(None, help="Text of human/persona"), + filename: str = typer.Option(None, "-f", help="Specify filename"), +): + """Add a person/human""" + + if option == "persona": + directory = os.path.join(MEMGPT_DIR, "personas") + elif option == "human": + directory = os.path.join(MEMGPT_DIR, "humans") + else: + raise ValueError(f"Unknown kind {kind}") + + if filename: + assert text is None, f"Cannot provide both filename and text" + # copy file to directory + shutil.copyfile(filename, os.path.join(directory, name)) + if text: + assert filename is None, f"Cannot provide both filename and text" + # write text to file + with open(os.path.join(directory, name), "w") as f: + f.write(text) diff --git a/memgpt/connectors/connector.py b/memgpt/cli/cli_load.py similarity index 95% rename from memgpt/connectors/connector.py rename to memgpt/cli/cli_load.py index 4b4c399a..9bbb112f 100644 --- a/memgpt/connectors/connector.py +++ b/memgpt/cli/cli_load.py @@ -8,12 +8,8 @@ memgpt load --name [ADDITIONAL ARGS] """ -from llama_index import download_loader from typing import List -import os import typer -from memgpt.constants import MEMGPT_DIR -from memgpt.utils import estimate_openai_cost, get_index, save_index app = typer.Typer() @@ -26,6 +22,7 @@ def load_directory( recursive: bool = typer.Option(False, help="Recursively search for files in directory."), ): from llama_index import SimpleDirectoryReader + from memgpt.utils import get_index, save_index if recursive: assert input_dir is not None, "Must provide input directory if recursive is True." @@ -53,6 +50,7 @@ def load_webpage( urls: List[str] = typer.Option(None, help="List of urls to load."), ): from llama_index import SimpleWebPageReader + from memgpt.utils import get_index, save_index docs = SimpleWebPageReader(html_to_text=True).load_data(urls) @@ -76,6 +74,7 @@ def load_database( dbname: str = typer.Option(None, help="Database name."), ): from llama_index.readers.database import DatabaseReader + from memgpt.utils import get_index, save_index print(dump_path, scheme) diff --git a/memgpt/config.py b/memgpt/config.py index cbed4897..1945d37c 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -1,7 +1,12 @@ import glob +import random +import string import json import os +import uuid import textwrap +from dataclasses import dataclass +import configparser import questionary @@ -15,6 +20,11 @@ import memgpt.interface as interface from memgpt.personas.personas import get_persona_text from memgpt.humans.humans import get_human_text from memgpt.constants import MEMGPT_DIR +import memgpt.constants as constants +import memgpt.personas.personas as personas +import memgpt.humans.humans as humans +from memgpt.presets import DEFAULT_PRESET, preset_options + model_choices = [ questionary.Choice("gpt-4"), @@ -25,6 +35,245 @@ model_choices = [ ] +@dataclass +class MemGPTConfig: + config_path: str = f"{MEMGPT_DIR}/config" + anon_clientid: str = None + + # preset + preset: str = DEFAULT_PRESET + + # model parameters + # provider: str = "openai" # openai, azure, local (TODO) + model_endpoint: str = "openai" + model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local + + # model parameters: openai + openai_key: str = None + + # model parameters: azure + azure_key: str = None + azure_endpoint: str = None + azure_version: str = None + azure_deployment: str = None + azure_embedding_deployment: str = None + + # persona parameters + default_persona: str = personas.DEFAULT + default_human: str = humans.DEFAULT + default_agent: str = None + + # embedding parameters + embedding_model: str = "openai" + embedding_dim: int = 768 + embedding_chunk_size: int = 300 # number of tokens + + # database configs: archival + archival_storage_type: str = "local" # local, db + archival_storage_path: str = None # TODO: set to memgpt dir + archival_storage_uri: str = None # TODO: eventually allow external vector DB + + # database configs: recall + recall_storage_type: str = "local" # local, db + recall_storage_path: str = None # TODO: set to memgpt dir + recall_storage_uri: str = None # TODO: eventually allow external vector DB + + # database configs: agent state + persistence_manager_type: str = None # in-memory, db + persistence_manager_save_file: str = None # local file + persistence_manager_uri: str = None # db URI + + @staticmethod + def generate_uuid() -> str: + return uuid.UUID(int=uuid.getnode()).hex + + @classmethod + def load(cls) -> "MemGPTConfig": + config = configparser.ConfigParser() + if os.path.exists(MemGPTConfig.config_path): + config.read(MemGPTConfig.config_path) + + # read config values + model = config.get("defaults", "model") + preset = config.get("defaults", "preset") + model_endpoint = config.get("defaults", "model_endpoint") + default_persona = config.get("defaults", "persona") + default_human = config.get("defaults", "human") + default_agent = config.get("defaults", "agent") if config.has_option("defaults", "agent") else None + + openai_key, openai_model = None, None + if "openai" in config: + openai_key = config.get("openai", "key") + + azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = None, None, None, None, None + if "azure" in config: + azure_key = config.get("azure", "key") + azure_endpoint = config.get("azure", "endpoint") + azure_version = config.get("azure", "version") + azure_deployment = config.get("azure", "deployment") + azure_embedding_deployment = config.get("azure", "embedding_deployment") + + embedding_model = config.get("embedding", "model") + embedding_dim = config.getint("embedding", "dim") + embedding_chunk_size = config.getint("embedding", "chunk_size") + + anon_clientid = config.get("client", "anon_clientid") + + return cls( + model=model, + preset=preset, + model_endpoint=model_endpoint, + default_persona=default_persona, + default_human=default_human, + default_agent=default_agent, + openai_key=openai_key, + azure_key=azure_key, + azure_endpoint=azure_endpoint, + azure_version=azure_version, + azure_deployment=azure_deployment, + azure_embedding_deployment=azure_embedding_deployment, + embedding_model=embedding_model, + embedding_dim=embedding_dim, + embedding_chunk_size=embedding_chunk_size, + anon_clientid=anon_clientid, + ) + + anon_clientid = MemGPTConfig.generate_uuid() + config = cls(anon_clientid=anon_clientid) + config.save() # save updated config + return config + + def save(self): + config = configparser.ConfigParser() + + # CLI defaults + config.add_section("defaults") + config.set("defaults", "model", self.model) + config.set("defaults", "preset", self.preset) + assert self.model_endpoint is not None, "Endpoint must be set" + config.set("defaults", "model_endpoint", self.model_endpoint) + config.set("defaults", "persona", self.default_persona) + config.set("defaults", "human", self.default_human) + if self.default_agent: + config.set("defaults", "agent", self.default_agent) + + # security credentials + if self.openai_key: + config.add_section("openai") + config.set("openai", "key", self.openai_key) + + if self.azure_key: + config.add_section("azure") + config.set("azure", "key", self.azure_key) + config.set("azure", "endpoint", self.azure_endpoint) + config.set("azure", "version", self.azure_version) + config.set("azure", "deployment", self.azure_deployment) + config.set("azure", "embedding_deployment", self.azure_embedding_deployment) + + # embeddings + config.add_section("embedding") + config.set("embedding", "model", self.embedding_model) + config.set("embedding", "dim", str(self.embedding_dim)) + config.set("embedding", "chunk_size", str(self.embedding_chunk_size)) + + # client + config.add_section("client") + if not self.anon_clientid: + self.anon_clientid = self.generate_uuid() + config.set("client", "anon_clientid", self.anon_clientid) + + with open(self.config_path, "w") as f: + config.write(f) + + @staticmethod + def exists(): + return os.path.exists(MemGPTConfig.config_path) + + @staticmethod + def create_config_dir(): + if not os.path.exists(MEMGPT_DIR): + os.makedirs(MEMGPT_DIR, exist_ok=True) + + folders = ["personas", "humans", "archival", "agents"] + for folder in folders: + if not os.path.exists(os.path.join(MEMGPT_DIR, folder)): + os.makedirs(os.path.join(MEMGPT_DIR, folder)) + + +@dataclass +class AgentConfig: + """ + Configuration for a specific instance of an agent + """ + + def __init__(self, persona, human, model, preset=DEFAULT_PRESET, name=None, data_source=None, agent_config_path=None, create_time=None): + if name is None: + self.name = f"agent_{self.generate_agent_id()}" + else: + self.name = name + self.persona = persona + self.human = human + self.model = model + self.preset = preset + self.data_source = data_source + self.create_time = create_time if create_time is not None else utils.get_local_time() + + # save agent config + self.agent_config_path = ( + os.path.join(MEMGPT_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path + ) + # assert not os.path.exists(self.agent_config_path), f"Agent config file already exists at {self.agent_config_path}" + self.save() + + def generate_agent_id(self, length=6): + ## random character based + # characters = string.ascii_lowercase + string.digits + # return ''.join(random.choices(characters, k=length)) + + # count based + agent_count = len(utils.list_agent_config_files()) + return str(agent_count + 1) + + def attach_data_source(self, data_source: str): + # TODO: add warning that only once source can be attached + # i.e. previous source will be overriden + self.data_source = data_source + self.save() + + def save_state_dir(self): + # directory to save agent state + return os.path.join(MEMGPT_DIR, "agents", self.name, "agent_state") + + def save_persistence_manager_dir(self): + # directory to save persistent manager state + return os.path.join(MEMGPT_DIR, "agents", self.name, "persistence_manager") + + def save_agent_index_dir(self): + # save llama index inside of persistent manager directory + return os.path.join(self.save_persistence_manager_dir(), "index") + + def save(self): + # save state of persistence manager + os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True) + with open(self.agent_config_path, "w") as f: + json.dump(vars(self), f, indent=4) + + @staticmethod + def exists(name: str): + """Check if agent config exists""" + agent_config_path = os.path.join(MEMGPT_DIR, "agents", name) + return os.path.exists(agent_config_path) + + @classmethod + def load(cls, name: str): + """Load agent config from JSON file""" + agent_config_path = os.path.join(MEMGPT_DIR, "agents", name, "config.json") + assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}" + with open(agent_config_path, "r") as f: + agent_config = json.load(f) + return cls(**agent_config) + + class Config: personas_dir = os.path.join("memgpt", "personas", "examples") custom_personas_dir = os.path.join(MEMGPT_DIR, "personas") diff --git a/memgpt/embeddings.py b/memgpt/embeddings.py new file mode 100644 index 00000000..af95c6b0 --- /dev/null +++ b/memgpt/embeddings.py @@ -0,0 +1,32 @@ +from memgpt.config import MemGPTConfig +import typer +from llama_index.embeddings import OpenAIEmbedding + + +def embedding_model(config: MemGPTConfig): + # TODO: use embedding_endpoint in the future + if config.model_endpoint == "openai": + return OpenAIEmbedding() + elif config.model_endpoint == "azure": + return OpenAIEmbedding( + model="text-embedding-ada-002", + deployment_name=config.azure_embedding_deployment, + api_key=config.azure_key, + api_base=config.azure_endpoint, + api_type="azure", + api_version=config.azure_version, + ) + else: + # default to hugging face model + from llama_index.embeddings import HuggingFaceEmbedding + + model = "BAAI/bge-small-en-v1.5" + typer.secho( + f"Warning: defaulting to HuggingFace embedding model {model} since model endpoint is not OpenAI or Azure.", + fg=typer.colors.YELLOW, + ) + typer.secho(f"Warning: ensure torch and transformers are installed") + # return f"local:{model}" + + # loads BAAI/bge-small-en-v1.5 + return HuggingFaceEmbedding(model_name=model) diff --git a/memgpt/main.py b/memgpt/main.py index 5ae09e6f..15197ce1 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -1,4 +1,7 @@ import asyncio +import shutil +import configparser +import uuid import logging import glob import os @@ -9,6 +12,7 @@ import questionary import typer from rich.console import Console +from prettytable import PrettyTable console = Console() @@ -21,14 +25,17 @@ import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans from memgpt.persistence_manager import ( + LocalStateManager, InMemoryStateManager, InMemoryStateManagerWithPreloadedArchivalMemory, InMemoryStateManagerWithFaiss, ) - -from memgpt.config import Config +from memgpt.cli.cli import run +from memgpt.cli.cli_config import configure, list, add +from memgpt.cli.cli_load import app as load_app +from memgpt.config import Config, MemGPTConfig, AgentConfig from memgpt.constants import MEMGPT_DIR -from memgpt.connectors import connector +from memgpt.agent import AgentAsync from memgpt.openai_tools import ( configure_azure_support, check_azure_embeddings, @@ -37,7 +44,12 @@ from memgpt.openai_tools import ( import asyncio app = typer.Typer() -app.add_typer(connector.app, name="load") +app.command(name="run")(run) +app.command(name="configure")(configure) +app.command(name="list")(list) +app.command(name="add")(add) +# load data commands +app.add_typer(load_app, name="load") def clear_line(): @@ -111,7 +123,9 @@ def load(memgpt_agent, filename): @app.callback(invoke_without_command=True) # make default command -def run( +# @app.command("legacy-run") +def legacy_run( + ctx: typer.Context, persona: str = typer.Option(None, help="Specify persona"), human: str = typer.Option(None, help="Specify human"), model: str = typer.Option(constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"), @@ -144,6 +158,13 @@ def run( help="Use Azure OpenAI (requires additional environment variables)", ), # TODO: just pass in? ): + if ctx.invoked_subcommand is not None: + return + + typer.secho("Warning: Running legacy run command. Run `memgpt run` instead.", fg=typer.colors.RED, bold=True) + if not questionary.confirm("Continue with legacy CLI?", default=False).ask(): + return + loop = asyncio.get_event_loop() loop.run_until_complete( main( @@ -208,7 +229,7 @@ async def main( memgpt_persona = persona if memgpt_persona is None: memgpt_persona = ( - personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT, + personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT_PRESET, None, # represents the personas dir in pymemgpt package ) else: @@ -304,7 +325,8 @@ async def main( chosen_persona = cfg.memgpt_persona memgpt_agent = presets.use_preset( - presets.DEFAULT, + presets.DEFAULT_PRESET, + None, # no agent config to provide cfg.model, personas.get_persona_text(*chosen_persona), humans.get_human_text(*chosen_human), @@ -314,12 +336,6 @@ async def main( print_messages = memgpt.interface.print_messages await print_messages(memgpt_agent.messages) - counter = 0 - user_input = None - skip_next_user_input = False - user_message = None - USER_GOES_FIRST = first - if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner if not os.path.exists(cfg.archival_storage_files): print(f"File {cfg.archival_storage_files} does not exist") @@ -338,6 +354,17 @@ async def main( if load_save_file: load(memgpt_agent, cfg.agent_save_file) + # run agent loop + await run_agent_loop(memgpt_agent, first, no_verify, cfg, legacy=True) + + +async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, legacy=False): + counter = 0 + user_input = None + skip_next_user_input = False + user_message = None + USER_GOES_FIRST = first + # auto-exit for if "GITHUB_ACTIONS" in os.environ: return @@ -359,6 +386,10 @@ async def main( ).ask_async() clear_line() + # Gracefully exit on Ctrl-C/D + if user_input is None: + user_input = "/exit" + user_input = user_input.rstrip() if user_input.startswith("!"): @@ -373,30 +404,40 @@ async def main( # Handle CLI commands # Commands to not get passed as input to MemGPT if user_input.startswith("/"): - if user_input.lower() == "/exit": - # autosave - save(memgpt_agent=memgpt_agent, cfg=cfg) - break + if legacy: + # legacy agent save functions (TODO: eventually remove) + if user_input.lower() == "/exit": + # autosave + save(memgpt_agent=memgpt_agent, cfg=cfg) + break - elif user_input.lower() == "/savechat": - filename = utils.get_local_time().replace(" ", "_").replace(":", "_") - filename = f"{filename}.pkl" - directory = os.path.join(MEMGPT_DIR, "saved_chats") - try: - if not os.path.exists(directory): - os.makedirs(directory) - with open(os.path.join(directory, filename), "wb") as f: - pickle.dump(memgpt_agent.messages, f) - print(f"Saved messages to: {filename}") - except Exception as e: - print(f"Saving chat to {filename} failed with: {e}") - continue + elif user_input.lower() == "/savechat": + filename = utils.get_local_time().replace(" ", "_").replace(":", "_") + filename = f"{filename}.pkl" + directory = os.path.join(MEMGPT_DIR, "saved_chats") + try: + if not os.path.exists(directory): + os.makedirs(directory) + with open(os.path.join(directory, filename), "wb") as f: + pickle.dump(memgpt_agent.messages, f) + print(f"Saved messages to: {filename}") + except Exception as e: + print(f"Saving chat to {filename} failed with: {e}") + continue - elif user_input.lower() == "/save": - save(memgpt_agent=memgpt_agent, cfg=cfg) - continue + elif user_input.lower() == "/save": + save(memgpt_agent=memgpt_agent, cfg=cfg) + continue + else: + # updated agent save functions + if user_input.lower() == "/exit": + memgpt_agent.save() + break + elif user_input.lower() == "/save" or user_input.lower() == "/savechat": + memgpt_agent.save() + continue - elif user_input.lower() == "/load" or user_input.lower().startswith("/load "): + if user_input.lower() == "/load" or user_input.lower().startswith("/load "): command = user_input.strip().split() filename = command[1] if len(command) > 1 else None load(memgpt_agent=memgpt_agent, filename=filename) diff --git a/memgpt/memory.py b/memgpt/memory.py index 8de814b4..f401eece 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -9,15 +9,16 @@ from typing import Optional, List, Tuple from .constants import MESSAGE_SUMMARY_WARNING_TOKENS, MEMGPT_DIR from .utils import cosine_similarity, get_local_time, printd, count_tokens from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM +from memgpt import utils from .openai_tools import ( acompletions_with_backoff as acreate, async_get_embedding_with_backoff, get_embedding_with_backoff, completions_with_backoff as create, ) - from llama_index import ( VectorStoreIndex, + EmptyIndex, get_response_synthesizer, load_index_from_storage, StorageContext, @@ -640,36 +641,67 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): class LocalArchivalMemory(ArchivalMemory): """Archival memory built on top of Llama Index""" - def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100): + def __init__(self, agent_config, top_k: Optional[int] = 100): """Init function for archival memory :param archiva_memory_database: name of dataset to pre-fill archival with :type archival_memory_database: str """ - if archival_memory_database is not None: - # TODO: load form ~/.memgpt/archival - directory = f"{MEMGPT_DIR}/archival/{archival_memory_database}" - assert os.path.exists(directory), f"Archival memory database {archival_memory_database} does not exist" + self.top_k = top_k + self.agent_config = agent_config + + # locate saved index + if self.agent_config.data_source is not None: # connected data source + directory = f"{MEMGPT_DIR}/archival/{self.agent_config.data_source}" + assert os.path.exists(directory), f"Archival memory database {self.agent_config.data_source} does not exist" + elif self.agent_config.name is not None: + directory = agent_config.save_agent_index_dir() + if not os.path.exists(directory): + # no existing archival storage + directory = None + + # load/create index + if directory: storage_context = StorageContext.from_defaults(persist_dir=directory) self.index = load_index_from_storage(storage_context) else: - self.index = VectorStoreIndex() - self.top_k = top_k - self.retriever = VectorIndexRetriever( - index=self.index, # does this get refreshed? - similarity_top_k=self.top_k, - ) + self.index = EmptyIndex() + + # create retriever + if isinstance(self.index, EmptyIndex): + self.retriever = None # cant create retriever over empty indes + else: + self.retriever = VectorIndexRetriever( + index=self.index, # does this get refreshed? + similarity_top_k=self.top_k, + ) + # TODO: have some mechanism for cleanup otherwise will lead to OOM self.cache = {} - def insert(self, memory_string): + def save(self): + """Save the index to disk""" + if self.agent_config.data_source: # update original archival index + # TODO: this corrupts the originally loaded data. do we want to do this? + utils.save_index(self.index, self.agent_config.data_source) + else: + utils.save_agent_index(self.index, self.agent_config) + + async def insert(self, memory_string): self.index.insert(memory_string) - async def a_insert(self, memory_string): - return self.insert(memory_string) + # TODO: figure out if this needs to be refreshed (probably not) + self.retriever = VectorIndexRetriever( + index=self.index, + similarity_top_k=self.top_k, + ) + + async def search(self, query_string, count=None, start=None): + if self.retriever is None: + print("Warning: archival memory is empty") + return [], 0 - def search(self, query_string, count=None, start=None): start = start if start else 0 count = count if count else self.top_k count = min(count + start, self.top_k) diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index 2c38d8c1..51d3b6bf 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod +import os import pickle - +from memgpt.config import AgentConfig from .memory import ( DummyRecallMemory, DummyRecallMemoryWithEmbeddings, @@ -107,22 +108,33 @@ class LocalStateManager(PersistenceManager): recall_memory_cls = DummyRecallMemory archival_memory_cls = LocalArchivalMemory - def __init__(self, archival_memory_db=None): + def __init__(self, agent_config: AgentConfig): # Memory held in-state useful for debugging stateful versions self.memory = None self.messages = [] self.all_messages = [] - self.archival_memory = LocalArchivalMemory(archival_memory_database=archival_memory_db) + self.archival_memory = LocalArchivalMemory(agent_config=agent_config) + self.agent_config = agent_config @staticmethod - def load(filename): + def load(filename, agent_config: AgentConfig): + """ Load a LocalStateManager from a file. """ "" with open(filename, "rb") as f: - return pickle.load(f) + manager = pickle.load(f) + + manager.archival_memory = LocalArchivalMemory(agent_config=agent_config) + return manager def save(self, filename): with open(filename, "wb") as fh: + # TODO: fix this hacky solution to pickle the retriever + self.archival_memory.save() + self.archival_memory = None pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL) + # re-load archival (TODO: dont do this) + self.archival_memory = LocalArchivalMemory(agent_config=self.agent_config) + def init(self, agent): printd(f"Initializing InMemoryStateManager with agent object") self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] diff --git a/memgpt/presets.py b/memgpt/presets.py index 76ff8fae..0fce2928 100644 --- a/memgpt/presets.py +++ b/memgpt/presets.py @@ -1,16 +1,17 @@ from .prompts import gpt_functions from .prompts import gpt_system -from .agent import AgentAsync -from .utils import printd + +DEFAULT_PRESET = "memgpt_chat" +preset_options = [DEFAULT_PRESET] -DEFAULT = "memgpt_chat" - - -def use_preset(preset_name, model, persona, human, interface, persistence_manager): +def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): """Storing combinations of SYSTEM + FUNCTION prompts""" - if preset_name == "memgpt_chat": + from memgpt.agent import AgentAsync + from memgpt.utils import printd + + if preset_name == DEFAULT_PRESET: functions = [ "send_message", "pause_heartbeats", @@ -30,6 +31,7 @@ def use_preset(preset_name, model, persona, human, interface, persistence_manage preset_name = "memgpt_gpt35_extralong" return AgentAsync( + config=agent_config, model=model, system=gpt_system.get_system_text(preset_name), functions=available_functions, diff --git a/memgpt/utils.py b/memgpt/utils.py index d2fbc9d5..e2146bc7 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -14,6 +14,7 @@ import sqlite3 import fitz from tqdm import tqdm import typer +import memgpt from memgpt.openai_tools import async_get_embedding_with_backoff from memgpt.constants import MEMGPT_DIR from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext @@ -364,6 +365,10 @@ def get_index(name, docs): :param docs: Documents to be embedded :type docs: List[Document] """ + from memgpt.config import MemGPTConfig # avoid circular import + + # TODO: configure to work for local + print("Warning: get_index(docs) only supported for OpenAI") # check if directory exists dir = f"{MEMGPT_DIR}/archival/{name}" @@ -387,17 +392,36 @@ def get_index(name, docs): typer.secho("Aborting.", fg="red") exit() - embed_model = OpenAIEmbedding() - service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=300) - set_global_service_context(service_context) + # read embedding confirguration + # TODO: in the future, make an IngestData class that loads the config once + # config = MemGPTConfig.load() + # chunk_size = config.embedding_chunk_size + # model = config.embedding_model # TODO: actually use this + # dim = config.embedding_dim # TODO: actually use this + # embed_model = OpenAIEmbedding() + # service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=chunk_size) + # set_global_service_context(service_context) # index documents index = VectorStoreIndex.from_documents(docs) return index +def save_agent_index(index, agent_config): + """Save agent index inside of ~/.memgpt/agents/ + + :param index: Index to save + :type index: VectorStoreIndex + :param agent_name: Name of agent that the archival memory belonds to + :type agent_name: str + """ + dir = agent_config.save_agent_index_dir() + os.makedirs(dir, exist_ok=True) + index.storage_context.persist(dir) + + def save_index(index, name): - """Save index to a specificed name in ~/.memgpt + """Save index ~/.memgpt/archival/ to load into agents :param index: Index to save :type index: VectorStoreIndex @@ -422,3 +446,34 @@ def save_index(index, name): os.makedirs(dir, exist_ok=True) index.storage_context.persist(dir) print(dir) + + +def list_agent_config_files(): + """List all agents config files""" + return os.listdir(os.path.join(MEMGPT_DIR, "agents")) + + +def list_human_files(): + """List all humans files""" + defaults_dir = os.path.join(memgpt.__path__[0], "humans", "examples") + user_dir = os.path.join(MEMGPT_DIR, "humans") + + memgpt_defaults = os.listdir(defaults_dir) + memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")] + + user_added = os.listdir(user_dir) + user_added = [os.path.join(user_dir, f) for f in user_added] + return memgpt_defaults + user_added + + +def list_persona_files(): + """List all personas files""" + defaults_dir = os.path.join(memgpt.__path__[0], "personas", "examples") + user_dir = os.path.join(MEMGPT_DIR, "personas") + + memgpt_defaults = os.listdir(defaults_dir) + memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")] + + user_added = os.listdir(user_dir) + user_added = [os.path.join(user_dir, f) for f in user_added] + return memgpt_defaults + user_added diff --git a/poetry.lock b/poetry.lock index f50ca85f..fd2c27b9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1337,6 +1337,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "prettytable" +version = "3.9.0" +description = "A simple Python library for easily displaying tabular data in a visually appealing ASCII table format" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "prettytable-3.9.0-py3-none-any.whl", hash = "sha256:a71292ab7769a5de274b146b276ce938786f56c31cf7cea88b6f3775d82fe8c8"}, + {file = "prettytable-3.9.0.tar.gz", hash = "sha256:f4ed94803c23073a90620b201965e5dc0bccf1760b7a7eaf3158cab8aaffdf34"}, +] + +[package.dependencies] +wcwidth = "*" + +[package.extras] +tests = ["pytest", "pytest-cov", "pytest-lazy-fixture"] + [[package]] name = "prompt-toolkit" version = "3.0.36" @@ -2494,4 +2512,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "72cadac0b6c167e5b890c7062e7f163e4976b29a0083b6109c1c3a8f5bb02d25" +content-hash = "13e73dc8fe9e19792903e9659b53cfa24bcc7abbd73fd0d691daeb22136cdaa0" diff --git a/pyproject.toml b/pyproject.toml index 413811a7..932748d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ pytest = "^7.4.3" llama-index = "^0.8.53.post3" setuptools = "^68.2.2" datasets = "^2.14.6" +prettytable = "^3.9.0" [build-system] diff --git a/requirements-local.txt b/requirements-local.txt new file mode 100644 index 00000000..18eb5dfa --- /dev/null +++ b/requirements-local.txt @@ -0,0 +1,3 @@ +transformers==4.34.1 +huggingface-hub==0.17.3 +torch==2.1.0 diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 6803e7d0..ba89989d 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -1,21 +1,18 @@ import tempfile import asyncio import os -from memgpt.connectors.connector import load_directory, load_database, load_webpage -import memgpt.agent as agent -import memgpt.system as system -import memgpt.utils as utils +import asyncio +from datasets import load_dataset + +import memgpt +from memgpt.cli.cli_load import load_directory, load_database, load_webpage import memgpt.presets as presets -import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager -from memgpt.config import Config +from memgpt.config import AgentConfig from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL -from memgpt.connectors import connector import memgpt.interface # for printing to terminal -import asyncio -from datasets import load_dataset def test_load_directory(): @@ -37,12 +34,21 @@ def test_load_directory(): recursive=True, ) + # create agents with defaults + agent_config = AgentConfig( + persona=personas.DEFAULT, + human=humans.DEFAULT, + model=DEFAULT_MEMGPT_MODEL, + data_source="tmp_hf_dataset", + ) + # create state manager based off loaded data - persistence_manager = LocalStateManager(archival_memory_db="tmp_hf_dataset") + persistence_manager = LocalStateManager(agent_config=agent_config) # create agent memgpt_agent = presets.use_preset( - presets.DEFAULT, + presets.DEFAULT_PRESET, + agent_config, DEFAULT_MEMGPT_MODEL, personas.get_persona_text(personas.DEFAULT), humans.get_human_text(humans.DEFAULT), @@ -92,11 +98,21 @@ def test_load_database(): query=f"SELECT * FROM {list(table_names)[0]}", ) - persistence_manager = LocalStateManager(archival_memory_db="tmp_db_dataset") + # create agents with defaults + agent_config = AgentConfig( + persona=personas.DEFAULT_PRESET, + human=humans.DEFAULT, + model=DEFAULT_MEMGPT_MODEL, + data_source="tmp_hf_dataset", + ) + + # create state manager based off loaded data + persistence_manager = LocalStateManager(agent_config=agent_config) # create agent memgpt_agent = presets.use_preset( presets.DEFAULT, + agent_config, DEFAULT_MEMGPT_MODEL, personas.get_persona_text(personas.DEFAULT), humans.get_human_text(humans.DEFAULT),