feat: Migration command for importing old agents into new DB backend (#802)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Charles Packer
2024-01-11 14:57:21 -08:00
committed by GitHub
parent fd20285840
commit f118e01ad1
30 changed files with 819 additions and 98 deletions

View File

@@ -3,11 +3,13 @@ repos:
rev: v2.3.0
hooks:
- id: check-yaml
exclude: ^docs/
exclude: 'docs/.*|tests/data/.*'
- id: end-of-file-fixer
exclude: ^docs/
exclude: 'docs/.*|tests/data/.*'
- id: trailing-whitespace
exclude: ^docs/
exclude: 'docs/.*|tests/data/.*'
- id: end-of-file-fixer
exclude: 'docs/.*|tests/data/.*'
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:

View File

@@ -254,12 +254,11 @@ class SQLStorageConnector(StorageConnector):
return all_filters
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
session = self.Session()
offset = 0
filters = self.get_filters(filters)
while True:
# Retrieve a chunk of records with the given page_size
db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
db_record_chunk = self.session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
# If the chunk is empty, we've retrieved all records
if not db_record_chunk:
@@ -272,40 +271,35 @@ class SQLStorageConnector(StorageConnector):
offset += page_size
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]:
session = self.Session()
filters = self.get_filters(filters)
if limit:
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
db_records = self.session.query(self.db_model).filter(*filters).limit(limit).all()
else:
db_records = session.query(self.db_model).filter(*filters).all()
db_records = self.session.query(self.db_model).filter(*filters).all()
return [record.to_record() for record in db_records]
def get(self, id: str) -> Optional[Record]:
session = self.Session()
db_record = session.query(self.db_model).get(id)
db_record = self.session.query(self.db_model).get(id)
if db_record is None:
return None
return db_record.to_record()
def size(self, filters: Optional[Dict] = {}) -> int:
# return size of table
session = self.Session()
filters = self.get_filters(filters)
return session.query(self.db_model).filter(*filters).count()
return self.session.query(self.db_model).filter(*filters).count()
def insert(self, record: Record):
session = self.Session()
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
self.session.add(db_record)
self.session.commit()
def insert_many(self, records: List[Record], show_progress=False):
session = self.Session()
iterable = tqdm(records) if show_progress else records
for record in iterable:
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
self.session.add(db_record)
self.session.commit()
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
@@ -315,15 +309,13 @@ class SQLStorageConnector(StorageConnector):
def list_data_sources(self):
assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
session = self.Session()
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
unique_data_sources = self.session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
return unique_data_sources
def query_date(self, start_date, end_date, offset=0, limit=None):
session = self.Session()
filters = self.get_filters({})
query = (
session.query(self.db_model)
self.session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= start_date)
.filter(self.db_model.created_at <= end_date)
@@ -336,10 +328,12 @@ class SQLStorageConnector(StorageConnector):
def query_text(self, query, offset=0, limit=None):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
session = self.Session()
filters = self.get_filters({})
query = (
session.query(self.db_model).filter(*filters).filter(func.lower(self.db_model.text).contains(func.lower(query))).offset(offset)
self.session.query(self.db_model)
.filter(*filters)
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
.offset(offset)
)
if limit:
query = query.limit(limit)
@@ -348,16 +342,14 @@ class SQLStorageConnector(StorageConnector):
return [result.to_record() for result in results]
def delete_table(self):
session = self.Session()
close_all_sessions()
self.db_model.__table__.drop(session.bind)
session.commit()
self.db_model.__table__.drop(self.session.bind)
self.session.commit()
def delete(self, filters: Optional[Dict] = {}):
session = self.Session()
filters = self.get_filters(filters)
session.query(self.db_model).filter(*filters).delete()
session.commit()
self.session.query(self.db_model).filter(*filters).delete()
self.session.commit()
class PostgresStorageConnector(SQLStorageConnector):
@@ -393,13 +385,13 @@ class PostgresStorageConnector(SQLStorageConnector):
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
self.Session = sessionmaker(bind=self.engine)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
session_maker = sessionmaker(bind=self.engine)
self.session = session_maker()
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
filters = self.get_filters(filters)
results = session.scalars(
results = self.session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
@@ -429,7 +421,8 @@ class SQLLiteStorageConnector(SQLStorageConnector):
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite")
self.engine = create_engine(f"sqlite:///{self.path}")
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
self.Session = sessionmaker(bind=self.engine)
session_maker = sessionmaker(bind=self.engine)
self.session = session_maker()
import sqlite3

View File

@@ -28,6 +28,13 @@ from memgpt.embeddings import embedding_model
from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT
from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig, User
from memgpt.metadata import MetadataStore
from memgpt.migrate import migrate_all_agents, migrate_all_sources
def migrate():
"""Migrate old agents (pre 0.2.12) to the new database system"""
migrate_all_agents()
migrate_all_sources()
class QuickstartChoice(Enum):

View File

@@ -1,4 +1,5 @@
import builtins
from tqdm import tqdm
import uuid
import questionary
from prettytable import PrettyTable
@@ -61,6 +62,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
provider = questionary.select(
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
).ask()
if provider is None:
raise KeyboardInterrupt
# set: model_endpoint_type, model_endpoint
if provider == "openai":
@@ -75,6 +78,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
openai_api_key = questionary.text(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
if openai_api_key is None:
raise KeyboardInterrupt
config.openai_key = openai_api_key
config.save()
else:
@@ -85,6 +90,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
"Enter your OpenAI API key (hit enter to use existing key):",
default=default_input,
).ask()
if openai_api_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if openai_api_key != default_input:
config.openai_key = openai_api_key
@@ -93,6 +100,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
model_endpoint_type = "openai"
model_endpoint = "https://api.openai.com/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
provider = "openai"
elif provider == "azure":
@@ -122,6 +131,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
backend_options,
default=default_model_endpoint_type,
).ask()
if model_endpoint_type is None:
raise KeyboardInterrupt
# set default endpoint
# if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use
@@ -131,21 +142,33 @@ def configure_llm_endpoint(config: MemGPTConfig):
if model_endpoint_type in DEFAULT_ENDPOINTS:
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
while not utils.is_valid_url(model_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
elif config.model_endpoint:
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
while not utils.is_valid_url(model_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
else:
# default_model_endpoint = None
model_endpoint = None
model_endpoint = questionary.text("Enter default endpoint:").ask()
if model_endpoint is None:
raise KeyboardInterrupt
while not utils.is_valid_url(model_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = questionary.text("Enter default endpoint:").ask()
if model_endpoint is None:
raise KeyboardInterrupt
else:
model_endpoint = default_model_endpoint
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
@@ -185,6 +208,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
default=config.model if valid_model else hardcoded_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt
# If the user asked for the full list, show it
if model == see_all_option_str:
@@ -194,6 +219,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
choices=fetched_model_options + [other_option_str],
default=config.model if valid_model else fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt
# Finally if the user asked to manually input, allow it
if model == other_option_str:
@@ -202,6 +229,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
model = questionary.text(
"Enter custom model name:",
).ask()
if model is None:
raise KeyboardInterrupt
else: # local models
# ollama also needs model type
@@ -211,6 +240,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
"Enter default model name (required for Ollama, see: https://memgpt.readme.io/docs/ollama):",
default=default_model,
).ask()
if model is None:
raise KeyboardInterrupt
model = None if len(model) == 0 else model
default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""
@@ -234,6 +265,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
model = questionary.select(
"Select default model:", choices=model_options, default=config.model if valid_model else model_options[0]
).ask()
if model is None:
raise KeyboardInterrupt
# If we got custom input, ask for raw input
if model == other_option_str:
@@ -241,6 +274,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
if model is None:
raise KeyboardInterrupt
# TODO allow empty string for input?
model = None if len(model) == 0 else model
@@ -249,6 +284,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
if model is None:
raise KeyboardInterrupt
model = None if len(model) == 0 else model
# model wrapper
@@ -258,6 +295,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
choices=available_model_wrappers,
default=DEFAULT_WRAPPER_NAME,
).ask()
if model_wrapper is None:
raise KeyboardInterrupt
# set: context_window
if str(model) not in LLM_MAX_TOKENS:
@@ -275,11 +314,15 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask()
if context_window is None:
raise KeyboardInterrupt
# If custom, ask for input
if context_window == "custom":
while True:
context_window = questionary.text("Enter context window (e.g. 8192)").ask()
if context_window is None:
raise KeyboardInterrupt
try:
context_window = int(context_window)
break
@@ -302,6 +345,8 @@ def configure_embedding_endpoint(config: MemGPTConfig):
embedding_provider = questionary.select(
"Select embedding provider:", choices=["openai", "azure", "hugging-face", "local"], default=default_embedding_endpoint_type
).ask()
if embedding_provider is None:
raise KeyboardInterrupt
if embedding_provider == "openai":
# check for key
@@ -315,6 +360,8 @@ def configure_embedding_endpoint(config: MemGPTConfig):
openai_api_key = questionary.text(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
if openai_api_key is None:
raise KeyboardInterrupt
config.openai_key = openai_api_key
config.save()
@@ -343,6 +390,8 @@ def configure_embedding_endpoint(config: MemGPTConfig):
# get endpoint
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
if embedding_endpoint is None:
raise KeyboardInterrupt
while not utils.is_valid_url(embedding_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
@@ -353,10 +402,14 @@ def configure_embedding_endpoint(config: MemGPTConfig):
"Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):",
default=default_embedding_model,
).ask()
if embedding_model is None:
raise KeyboardInterrupt
# get model dimentions
default_embedding_dim = config.embedding_dim if config.embedding_dim else "1024"
embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask()
if embedding_dim is None:
raise KeyboardInterrupt
try:
embedding_dim = int(embedding_dim)
except Exception as e:
@@ -376,16 +429,22 @@ def configure_cli(config: MemGPTConfig):
# preset
default_preset = config.preset if config.preset and config.preset in preset_options else None
preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask()
if preset is None:
raise KeyboardInterrupt
# persona
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
default_persona = config.persona if config.persona and config.persona in personas else None
persona = questionary.select("Select default persona:", personas, default=default_persona).ask()
if persona is None:
raise KeyboardInterrupt
# human
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
default_human = config.human if config.human and config.human in humans else None
human = questionary.select("Select default human:", humans, default=default_human).ask()
if human is None:
raise KeyboardInterrupt
# TODO: figure out if we should set a default agent or not
agent = None
@@ -399,6 +458,8 @@ def configure_archival_storage(config: MemGPTConfig):
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
if archival_storage_type is None:
raise KeyboardInterrupt
archival_storage_uri, archival_storage_path = config.archival_storage_uri, config.archival_storage_path
# configure postgres
@@ -407,6 +468,8 @@ def configure_archival_storage(config: MemGPTConfig):
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()
if archival_storage_uri is None:
raise KeyboardInterrupt
# TODO: add back
## configure lancedb
@@ -419,8 +482,12 @@ def configure_archival_storage(config: MemGPTConfig):
# configure chroma
if archival_storage_type == "chroma":
chroma_type = questionary.select("Select chroma backend:", ["http", "persistent"], default="persistent").ask()
if chroma_type is None:
raise KeyboardInterrupt
if chroma_type == "http":
archival_storage_uri = questionary.text("Enter chroma ip (e.g. localhost:8000):", default="localhost:8000").ask()
if archival_storage_uri is None:
raise KeyboardInterrupt
if chroma_type == "persistent":
archival_storage_path = os.path.join(MEMGPT_DIR, "chroma")
@@ -435,6 +502,8 @@ def configure_recall_storage(config: MemGPTConfig):
recall_storage_type = questionary.select(
"Select storage backend for recall data:", recall_storage_options, default=config.recall_storage_type
).ask()
if recall_storage_type is None:
raise KeyboardInterrupt
recall_storage_uri, recall_storage_path = config.recall_storage_uri, config.recall_storage_path
# configure postgres
if recall_storage_type == "postgres":
@@ -442,6 +511,8 @@ def configure_recall_storage(config: MemGPTConfig):
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.recall_storage_uri if config.recall_storage_uri else "",
).ask()
if recall_storage_uri is None:
raise KeyboardInterrupt
return recall_storage_type, recall_storage_uri, recall_storage_path
@@ -564,7 +635,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
"""List all agents"""
table = PrettyTable()
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"]
for agent in ms.list_agents(user_id=user_id):
for agent in tqdm(ms.list_agents(user_id=user_id)):
source_ids = ms.list_attached_sources(agent_id=agent.id)
source_names = [ms.get_source(source_id=source_id).name for source_id in source_ids]
table.add_row(

View File

@@ -32,6 +32,32 @@ from llama_index import (
app = typer.Typer()
def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
orig_size = storage.size()
# insert metadata store
ms = MetadataStore(config)
source = ms.get_source(user_id=user_id, source_name=source_name)
if not source:
# create new
source = Source(user_id=user_id, name=source_name, created_at=get_local_time())
ms.create_source(source)
# make sure user_id is set for passages
for passage in passages:
# TODO: attach source IDs
# passage.source_id = source.id
passage.user_id = user_id
passage.data_source = source_name
# add and save all passages
storage.insert_many(passages)
assert orig_size + len(passages) == storage.size(), f"Expected {orig_size + len(passages)} passages, got {storage.size()}"
storage.save()
def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
@@ -132,6 +158,9 @@ def load_index(
node.embedding = vector
passages.append(Passage(text=node.text, embedding=vector))
if len(passages) == 0:
raise ValueError(f"No passages found in index {dir}")
# create storage connector
config = MemGPTConfig.load()
if user_id is None:

View File

@@ -24,7 +24,7 @@ import memgpt.agent as agent
import memgpt.system as system
import memgpt.constants as constants
import memgpt.errors as errors
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart, suppress_stdout
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart, migrate
from memgpt.cli.cli_config import configure, list, add, delete
from memgpt.cli.cli_load import app as load_app
from memgpt.agent_store.storage import StorageConnector, TableType
@@ -43,6 +43,8 @@ app.command(name="folder")(open_folder)
app.command(name="quickstart")(quickstart)
# load data commands
app.add_typer(load_app, name="load")
# migration command
app.command(name="migrate")(migrate)
def clear_line(strip_ui=False):

View File

@@ -223,97 +223,87 @@ class MetadataStore:
Base.metadata.create_all(
self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__]
)
self.Session = sessionmaker(bind=self.engine)
session_maker = sessionmaker(bind=self.engine)
self.session = session_maker()
def create_agent(self, agent: AgentState):
# insert into agent table
session = self.Session()
# make sure agent.name does not already exist for user user_id
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
if self.session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
raise ValueError(f"Agent with name {agent.name} already exists")
session.add(AgentModel(**vars(agent)))
session.commit()
self.session.add(AgentModel(**vars(agent)))
self.session.commit()
def create_source(self, source: Source):
session = self.Session()
# make sure source.name does not already exist for user
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
if (
self.session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count()
> 0
):
raise ValueError(f"Source with name {source.name} already exists")
session.add(SourceModel(**vars(source)))
session.commit()
self.session.add(SourceModel(**vars(source)))
self.session.commit()
def create_user(self, user: User):
session = self.Session()
if session.query(UserModel).filter(UserModel.id == user.id).count() > 0:
if self.session.query(UserModel).filter(UserModel.id == user.id).count() > 0:
raise ValueError(f"User with id {user.id} already exists")
session.add(UserModel(**vars(user)))
session.commit()
self.session.add(UserModel(**vars(user)))
self.session.commit()
def update_agent(self, agent: AgentState):
session = self.Session()
session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
session.commit()
self.session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
self.session.commit()
def update_user(self, user: User):
session = self.Session()
session.query(UserModel).filter(UserModel.id == user.id).update(vars(user))
session.commit()
self.session.query(UserModel).filter(UserModel.id == user.id).update(vars(user))
self.session.commit()
def update_source(self, source: Source):
session = self.Session()
session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
session.commit()
self.session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
self.session.commit()
def delete_agent(self, agent_id: str):
session = self.Session()
session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
session.commit()
self.session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
self.session.commit()
def delete_source(self, source_id: str):
session = self.Session()
# delete from sources table
session.query(SourceModel).filter(SourceModel.id == source_id).delete()
self.session.query(SourceModel).filter(SourceModel.id == source_id).delete()
# delete any mappings
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()
self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()
session.commit()
self.session.commit()
def delete_user(self, user_id: str):
session = self.Session()
# delete from users table
session.query(UserModel).filter(UserModel.id == user_id).delete()
self.session.query(UserModel).filter(UserModel.id == user_id).delete()
# delete associated agents
session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
self.session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
# delete associated sources
session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
self.session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
# delete associated mappings
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
session.commit()
self.session.commit()
def list_agents(self, user_id: str) -> List[AgentState]:
session = self.Session()
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
results = self.session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results]
def list_sources(self, user_id: str) -> List[Source]:
session = self.Session()
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
results = self.session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
return [r.to_record() for r in results]
def get_agent(self, agent_id: str = None, agent_name: str = None, user_id: str = None) -> Optional[AgentState]:
session = self.Session()
if agent_id:
results = session.query(AgentModel).filter(AgentModel.id == agent_id).all()
results = self.session.query(AgentModel).filter(AgentModel.id == agent_id).all()
else:
assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name"
results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
results = self.session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
if len(results) == 0:
return None
@@ -321,20 +311,18 @@ class MetadataStore:
return results[0].to_record()
def get_user(self, user_id: str) -> Optional[User]:
session = self.Session()
results = session.query(UserModel).filter(UserModel.id == user_id).all()
results = self.session.query(UserModel).filter(UserModel.id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
def get_source(self, source_id: str = None, user_id: str = None, source_name: str = None) -> Optional[Source]:
session = self.Session()
if source_id:
results = session.query(SourceModel).filter(SourceModel.id == source_id).all()
results = self.session.query(SourceModel).filter(SourceModel.id == source_id).all()
else:
assert user_id is not None and source_name is not None
results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
results = self.session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
@@ -342,23 +330,19 @@ class MetadataStore:
# agent source metadata
def attach_source(self, user_id: str, agent_id: str, source_id: str):
session = self.Session()
session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
session.commit()
self.session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
self.session.commit()
def list_attached_sources(self, agent_id: str) -> List[Column]:
session = self.Session()
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
return [r.source_id for r in results]
def list_attached_agents(self, source_id):
session = self.Session()
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
return [r.agent_id for r in results]
def detach_source(self, agent_id: str, source_id: str):
session = self.Session()
session.query(AgentSourceMappingModel).filter(
self.session.query(AgentSourceMappingModel).filter(
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
).delete()
session.commit()
self.session.commit()

448
memgpt/migrate.py Normal file
View File

@@ -0,0 +1,448 @@
import configparser
import datetime
import os
import pickle
import glob
import sys
import traceback
import uuid
import json
import shutil
import typer
from tqdm import tqdm
import questionary
from llama_index import (
StorageContext,
load_index_from_storage,
)
from memgpt.agent import Agent
from memgpt.data_types import AgentState, User, Passage, Source
from memgpt.metadata import MetadataStore
from memgpt.utils import MEMGPT_DIR, version_less_than, OpenAIBackcompatUnpickler, annotate_message_json_list_with_tool_calls
from memgpt.config import MemGPTConfig
from memgpt.cli.cli_config import configure
from memgpt.agent_store.storage import StorageConnector, TableType
# This is the version where the breaking change was made
VERSION_CUTOFF = "0.2.12"
# Migration backup dir (where we'll dump old agents that we successfully migrated)
MIGRATION_BACKUP_FOLDER = "migration_backups"
def wipe_config_and_reconfigure():
"""Wipe (backup) the config file, and launch `memgpt configure`"""
# Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# Construct the new backup directory name with the timestamp
backup_filename = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}")
existing_filename = os.path.join(MEMGPT_DIR, "config")
# Check if the existing file exists before moving
if os.path.exists(existing_filename):
# shutil should work cross-platform
shutil.move(existing_filename, backup_filename)
typer.secho(f"Deleted config file ({existing_filename}) and saved as backup ({backup_filename})", fg=typer.colors.GREEN)
else:
typer.secho(f"Couldn't find an existing config file to delete", fg=typer.colors.RED)
# Run configure
configure()
def config_is_compatible() -> bool:
"""Check if the config is OK to use with 0.2.12, or if it needs to be deleted"""
# NOTE: don't use built-in load(), since that will apply defaults
# memgpt_config = MemGPTConfig.load()
memgpt_config_file = os.path.join(MEMGPT_DIR, "config")
parser = configparser.ConfigParser()
parser.read(memgpt_config_file)
if "version" in parser and "memgpt_version" in parser["version"]:
version = parser["version"]["memgpt_version"]
else:
version = None
if version is None:
typer.secho(f"Current config version is missing", fg=typer.colors.RED)
return False
elif version_less_than(version, VERSION_CUTOFF):
typer.secho(f"Current config version ({version}) is older than cutoff ({VERSION_CUTOFF})", fg=typer.colors.RED)
return False
else:
typer.secho(f"Current config version {version} is compatible!", fg=typer.colors.GREEN)
return True
def agent_is_migrateable(agent_name: str) -> bool:
"""Determine whether or not the agent folder is a migration target"""
agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name)
if not os.path.exists(agent_folder):
raise ValueError(f"Folder {agent_folder} does not exist")
agent_config_file = os.path.join(agent_folder, "config.json")
if not os.path.exists(agent_config_file):
raise ValueError(f"Agent folder {agent_folder} does not have a config file")
try:
with open(agent_config_file, "r") as fh:
agent_config = json.load(fh)
except Exception as e:
raise ValueError(f"Failed to load agent config file ({agent_config_file}), error = {e}")
if not hasattr(agent_config, "memgpt_version") or version_less_than(agent_config.memgpt_version, VERSION_CUTOFF):
return True
else:
return False
def migrate_source(source_name: str):
"""
Migrate an old source folder (`~/.memgpt/sources/{source_name}`).
"""
# 1. Load the VectorIndex from ~/.memgpt/sources/{source_name}/index
# TODO
source_path = os.path.join(MEMGPT_DIR, "archival", source_name, "nodes.pkl")
assert os.path.exists(source_path), f"Source {source_name} does not exist at {source_path}"
# load state from old checkpoint file
from memgpt.cli.cli_load import load_index
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
# gets default user
ms = MetadataStore(config)
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
raise ValueError(
f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
)
# insert source into metadata store
source = Source(user_id=user.id, name=source_name)
ms.create_source(source)
try:
nodes = pickle.load(open(source_path, "rb"))
passages = []
for node in nodes:
# print(len(node.embedding))
# TODO: make sure embedding config matches embedding size?
passages.append(Passage(user_id=user.id, data_source=source_name, text=node.text, embedding=node.embedding))
assert len(passages) > 0, f"Source {source_name} has no passages"
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
conn.insert_many(passages)
print(f"Inserted {len(passages)} to {source_name}")
except Exception as e:
# delete from metadata store
ms.delete_source(source.id)
raise ValueError(f"Failed to migrate {source_name}: {str(e)}")
# basic checks
source = ms.get_source(user_id=user.id, source_name=source_name)
assert source is not None, f"Failed to load source {source_name} from database after migration"
def migrate_agent(agent_name: str):
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
Steps:
1. Load the agent state JSON from the old folder
2. Create a new AgentState using the agent config + agent internal state
3. Instantiate a new Agent by passing AgentState to Agent.__init__
(This will automatically run into a new database)
"""
# 1. Load the agent state JSON from the old folder
# TODO
agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name)
# migration_file = os.path.join(agent_folder, MIGRATION_FILE_NAME)
# load state from old checkpoint file
agent_ckpt_directory = os.path.join(agent_folder, "agent_state")
json_files = glob.glob(os.path.join(agent_ckpt_directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}")
# NOTE this is a soft fail, just allow it to pass
# return
# Sort files based on modified timestamp, with the latest file being the first.
state_filename = max(json_files, key=os.path.getmtime)
state_dict = json.load(open(state_filename, "r"))
# print(state_dict.keys())
# print(state_dict["memory"])
# dict_keys(['model', 'system', 'functions', 'messages', 'messages_total', 'memory'])
# load old data from the persistence manager
persistence_filename = os.path.basename(state_filename).replace(".json", ".persistence.pickle")
persistence_filename = os.path.join(agent_folder, "persistence_manager", persistence_filename)
archival_filename = os.path.join(agent_folder, "persistence_manager", "index", "nodes.pkl")
if not os.path.exists(persistence_filename):
raise ValueError(f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}")
try:
with open(persistence_filename, "rb") as f:
data = pickle.load(f)
except ModuleNotFoundError as e:
# Patch for stripped openai package
# ModuleNotFoundError: No module named 'openai.openai_object'
with open(persistence_filename, "rb") as f:
unpickler = OpenAIBackcompatUnpickler(f)
data = unpickler.load()
from memgpt.openai_backcompat.openai_object import OpenAIObject
def convert_openai_objects_to_dict(obj):
if isinstance(obj, OpenAIObject):
# Convert to dict or handle as needed
# print(f"detected OpenAIObject on {obj}")
return obj.to_dict_recursive()
elif isinstance(obj, dict):
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_openai_objects_to_dict(v) for v in obj]
else:
return obj
data = convert_openai_objects_to_dict(data)
# data will contain:
# print("data.keys()", data.keys())
# manager.all_messages = data["all_messages"]
# manager.messages = data["messages"]
# manager.recall_memory = data["recall_memory"]
agent_config_filename = os.path.join(agent_folder, "config.json")
with open(agent_config_filename, "r") as fh:
agent_config = json.load(fh)
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
# gets default user
ms = MetadataStore(config)
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
raise ValueError(
f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
)
# ms.create_user(User(id=user_id))
# user = ms.get_user(user_id=user_id)
# if user is None:
# typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
# sys.exit(1)
agent_state = AgentState(
name=agent_config["name"],
user_id=user.id,
persona=agent_config["persona"], # eg 'sam_pov'
human=agent_config["human"], # eg 'basic'
preset=agent_config["preset"], # eg 'memgpt_chat'
state=dict(
human=state_dict["memory"]["human"],
persona=state_dict["memory"]["persona"],
system=state_dict["system"],
functions=state_dict["functions"], # this shouldn't matter, since Agent.__init__ will re-link
messages=annotate_message_json_list_with_tool_calls(state_dict["messages"]),
),
llm_config=user.default_llm_config,
embedding_config=user.default_embedding_config,
)
# 3. Instantiate a new Agent by passing AgentState to Agent.__init__
# NOTE: the Agent.__init__ will trigger a save, which will write to the DB
try:
agent = Agent(
agent_state=agent_state,
messages_total=state_dict["messages_total"], # TODO: do we need this?
interface=None,
)
except Exception as e:
# if "Agent with name" in str(e):
# print(e)
# return
# elif "was specified in agent.state.functions":
# print(e)
# return
# else:
# raise
raise
# Wrap the rest in a try-except so that we can cleanup by deleting the agent if we fail
try:
## 4. Insert into recall
# TODO should this be 'messages', or 'all_messages'?
# all_messages in recall will have fields "timestamp" and "message"
full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
# We want to keep the timestamp
for i in range(len(data["all_messages"])):
data["all_messages"][i]["message"] = full_message_history_buffer[i]
messages_to_insert = [agent.persistence_manager.json_to_message(msg) for msg in data["all_messages"]]
agent.persistence_manager.recall_memory.insert_many(messages_to_insert)
# print("Finished migrating recall memory")
# TODO should we also assign data["messages"] to RecallMemory.messages?
# 5. Insert into archival
if os.path.exists(archival_filename):
nodes = pickle.load(open(archival_filename, "rb"))
passages = []
for node in nodes:
# print(len(node.embedding))
# TODO: make sure embeding size matches embedding config?
passages.append(Passage(user_id=user.id, agent_id=agent_state.id, text=node.text, embedding=node.embedding))
if len(passages) > 0:
agent.persistence_manager.archival_memory.storage.insert_many(passages)
print(f"Inserted {len(passages)} passages into archival memory")
else:
print("No archival memory found at", archival_filename)
except:
ms.delete_agent(agent_state.id)
raise
try:
new_agent_folder = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents", agent_name)
shutil.move(agent_folder, new_agent_folder)
except Exception as e:
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
raise
# def migrate_all_agents(stop_on_fail=True):
def migrate_all_agents(stop_on_fail: bool = False) -> dict:
"""Scan over all agent folders in MEMGPT_DIR and migrate each agent."""
if not os.path.exists(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents"))
if not config_is_compatible():
typer.secho(f"Your current config file is incompatible with MemGPT versions >= {VERSION_CUTOFF}", fg=typer.colors.RED)
if questionary.confirm(
"To migrate old MemGPT agents, you must delete your config file and run `memgpt configure`. Would you like to proceed?"
).ask():
try:
wipe_config_and_reconfigure()
except Exception as e:
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
raise
else:
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
raise KeyboardInterrupt()
agents_dir = os.path.join(MEMGPT_DIR, "agents")
# Ensure the directory exists
if not os.path.exists(agents_dir):
raise ValueError(f"Directory {agents_dir} does not exist.")
# Get a list of all folders in agents_dir
agent_folders = [f for f in os.listdir(agents_dir) if os.path.isdir(os.path.join(agents_dir, f))]
# Iterate over each folder with a tqdm progress bar
count = 0
failures = []
candidates = []
try:
for agent_name in tqdm(agent_folders, desc="Migrating agents"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
if agent_is_migrateable(agent_name=agent_name):
candidates.append(agent_name)
migrate_agent(agent_name)
count += 1
else:
continue
except Exception as e:
failures.append({"name": agent_name, "reason": str(e)})
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
traceback.print_exc()
if stop_on_fail:
raise
except KeyboardInterrupt:
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
if len(candidates) == 0:
typer.secho(f"No migration candidates found ({len(agent_folders)} agent folders total)", fg=typer.colors.GREEN)
else:
typer.secho(f"Inspected {len(agent_folders)} agent folders")
if len(failures) > 0:
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
for fail in failures:
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
typer.secho(f"{len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
if count > 0:
typer.secho(f"{count}/{len(candidates)} agents were successfully migrated to the new database format", fg=typer.colors.GREEN)
return {
"agent_folders": len(agent_folders),
"migration_candidates": len(candidates),
"successful_migrations": count,
"failed_migrations": len(failures),
}
def migrate_all_sources(stop_on_fail: bool = False) -> dict:
"""Scan over all agent folders in MEMGPT_DIR and migrate each agent."""
sources_dir = os.path.join(MEMGPT_DIR, "archival")
# Ensure the directory exists
if not os.path.exists(sources_dir):
raise ValueError(f"Directory {sources_dir} does not exist.")
# Get a list of all folders in agents_dir
source_folders = [f for f in os.listdir(sources_dir) if os.path.isdir(os.path.join(sources_dir, f))]
# Iterate over each folder with a tqdm progress bar
count = 0
failures = []
candidates = []
try:
for source_name in tqdm(source_folders, desc="Migrating data sources"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
candidates.append(source_name)
migrate_source(source_name)
count += 1
except Exception as e:
failures.append({"name": source_name, "reason": str(e)})
traceback.print_exc()
if stop_on_fail:
raise
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
except KeyboardInterrupt:
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
if len(candidates) == 0:
typer.secho(f"No migration candidates found ({len(source_folders)} source folders total)", fg=typer.colors.GREEN)
else:
typer.secho(f"Inspected {len(source_folders)} source folders")
if len(failures) > 0:
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
for fail in failures:
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
typer.secho(f"{len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
if count > 0:
typer.secho(f"{count}/{len(candidates)} sources were successfully migrated to the new database format", fg=typer.colors.GREEN)
return {
"source_folders": len(source_folders),
"migration_candidates": len(candidates),
"successful_migrations": count,
"failed_migrations": len(failures),
}

View File

@@ -10,9 +10,12 @@ from memgpt.data_types import Message, ToolCall, AgentState
from datetime import datetime
def parse_formatted_time(formatted_time):
def parse_formatted_time(formatted_time: str):
# parse times returned by memgpt.utils.get_formatted_time()
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
try:
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p %Z%z")
except:
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p")
class PersistenceManager(ABC):

View File

@@ -1,4 +1,5 @@
from datetime import datetime
import copy
import re
import json
import os
@@ -6,8 +7,11 @@ import pickle
import platform
import random
import subprocess
import uuid
import sys
import io
from typing import List
from urllib.parse import urlparse
from contextlib import contextmanager
import difflib
@@ -456,6 +460,86 @@ NOUN_BANK = [
]
def annotate_message_json_list_with_tool_calls(messages: List[dict]):
"""Add in missing tool_call_id fields to a list of messages using function call style
Walk through the list forwards:
- If we encounter an assistant message that calls a function ("function_call") but doesn't have a "tool_call_id" field
- Generate the tool_call_id
- Then check if the subsequent message is a role == "function" message
- If so, then att
"""
tool_call_index = None
tool_call_id = None
updated_messages = []
for i, message in enumerate(messages):
if "role" not in message:
raise ValueError(f"message missing 'role' field:\n{message}")
# If we find a function call w/o a tool call ID annotation, annotate it
if message["role"] == "assistant" and "function_call" in message:
if "tool_call_id" in message and message["tool_call_id"] is not None:
printd(f"Message already has tool_call_id")
tool_call_id = message["tool_call_id"]
else:
tool_call_id = str(uuid.uuid4())
message["tool_call_id"] = tool_call_id
tool_call_index = i
# After annotating the call, we expect to find a follow-up response (also unannotated)
elif message["role"] == "function":
# We should have a new tool call id in the buffer
if tool_call_id is None:
# raise ValueError(
print(
f"Got a function call role, but did not have a saved tool_call_id ready to use (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
# allow a soft fail in this case
message["tool_call_id"] = str(uuid.uuid4())
elif "tool_call_id" in message:
raise ValueError(
f"Got a function call role, but it already had a saved tool_call_id (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
elif i != tool_call_index + 1:
raise ValueError(
f"Got a function call role, saved tool_call_id came earlier than i-1 (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
else:
message["tool_call_id"] = tool_call_id
tool_call_id = None # wipe the buffer
elif message["role"] == "tool":
raise NotImplementedError(
f"tool_call_id annotation is meant for deprecated functions style, but got role 'tool' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
else:
# eg role == 'user', nothing to do here
pass
updated_messages.append(copy.deepcopy(message))
return updated_messages
def version_less_than(version_a: str, version_b: str) -> bool:
"""Compare versions to check if version_a is less than version_b."""
# Regular expression to match version strings of the format int.int.int
version_pattern = re.compile(r"^\d+\.\d+\.\d+$")
# Assert that version strings match the required format
if not version_pattern.match(version_a) or not version_pattern.match(version_b):
raise ValueError("Version strings must be in the format 'int.int.int'")
# Split the version strings into parts
parts_a = [int(part) for part in version_a.split(".")]
parts_b = [int(part) for part in version_b.split(".")]
# Compare version parts
return parts_a < parts_b
def create_random_username() -> str:
"""Generate a random username by combining an adjective and a noun."""
adjective = random.choice(ADJECTIVE_BANK).capitalize()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,20 @@
{
"name": "agent_test",
"persona": "sam_pov",
"human": "basic",
"preset": "memgpt_chat",
"context_window": 8192,
"model": "gpt-4",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null,
"embedding_endpoint_type": "openai",
"embedding_endpoint": "https://api.openai.com/v1",
"embedding_model": null,
"embedding_dim": 1536,
"embedding_chunk_size": 300,
"data_sources": [],
"create_time": "2024-01-11 12:42:25 PM",
"memgpt_version": "0.2.11",
"agent_config_path": "/Users/sarahwooders/.memgpt/agents/agent_test/config.json"
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,22 @@
{
"name": "agent_test_attach",
"persona": "sam_pov",
"human": "basic",
"preset": "memgpt_chat",
"context_window": 8192,
"model": "gpt-4",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null,
"embedding_endpoint_type": "openai",
"embedding_endpoint": "https://api.openai.com/v1",
"embedding_model": null,
"embedding_dim": 1536,
"embedding_chunk_size": 300,
"data_sources": [
"test"
],
"create_time": "2024-01-11 12:41:37 PM",
"memgpt_version": "0.2.11",
"agent_config_path": "/Users/sarahwooders/.memgpt/agents/agent_test_attach/config.json"
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,20 @@
{
"name": "agent_test_empty_archival",
"persona": "sam_pov",
"human": "basic",
"preset": "memgpt_chat",
"context_window": 8192,
"model": "gpt-4",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null,
"embedding_endpoint_type": "openai",
"embedding_endpoint": "https://api.openai.com/v1",
"embedding_model": null,
"embedding_dim": 1536,
"embedding_chunk_size": 300,
"data_sources": [],
"create_time": "2024-01-11 12:44:07 PM",
"memgpt_version": "0.2.11",
"agent_config_path": "/Users/sarahwooders/.memgpt/agents/agent_test_empty_archival/config.json"
}

Binary file not shown.

View File

@@ -0,0 +1,29 @@
[defaults]
preset = memgpt_chat
persona = sam_pov
human = basic
[model]
model = gpt-4
model_endpoint = https://api.openai.com/v1
model_endpoint_type = openai
context_window = 8192
[openai]
key = FAKE_KEY
[embedding]
embedding_endpoint_type = openai
embedding_endpoint = https://api.openai.com/v1
embedding_dim = 1536
embedding_chunk_size = 300
[archival_storage]
type = local
[version]
memgpt_version = 0.2.11
[client]
anon_clientid = 00000000000000000000d67f40108c5c