diff --git a/memgpt/data_types.py b/memgpt/data_types.py index c45192b1..00ab845a 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -123,9 +123,13 @@ class Message(Record): openai_message_dict: dict, model: Optional[str] = None, # model used to make function call allow_functions_style: bool = False, # allow deprecated functions style? + created_at: Optional[datetime] = None, ): """Convert a ChatCompletion message object into a Message object (synced to DB)""" + assert "role" in openai_message_dict, openai_message_dict + assert "content" in openai_message_dict, openai_message_dict + # If we're going from deprecated function form if openai_message_dict["role"] == "function": if not allow_functions_style: @@ -135,6 +139,7 @@ class Message(Record): # Convert from 'function' response to a 'tool' response # NOTE: this does not conventionally include a tool_call_id, it's on the caster to provide it return Message( + created_at=created_at, user_id=user_id, agent_id=agent_id, model=model, @@ -166,6 +171,7 @@ class Message(Record): ] return Message( + created_at=created_at, user_id=user_id, agent_id=agent_id, model=model, @@ -197,6 +203,7 @@ class Message(Record): # If we're going from tool-call style return Message( + created_at=created_at, user_id=user_id, agent_id=agent_id, model=model, diff --git a/memgpt/migrate.py b/memgpt/migrate.py index da70830e..066c6fd9 100644 --- a/memgpt/migrate.py +++ b/memgpt/migrate.py @@ -1,5 +1,5 @@ import configparser -import datetime +from datetime import datetime import os import pickle import glob @@ -8,6 +8,8 @@ import traceback import uuid import json import shutil +from typing import Optional +import pytz import typer from tqdm import tqdm @@ -21,10 +23,17 @@ from llama_index import ( from memgpt.agent import Agent from memgpt.data_types import AgentState, User, Passage, Source, Message from memgpt.metadata import MetadataStore -from memgpt.utils import MEMGPT_DIR, version_less_than, OpenAIBackcompatUnpickler, annotate_message_json_list_with_tool_calls +from memgpt.utils import ( + MEMGPT_DIR, + version_less_than, + OpenAIBackcompatUnpickler, + annotate_message_json_list_with_tool_calls, + parse_formatted_time, +) from memgpt.config import MemGPTConfig from memgpt.cli.cli_config import configure from memgpt.agent_store.storage import StorageConnector, TableType +from memgpt.persistence_manager import PersistenceManager, LocalStateManager # This is the version where the breaking change was made VERSION_CUTOFF = "0.2.12" @@ -33,19 +42,19 @@ VERSION_CUTOFF = "0.2.12" MIGRATION_BACKUP_FOLDER = "migration_backups" -def wipe_config_and_reconfigure(run_configure=True): +def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True): """Wipe (backup) the config file, and launch `memgpt configure`""" - if not os.path.exists(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)): - os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)) - os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents")) + if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)): + os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)) + os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents")) # Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS) timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # Construct the new backup directory name with the timestamp - backup_filename = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}") - existing_filename = os.path.join(MEMGPT_DIR, "config") + backup_filename = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}") + existing_filename = os.path.join(data_dir, "config") # Check if the existing file exists before moving if os.path.exists(existing_filename): @@ -63,11 +72,11 @@ def wipe_config_and_reconfigure(run_configure=True): MemGPTConfig.load() -def config_is_compatible(allow_empty=False, echo=False) -> bool: +def config_is_compatible(data_dir: str = MEMGPT_DIR, allow_empty=False, echo=False) -> bool: """Check if the config is OK to use with 0.2.12, or if it needs to be deleted""" # NOTE: don't use built-in load(), since that will apply defaults # memgpt_config = MemGPTConfig.load() - memgpt_config_file = os.path.join(MEMGPT_DIR, "config") + memgpt_config_file = os.path.join(data_dir, "config") if not os.path.exists(memgpt_config_file): return True if allow_empty else False parser = configparser.ConfigParser() @@ -92,9 +101,9 @@ def config_is_compatible(allow_empty=False, echo=False) -> bool: return True -def agent_is_migrateable(agent_name: str) -> bool: +def agent_is_migrateable(agent_name: str, data_dir: str = MEMGPT_DIR) -> bool: """Determine whether or not the agent folder is a migration target""" - agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name) + agent_folder = os.path.join(data_dir, "agents", agent_name) if not os.path.exists(agent_folder): raise ValueError(f"Folder {agent_folder} does not exist") @@ -115,14 +124,14 @@ def agent_is_migrateable(agent_name: str) -> bool: return False -def migrate_source(source_name: str): +def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None): """ Migrate an old source folder (`~/.memgpt/sources/{source_name}`). """ # 1. Load the VectorIndex from ~/.memgpt/sources/{source_name}/index # TODO - source_path = os.path.join(MEMGPT_DIR, "archival", source_name, "nodes.pkl") + source_path = os.path.join(data_dir, "archival", source_name, "nodes.pkl") assert os.path.exists(source_path), f"Source {source_name} does not exist at {source_path}" # load state from old checkpoint file @@ -130,15 +139,21 @@ def migrate_source(source_name: str): # 2. Create a new AgentState using the agent config + agent internal state config = MemGPTConfig.load() + if ms is None: + ms = MetadataStore(config) # gets default user - ms = MetadataStore(config) user_id = uuid.UUID(config.anon_clientid) user = ms.get_user(user_id=user_id) if user is None: - raise ValueError( - f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents." - ) + ms.create_user(User(id=user_id)) + user = ms.get_user(user_id=user_id) + if user is None: + typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED) + sys.exit(1) + # raise ValueError( + # f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents." + # ) # insert source into metadata store source = Source(user_id=user.id, name=source_name) @@ -179,7 +194,7 @@ def migrate_source(source_name: str): assert source is not None, f"Failed to load source {source_name} from database after migration" -def migrate_agent(agent_name: str): +def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None): """Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`) Steps: @@ -191,7 +206,7 @@ def migrate_agent(agent_name: str): # 1. Load the agent state JSON from the old folder # TODO - agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name) + agent_folder = os.path.join(data_dir, "agents", agent_name) # migration_file = os.path.join(agent_folder, MIGRATION_FILE_NAME) # load state from old checkpoint file @@ -255,22 +270,45 @@ def migrate_agent(agent_name: str): # 2. Create a new AgentState using the agent config + agent internal state config = MemGPTConfig.load() + if ms is None: + ms = MetadataStore(config) # gets default user - ms = MetadataStore(config) user_id = uuid.UUID(config.anon_clientid) user = ms.get_user(user_id=user_id) if user is None: - raise ValueError( - f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents." - ) + ms.create_user(User(id=user_id)) + user = ms.get_user(user_id=user_id) + if user is None: + typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED) + sys.exit(1) + # raise ValueError( + # f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents." + # ) # ms.create_user(User(id=user_id)) # user = ms.get_user(user_id=user_id) # if user is None: # typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED) # sys.exit(1) + # create an agent_id ahead of time + agent_id = uuid.uuid4() + + # create all the Messages in the database + # message_objs = [] + # for message_dict in annotate_message_json_list_with_tool_calls(state_dict["messages"]): + # message_obj = Message.dict_to_message( + # user_id=user.id, + # agent_id=agent_id, + # openai_message_dict=message_dict, + # model=state_dict["model"] if "model" in state_dict else None, + # # allow_functions_style=False, + # allow_functions_style=True, + # ) + # message_objs.append(message_obj) + agent_state = AgentState( + id=agent_id, name=agent_config["name"], user_id=user.id, persona=agent_config["persona"], # eg 'sam_pov' @@ -281,18 +319,97 @@ def migrate_agent(agent_name: str): persona=state_dict["memory"]["persona"], system=state_dict["system"], functions=state_dict["functions"], # this shouldn't matter, since Agent.__init__ will re-link - messages=annotate_message_json_list_with_tool_calls(state_dict["messages"]), + # messages=[str(m.id) for m in message_objs], # this is a list of uuids, not message dicts ), - llm_config=user.default_llm_config, - embedding_config=user.default_embedding_config, + llm_config=config.default_llm_config, + embedding_config=config.default_embedding_config, ) + persistence_manager = LocalStateManager(agent_state=agent_state) + + # First clean up the recall message history to add tool call ids + full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]]) + for i in range(len(data["all_messages"])): + data["all_messages"][i]["message"] = full_message_history_buffer[i] + + # Figure out what messages in recall are in-context, and which are out-of-context + agent_message_cache = state_dict["messages"] + recall_message_full = data["all_messages"] + + def messages_are_equal(msg1, msg2): + return msg1["role"] == msg2["role"] and msg1["content"] == msg2["content"] + + in_context_messages = [] + out_of_context_messages = [] + assert len(agent_message_cache) <= len(recall_message_full), (len(agent_message_cache), len(recall_message_full)) + for d in recall_message_full: + # unpack into "timestamp" and "message" + recall_message = d["message"] + recall_timestamp = d["timestamp"] + try: + recall_datetime = parse_formatted_time(recall_timestamp).astimezone(pytz.utc) + except ValueError: + recall_datetime = datetime.strptime(recall_timestamp, "%Y-%m-%d %I:%M:%S %p").astimezone(pytz.utc) + + # message object + message_obj = Message.dict_to_message( + created_at=recall_datetime, + user_id=user.id, + agent_id=agent_id, + openai_message_dict=recall_message, + allow_functions_style=True, + ) + + # message is either in-context, or out-of-context + message_is_in_context = [messages_are_equal(recall_message, cache_message) for cache_message in agent_message_cache] + assert sum(message_is_in_context) <= 1, message_is_in_context + + if any(message_is_in_context): + in_context_messages.append(message_obj) + else: + out_of_context_messages.append(message_obj) + + assert len(in_context_messages) > 0 + assert len(in_context_messages) == len(agent_message_cache), (len(in_context_messages), len(agent_message_cache)) + # assert ( + # len(in_context_messages) + len(out_of_context_messages) == state_dict["messages_total"] + # ), f"{len(in_context_messages)} + {len(out_of_context_messages)} != {state_dict['messages_total']}" + + # Now we can insert the messages into the actual recall database + # So when we construct the agent from the state, they will be available + persistence_manager.recall_memory.insert_many(out_of_context_messages) + persistence_manager.recall_memory.insert_many(in_context_messages) + + # Overwrite the agent_state message object + agent_state.state["messages"] = [str(m.id) for m in in_context_messages] # this is a list of uuids, not message dicts + + ## 4. Insert into recall + # TODO should this be 'messages', or 'all_messages'? + # all_messages in recall will have fields "timestamp" and "message" + # full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]]) + # We want to keep the timestamp + # for i in range(len(data["all_messages"])): + # data["all_messages"][i]["message"] = full_message_history_buffer[i] + # messages_to_insert = [ + # Message.dict_to_message( + # user_id=user.id, + # agent_id=agent_id, + # openai_message_dict=msg, + # allow_functions_style=True, + # ) + # # for msg in data["all_messages"] + # for msg in full_message_history_buffer + # ] + # agent.persistence_manager.recall_memory.insert_many(messages_to_insert) + # print("Finished migrating recall memory") + # 3. Instantiate a new Agent by passing AgentState to Agent.__init__ # NOTE: the Agent.__init__ will trigger a save, which will write to the DB try: agent = Agent( agent_state=agent_state, - messages_total=state_dict["messages_total"], # TODO: do we need this? + # messages_total=state_dict["messages_total"], # TODO: do we need this? + messages_total=len(in_context_messages) + len(out_of_context_messages), interface=None, ) except Exception as e: @@ -308,17 +425,6 @@ def migrate_agent(agent_name: str): # Wrap the rest in a try-except so that we can cleanup by deleting the agent if we fail try: - ## 4. Insert into recall - # TODO should this be 'messages', or 'all_messages'? - # all_messages in recall will have fields "timestamp" and "message" - full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]]) - # We want to keep the timestamp - for i in range(len(data["all_messages"])): - data["all_messages"][i]["message"] = full_message_history_buffer[i] - messages_to_insert = [Message.dict_to_message(msg, allow_functions_style=True) for msg in data["all_messages"]] - agent.persistence_manager.recall_memory.insert_many(messages_to_insert) - # print("Finished migrating recall memory") - # TODO should we also assign data["messages"] to RecallMemory.messages? # 5. Insert into archival @@ -350,7 +456,7 @@ def migrate_agent(agent_name: str): raise try: - new_agent_folder = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents", agent_name) + new_agent_folder = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents", agent_name) shutil.move(agent_folder, new_agent_folder) except Exception as e: print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}") @@ -358,20 +464,20 @@ def migrate_agent(agent_name: str): # def migrate_all_agents(stop_on_fail=True): -def migrate_all_agents(stop_on_fail: bool = False) -> dict: - """Scan over all agent folders in MEMGPT_DIR and migrate each agent.""" +def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict: + """Scan over all agent folders in data_dir and migrate each agent.""" - if not os.path.exists(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)): - os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)) - os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents")) + if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)): + os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)) + os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents")) - if not config_is_compatible(echo=True): + if not config_is_compatible(data_dir, echo=True): typer.secho(f"Your current config file is incompatible with MemGPT versions >= {VERSION_CUTOFF}", fg=typer.colors.RED) if questionary.confirm( "To migrate old MemGPT agents, you must delete your config file and run `memgpt configure`. Would you like to proceed?" ).ask(): try: - wipe_config_and_reconfigure() + wipe_config_and_reconfigure(data_dir) except Exception as e: typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED) raise @@ -379,7 +485,7 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict: typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED) raise KeyboardInterrupt() - agents_dir = os.path.join(MEMGPT_DIR, "agents") + agents_dir = os.path.join(data_dir, "agents") # Ensure the directory exists if not os.path.exists(agents_dir): @@ -392,13 +498,16 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict: count = 0 failures = [] candidates = [] + config = MemGPTConfig.load() + print(config) + ms = MetadataStore(config) try: for agent_name in tqdm(agent_folders, desc="Migrating agents"): # Assuming migrate_agent is a function that takes the agent name and performs migration try: - if agent_is_migrateable(agent_name=agent_name): + if agent_is_migrateable(agent_name=agent_name, data_dir=data_dir): candidates.append(agent_name) - migrate_agent(agent_name) + migrate_agent(agent_name, data_dir=data_dir, ms=ms) count += 1 else: continue @@ -423,6 +532,7 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict: if count > 0: typer.secho(f"✅ {count}/{len(candidates)} agents were successfully migrated to the new database format", fg=typer.colors.GREEN) + del ms return { "agent_folders": len(agent_folders), "migration_candidates": len(candidates), @@ -431,10 +541,10 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict: } -def migrate_all_sources(stop_on_fail: bool = False) -> dict: - """Scan over all agent folders in MEMGPT_DIR and migrate each agent.""" +def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict: + """Scan over all agent folders in data_dir and migrate each agent.""" - sources_dir = os.path.join(MEMGPT_DIR, "archival") + sources_dir = os.path.join(data_dir, "archival") # Ensure the directory exists if not os.path.exists(sources_dir): @@ -447,12 +557,14 @@ def migrate_all_sources(stop_on_fail: bool = False) -> dict: count = 0 failures = [] candidates = [] + config = MemGPTConfig.load() + ms = MetadataStore(config) try: for source_name in tqdm(source_folders, desc="Migrating data sources"): # Assuming migrate_agent is a function that takes the agent name and performs migration try: candidates.append(source_name) - migrate_source(source_name) + migrate_source(source_name, data_dir, ms=ms) count += 1 except Exception as e: failures.append({"name": source_name, "reason": str(e)}) @@ -475,6 +587,7 @@ def migrate_all_sources(stop_on_fail: bool = False) -> dict: if count > 0: typer.secho(f"✅ {count}/{len(candidates)} sources were successfully migrated to the new database format", fg=typer.colors.GREEN) + del ms return { "source_folders": len(source_folders), "migration_candidates": len(candidates), diff --git a/tests/data/memgpt-0.2.11/config b/tests/data/memgpt-0.2.11/config index 5262e0ec..9f48e447 100644 --- a/tests/data/memgpt-0.2.11/config +++ b/tests/data/memgpt-0.2.11/config @@ -9,20 +9,27 @@ model_endpoint = https://api.openai.com/v1 model_endpoint_type = openai context_window = 8192 -[openai] -key = FAKE_KEY - [embedding] embedding_endpoint_type = openai embedding_endpoint = https://api.openai.com/v1 +embedding_model = text-embedding-ada-002 embedding_dim = 1536 embedding_chunk_size = 300 [archival_storage] -type = local +type = chroma +path = /Users/sarahwooders/.memgpt/chroma + +[recall_storage] +type = sqlite +path = /Users/sarahwooders/.memgpt + +[metadata_storage] +type = sqlite +path = /Users/sarahwooders/.memgpt [version] -memgpt_version = 0.2.11 +memgpt_version = 0.2.12 [client] anon_clientid = 00000000000000000000d67f40108c5c diff --git a/tests/test_migrate.py b/tests/test_migrate.py new file mode 100644 index 00000000..d7be1727 --- /dev/null +++ b/tests/test_migrate.py @@ -0,0 +1,14 @@ +import os +from memgpt.migrate import migrate_all_agents, migrate_all_sources + + +def test_migrate_0211(): + data_dir = "tests/data/memgpt-0.2.11" + # os.environ["MEMGPT_CONFIG_PATH"] = os.path.join(data_dir, "config") + # print(f"MEMGPT_CONFIG_PATH={os.environ['MEMGPT_CONFIG_PATH']}") + res = migrate_all_agents(data_dir) + assert res["failed_migrations"] == 0, f"Failed migrations: {res}" + res = migrate_all_sources(data_dir) + assert res["failed_migrations"] == 0, f"Failed migrations: {res}" + + # TODO: assert everything is in the DB