feat: Migration command for importing old agents into new DB backend (#802)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
@@ -3,11 +3,13 @@ repos:
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
exclude: ^docs/
|
||||
exclude: 'docs/.*|tests/data/.*'
|
||||
- id: end-of-file-fixer
|
||||
exclude: ^docs/
|
||||
exclude: 'docs/.*|tests/data/.*'
|
||||
- id: trailing-whitespace
|
||||
exclude: ^docs/
|
||||
exclude: 'docs/.*|tests/data/.*'
|
||||
- id: end-of-file-fixer
|
||||
exclude: 'docs/.*|tests/data/.*'
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
|
||||
@@ -254,12 +254,11 @@ class SQLStorageConnector(StorageConnector):
|
||||
return all_filters
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
||||
session = self.Session()
|
||||
offset = 0
|
||||
filters = self.get_filters(filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
|
||||
db_record_chunk = self.session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
|
||||
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
if not db_record_chunk:
|
||||
@@ -272,40 +271,35 @@ class SQLStorageConnector(StorageConnector):
|
||||
offset += page_size
|
||||
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
if limit:
|
||||
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
db_records = self.session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
else:
|
||||
db_records = session.query(self.db_model).filter(*filters).all()
|
||||
db_records = self.session.query(self.db_model).filter(*filters).all()
|
||||
return [record.to_record() for record in db_records]
|
||||
|
||||
def get(self, id: str) -> Optional[Record]:
|
||||
session = self.Session()
|
||||
db_record = session.query(self.db_model).get(id)
|
||||
db_record = self.session.query(self.db_model).get(id)
|
||||
if db_record is None:
|
||||
return None
|
||||
return db_record.to_record()
|
||||
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# return size of table
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
return session.query(self.db_model).filter(*filters).count()
|
||||
return self.session.query(self.db_model).filter(*filters).count()
|
||||
|
||||
def insert(self, record: Record):
|
||||
session = self.Session()
|
||||
db_record = self.db_model(**vars(record))
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
self.session.add(db_record)
|
||||
self.session.commit()
|
||||
|
||||
def insert_many(self, records: List[Record], show_progress=False):
|
||||
session = self.Session()
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for record in iterable:
|
||||
db_record = self.db_model(**vars(record))
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
self.session.add(db_record)
|
||||
self.session.commit()
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
|
||||
@@ -315,15 +309,13 @@ class SQLStorageConnector(StorageConnector):
|
||||
|
||||
def list_data_sources(self):
|
||||
assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
|
||||
session = self.Session()
|
||||
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
|
||||
unique_data_sources = self.session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
|
||||
return unique_data_sources
|
||||
|
||||
def query_date(self, start_date, end_date, offset=0, limit=None):
|
||||
session = self.Session()
|
||||
filters = self.get_filters({})
|
||||
query = (
|
||||
session.query(self.db_model)
|
||||
self.session.query(self.db_model)
|
||||
.filter(*filters)
|
||||
.filter(self.db_model.created_at >= start_date)
|
||||
.filter(self.db_model.created_at <= end_date)
|
||||
@@ -336,10 +328,12 @@ class SQLStorageConnector(StorageConnector):
|
||||
|
||||
def query_text(self, query, offset=0, limit=None):
|
||||
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
|
||||
session = self.Session()
|
||||
filters = self.get_filters({})
|
||||
query = (
|
||||
session.query(self.db_model).filter(*filters).filter(func.lower(self.db_model.text).contains(func.lower(query))).offset(offset)
|
||||
self.session.query(self.db_model)
|
||||
.filter(*filters)
|
||||
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
|
||||
.offset(offset)
|
||||
)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
@@ -348,16 +342,14 @@ class SQLStorageConnector(StorageConnector):
|
||||
return [result.to_record() for result in results]
|
||||
|
||||
def delete_table(self):
|
||||
session = self.Session()
|
||||
close_all_sessions()
|
||||
self.db_model.__table__.drop(session.bind)
|
||||
session.commit()
|
||||
self.db_model.__table__.drop(self.session.bind)
|
||||
self.session.commit()
|
||||
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
session.query(self.db_model).filter(*filters).delete()
|
||||
session.commit()
|
||||
self.session.query(self.db_model).filter(*filters).delete()
|
||||
self.session.commit()
|
||||
|
||||
|
||||
class PostgresStorageConnector(SQLStorageConnector):
|
||||
@@ -393,13 +385,13 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
session_maker = sessionmaker(bind=self.engine)
|
||||
self.session = session_maker()
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
results = session.scalars(
|
||||
results = self.session.scalars(
|
||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||
).all()
|
||||
|
||||
@@ -429,7 +421,8 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite")
|
||||
self.engine = create_engine(f"sqlite:///{self.path}")
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
session_maker = sessionmaker(bind=self.engine)
|
||||
self.session = session_maker()
|
||||
|
||||
import sqlite3
|
||||
|
||||
|
||||
@@ -28,6 +28,13 @@ from memgpt.embeddings import embedding_model
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT
|
||||
from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig, User
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
|
||||
|
||||
def migrate():
|
||||
"""Migrate old agents (pre 0.2.12) to the new database system"""
|
||||
migrate_all_agents()
|
||||
migrate_all_sources()
|
||||
|
||||
|
||||
class QuickstartChoice(Enum):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import builtins
|
||||
from tqdm import tqdm
|
||||
import uuid
|
||||
import questionary
|
||||
from prettytable import PrettyTable
|
||||
@@ -61,6 +62,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
provider = questionary.select(
|
||||
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
|
||||
).ask()
|
||||
if provider is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# set: model_endpoint_type, model_endpoint
|
||||
if provider == "openai":
|
||||
@@ -75,6 +78,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
openai_api_key = questionary.text(
|
||||
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
|
||||
).ask()
|
||||
if openai_api_key is None:
|
||||
raise KeyboardInterrupt
|
||||
config.openai_key = openai_api_key
|
||||
config.save()
|
||||
else:
|
||||
@@ -85,6 +90,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
"Enter your OpenAI API key (hit enter to use existing key):",
|
||||
default=default_input,
|
||||
).ask()
|
||||
if openai_api_key is None:
|
||||
raise KeyboardInterrupt
|
||||
# If the user modified it, use the new one
|
||||
if openai_api_key != default_input:
|
||||
config.openai_key = openai_api_key
|
||||
@@ -93,6 +100,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
model_endpoint_type = "openai"
|
||||
model_endpoint = "https://api.openai.com/v1"
|
||||
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
provider = "openai"
|
||||
|
||||
elif provider == "azure":
|
||||
@@ -122,6 +131,8 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
backend_options,
|
||||
default=default_model_endpoint_type,
|
||||
).ask()
|
||||
if model_endpoint_type is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# set default endpoint
|
||||
# if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use
|
||||
@@ -131,21 +142,33 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
if model_endpoint_type in DEFAULT_ENDPOINTS:
|
||||
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
while not utils.is_valid_url(model_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
elif config.model_endpoint:
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
while not utils.is_valid_url(model_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
else:
|
||||
# default_model_endpoint = None
|
||||
model_endpoint = None
|
||||
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
while not utils.is_valid_url(model_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
else:
|
||||
model_endpoint = default_model_endpoint
|
||||
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
|
||||
@@ -185,6 +208,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
|
||||
default=config.model if valid_model else hardcoded_model_options[0],
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# If the user asked for the full list, show it
|
||||
if model == see_all_option_str:
|
||||
@@ -194,6 +219,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
choices=fetched_model_options + [other_option_str],
|
||||
default=config.model if valid_model else fetched_model_options[0],
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# Finally if the user asked to manually input, allow it
|
||||
if model == other_option_str:
|
||||
@@ -202,6 +229,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
model = questionary.text(
|
||||
"Enter custom model name:",
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
else: # local models
|
||||
# ollama also needs model type
|
||||
@@ -211,6 +240,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
"Enter default model name (required for Ollama, see: https://memgpt.readme.io/docs/ollama):",
|
||||
default=default_model,
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""
|
||||
@@ -234,6 +265,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
model = questionary.select(
|
||||
"Select default model:", choices=model_options, default=config.model if valid_model else model_options[0]
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# If we got custom input, ask for raw input
|
||||
if model == other_option_str:
|
||||
@@ -241,6 +274,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
|
||||
default=default_model,
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
# TODO allow empty string for input?
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
@@ -249,6 +284,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
|
||||
default=default_model,
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
# model wrapper
|
||||
@@ -258,6 +295,8 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
choices=available_model_wrappers,
|
||||
default=DEFAULT_WRAPPER_NAME,
|
||||
).ask()
|
||||
if model_wrapper is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# set: context_window
|
||||
if str(model) not in LLM_MAX_TOKENS:
|
||||
@@ -275,11 +314,15 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi
|
||||
choices=context_length_options,
|
||||
default=str(LLM_MAX_TOKENS["DEFAULT"]),
|
||||
).ask()
|
||||
if context_window is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# If custom, ask for input
|
||||
if context_window == "custom":
|
||||
while True:
|
||||
context_window = questionary.text("Enter context window (e.g. 8192)").ask()
|
||||
if context_window is None:
|
||||
raise KeyboardInterrupt
|
||||
try:
|
||||
context_window = int(context_window)
|
||||
break
|
||||
@@ -302,6 +345,8 @@ def configure_embedding_endpoint(config: MemGPTConfig):
|
||||
embedding_provider = questionary.select(
|
||||
"Select embedding provider:", choices=["openai", "azure", "hugging-face", "local"], default=default_embedding_endpoint_type
|
||||
).ask()
|
||||
if embedding_provider is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
if embedding_provider == "openai":
|
||||
# check for key
|
||||
@@ -315,6 +360,8 @@ def configure_embedding_endpoint(config: MemGPTConfig):
|
||||
openai_api_key = questionary.text(
|
||||
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
|
||||
).ask()
|
||||
if openai_api_key is None:
|
||||
raise KeyboardInterrupt
|
||||
config.openai_key = openai_api_key
|
||||
config.save()
|
||||
|
||||
@@ -343,6 +390,8 @@ def configure_embedding_endpoint(config: MemGPTConfig):
|
||||
|
||||
# get endpoint
|
||||
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
if embedding_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
while not utils.is_valid_url(embedding_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
@@ -353,10 +402,14 @@ def configure_embedding_endpoint(config: MemGPTConfig):
|
||||
"Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):",
|
||||
default=default_embedding_model,
|
||||
).ask()
|
||||
if embedding_model is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# get model dimentions
|
||||
default_embedding_dim = config.embedding_dim if config.embedding_dim else "1024"
|
||||
embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask()
|
||||
if embedding_dim is None:
|
||||
raise KeyboardInterrupt
|
||||
try:
|
||||
embedding_dim = int(embedding_dim)
|
||||
except Exception as e:
|
||||
@@ -376,16 +429,22 @@ def configure_cli(config: MemGPTConfig):
|
||||
# preset
|
||||
default_preset = config.preset if config.preset and config.preset in preset_options else None
|
||||
preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask()
|
||||
if preset is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# persona
|
||||
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
||||
default_persona = config.persona if config.persona and config.persona in personas else None
|
||||
persona = questionary.select("Select default persona:", personas, default=default_persona).ask()
|
||||
if persona is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# human
|
||||
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
||||
default_human = config.human if config.human and config.human in humans else None
|
||||
human = questionary.select("Select default human:", humans, default=default_human).ask()
|
||||
if human is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# TODO: figure out if we should set a default agent or not
|
||||
agent = None
|
||||
@@ -399,6 +458,8 @@ def configure_archival_storage(config: MemGPTConfig):
|
||||
archival_storage_type = questionary.select(
|
||||
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
|
||||
).ask()
|
||||
if archival_storage_type is None:
|
||||
raise KeyboardInterrupt
|
||||
archival_storage_uri, archival_storage_path = config.archival_storage_uri, config.archival_storage_path
|
||||
|
||||
# configure postgres
|
||||
@@ -407,6 +468,8 @@ def configure_archival_storage(config: MemGPTConfig):
|
||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
||||
).ask()
|
||||
if archival_storage_uri is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# TODO: add back
|
||||
## configure lancedb
|
||||
@@ -419,8 +482,12 @@ def configure_archival_storage(config: MemGPTConfig):
|
||||
# configure chroma
|
||||
if archival_storage_type == "chroma":
|
||||
chroma_type = questionary.select("Select chroma backend:", ["http", "persistent"], default="persistent").ask()
|
||||
if chroma_type is None:
|
||||
raise KeyboardInterrupt
|
||||
if chroma_type == "http":
|
||||
archival_storage_uri = questionary.text("Enter chroma ip (e.g. localhost:8000):", default="localhost:8000").ask()
|
||||
if archival_storage_uri is None:
|
||||
raise KeyboardInterrupt
|
||||
if chroma_type == "persistent":
|
||||
archival_storage_path = os.path.join(MEMGPT_DIR, "chroma")
|
||||
|
||||
@@ -435,6 +502,8 @@ def configure_recall_storage(config: MemGPTConfig):
|
||||
recall_storage_type = questionary.select(
|
||||
"Select storage backend for recall data:", recall_storage_options, default=config.recall_storage_type
|
||||
).ask()
|
||||
if recall_storage_type is None:
|
||||
raise KeyboardInterrupt
|
||||
recall_storage_uri, recall_storage_path = config.recall_storage_uri, config.recall_storage_path
|
||||
# configure postgres
|
||||
if recall_storage_type == "postgres":
|
||||
@@ -442,6 +511,8 @@ def configure_recall_storage(config: MemGPTConfig):
|
||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||
default=config.recall_storage_uri if config.recall_storage_uri else "",
|
||||
).ask()
|
||||
if recall_storage_uri is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
return recall_storage_type, recall_storage_uri, recall_storage_path
|
||||
|
||||
@@ -564,7 +635,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
"""List all agents"""
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"]
|
||||
for agent in ms.list_agents(user_id=user_id):
|
||||
for agent in tqdm(ms.list_agents(user_id=user_id)):
|
||||
source_ids = ms.list_attached_sources(agent_id=agent.id)
|
||||
source_names = [ms.get_source(source_id=source_id).name for source_id in source_ids]
|
||||
table.add_row(
|
||||
|
||||
@@ -32,6 +32,32 @@ from llama_index import (
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
|
||||
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
|
||||
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
orig_size = storage.size()
|
||||
|
||||
# insert metadata store
|
||||
ms = MetadataStore(config)
|
||||
source = ms.get_source(user_id=user_id, source_name=source_name)
|
||||
if not source:
|
||||
# create new
|
||||
source = Source(user_id=user_id, name=source_name, created_at=get_local_time())
|
||||
ms.create_source(source)
|
||||
|
||||
# make sure user_id is set for passages
|
||||
for passage in passages:
|
||||
# TODO: attach source IDs
|
||||
# passage.source_id = source.id
|
||||
passage.user_id = user_id
|
||||
passage.data_source = source_name
|
||||
|
||||
# add and save all passages
|
||||
storage.insert_many(passages)
|
||||
assert orig_size + len(passages) == storage.size(), f"Expected {orig_size + len(passages)} passages, got {storage.size()}"
|
||||
storage.save()
|
||||
|
||||
|
||||
def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
|
||||
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
|
||||
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
@@ -132,6 +158,9 @@ def load_index(
|
||||
node.embedding = vector
|
||||
passages.append(Passage(text=node.text, embedding=vector))
|
||||
|
||||
if len(passages) == 0:
|
||||
raise ValueError(f"No passages found in index {dir}")
|
||||
|
||||
# create storage connector
|
||||
config = MemGPTConfig.load()
|
||||
if user_id is None:
|
||||
|
||||
@@ -24,7 +24,7 @@ import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
import memgpt.errors as errors
|
||||
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart, suppress_stdout
|
||||
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart, migrate
|
||||
from memgpt.cli.cli_config import configure, list, add, delete
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
@@ -43,6 +43,8 @@ app.command(name="folder")(open_folder)
|
||||
app.command(name="quickstart")(quickstart)
|
||||
# load data commands
|
||||
app.add_typer(load_app, name="load")
|
||||
# migration command
|
||||
app.command(name="migrate")(migrate)
|
||||
|
||||
|
||||
def clear_line(strip_ui=False):
|
||||
|
||||
@@ -223,97 +223,87 @@ class MetadataStore:
|
||||
Base.metadata.create_all(
|
||||
self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__]
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
session_maker = sessionmaker(bind=self.engine)
|
||||
self.session = session_maker()
|
||||
|
||||
def create_agent(self, agent: AgentState):
|
||||
# insert into agent table
|
||||
session = self.Session()
|
||||
# make sure agent.name does not already exist for user user_id
|
||||
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
|
||||
if self.session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
|
||||
raise ValueError(f"Agent with name {agent.name} already exists")
|
||||
session.add(AgentModel(**vars(agent)))
|
||||
session.commit()
|
||||
self.session.add(AgentModel(**vars(agent)))
|
||||
self.session.commit()
|
||||
|
||||
def create_source(self, source: Source):
|
||||
session = self.Session()
|
||||
# make sure source.name does not already exist for user
|
||||
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
|
||||
if (
|
||||
self.session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count()
|
||||
> 0
|
||||
):
|
||||
raise ValueError(f"Source with name {source.name} already exists")
|
||||
session.add(SourceModel(**vars(source)))
|
||||
session.commit()
|
||||
self.session.add(SourceModel(**vars(source)))
|
||||
self.session.commit()
|
||||
|
||||
def create_user(self, user: User):
|
||||
session = self.Session()
|
||||
if session.query(UserModel).filter(UserModel.id == user.id).count() > 0:
|
||||
if self.session.query(UserModel).filter(UserModel.id == user.id).count() > 0:
|
||||
raise ValueError(f"User with id {user.id} already exists")
|
||||
session.add(UserModel(**vars(user)))
|
||||
session.commit()
|
||||
self.session.add(UserModel(**vars(user)))
|
||||
self.session.commit()
|
||||
|
||||
def update_agent(self, agent: AgentState):
|
||||
session = self.Session()
|
||||
session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
|
||||
session.commit()
|
||||
self.session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
|
||||
self.session.commit()
|
||||
|
||||
def update_user(self, user: User):
|
||||
session = self.Session()
|
||||
session.query(UserModel).filter(UserModel.id == user.id).update(vars(user))
|
||||
session.commit()
|
||||
self.session.query(UserModel).filter(UserModel.id == user.id).update(vars(user))
|
||||
self.session.commit()
|
||||
|
||||
def update_source(self, source: Source):
|
||||
session = self.Session()
|
||||
session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
|
||||
session.commit()
|
||||
self.session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
|
||||
self.session.commit()
|
||||
|
||||
def delete_agent(self, agent_id: str):
|
||||
session = self.Session()
|
||||
session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
|
||||
session.commit()
|
||||
self.session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
|
||||
self.session.commit()
|
||||
|
||||
def delete_source(self, source_id: str):
|
||||
session = self.Session()
|
||||
|
||||
# delete from sources table
|
||||
session.query(SourceModel).filter(SourceModel.id == source_id).delete()
|
||||
self.session.query(SourceModel).filter(SourceModel.id == source_id).delete()
|
||||
|
||||
# delete any mappings
|
||||
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()
|
||||
self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()
|
||||
|
||||
session.commit()
|
||||
self.session.commit()
|
||||
|
||||
def delete_user(self, user_id: str):
|
||||
session = self.Session()
|
||||
|
||||
# delete from users table
|
||||
session.query(UserModel).filter(UserModel.id == user_id).delete()
|
||||
self.session.query(UserModel).filter(UserModel.id == user_id).delete()
|
||||
|
||||
# delete associated agents
|
||||
session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
|
||||
self.session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
|
||||
|
||||
# delete associated sources
|
||||
session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
|
||||
self.session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
|
||||
|
||||
# delete associated mappings
|
||||
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
|
||||
self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
|
||||
|
||||
session.commit()
|
||||
self.session.commit()
|
||||
|
||||
def list_agents(self, user_id: str) -> List[AgentState]:
|
||||
session = self.Session()
|
||||
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
|
||||
results = self.session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
def list_sources(self, user_id: str) -> List[Source]:
|
||||
session = self.Session()
|
||||
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
|
||||
results = self.session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
def get_agent(self, agent_id: str = None, agent_name: str = None, user_id: str = None) -> Optional[AgentState]:
|
||||
session = self.Session()
|
||||
if agent_id:
|
||||
results = session.query(AgentModel).filter(AgentModel.id == agent_id).all()
|
||||
results = self.session.query(AgentModel).filter(AgentModel.id == agent_id).all()
|
||||
else:
|
||||
assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name"
|
||||
results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
|
||||
results = self.session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
|
||||
|
||||
if len(results) == 0:
|
||||
return None
|
||||
@@ -321,20 +311,18 @@ class MetadataStore:
|
||||
return results[0].to_record()
|
||||
|
||||
def get_user(self, user_id: str) -> Optional[User]:
|
||||
session = self.Session()
|
||||
results = session.query(UserModel).filter(UserModel.id == user_id).all()
|
||||
results = self.session.query(UserModel).filter(UserModel.id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
def get_source(self, source_id: str = None, user_id: str = None, source_name: str = None) -> Optional[Source]:
|
||||
session = self.Session()
|
||||
if source_id:
|
||||
results = session.query(SourceModel).filter(SourceModel.id == source_id).all()
|
||||
results = self.session.query(SourceModel).filter(SourceModel.id == source_id).all()
|
||||
else:
|
||||
assert user_id is not None and source_name is not None
|
||||
results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
|
||||
results = self.session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
@@ -342,23 +330,19 @@ class MetadataStore:
|
||||
|
||||
# agent source metadata
|
||||
def attach_source(self, user_id: str, agent_id: str, source_id: str):
|
||||
session = self.Session()
|
||||
session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
|
||||
session.commit()
|
||||
self.session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
|
||||
self.session.commit()
|
||||
|
||||
def list_attached_sources(self, agent_id: str) -> List[Column]:
|
||||
session = self.Session()
|
||||
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
|
||||
results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
|
||||
return [r.source_id for r in results]
|
||||
|
||||
def list_attached_agents(self, source_id):
|
||||
session = self.Session()
|
||||
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
|
||||
results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
|
||||
return [r.agent_id for r in results]
|
||||
|
||||
def detach_source(self, agent_id: str, source_id: str):
|
||||
session = self.Session()
|
||||
session.query(AgentSourceMappingModel).filter(
|
||||
self.session.query(AgentSourceMappingModel).filter(
|
||||
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
||||
).delete()
|
||||
session.commit()
|
||||
self.session.commit()
|
||||
|
||||
448
memgpt/migrate.py
Normal file
448
memgpt/migrate.py
Normal file
@@ -0,0 +1,448 @@
|
||||
import configparser
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import glob
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
import json
|
||||
import shutil
|
||||
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
import questionary
|
||||
|
||||
from llama_index import (
|
||||
StorageContext,
|
||||
load_index_from_storage,
|
||||
)
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.data_types import AgentState, User, Passage, Source
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.utils import MEMGPT_DIR, version_less_than, OpenAIBackcompatUnpickler, annotate_message_json_list_with_tool_calls
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.cli.cli_config import configure
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
|
||||
# This is the version where the breaking change was made
|
||||
VERSION_CUTOFF = "0.2.12"
|
||||
|
||||
# Migration backup dir (where we'll dump old agents that we successfully migrated)
|
||||
MIGRATION_BACKUP_FOLDER = "migration_backups"
|
||||
|
||||
|
||||
def wipe_config_and_reconfigure():
|
||||
"""Wipe (backup) the config file, and launch `memgpt configure`"""
|
||||
|
||||
# Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS)
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Construct the new backup directory name with the timestamp
|
||||
backup_filename = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}")
|
||||
existing_filename = os.path.join(MEMGPT_DIR, "config")
|
||||
|
||||
# Check if the existing file exists before moving
|
||||
if os.path.exists(existing_filename):
|
||||
# shutil should work cross-platform
|
||||
shutil.move(existing_filename, backup_filename)
|
||||
typer.secho(f"Deleted config file ({existing_filename}) and saved as backup ({backup_filename})", fg=typer.colors.GREEN)
|
||||
else:
|
||||
typer.secho(f"Couldn't find an existing config file to delete", fg=typer.colors.RED)
|
||||
|
||||
# Run configure
|
||||
configure()
|
||||
|
||||
|
||||
def config_is_compatible() -> bool:
|
||||
"""Check if the config is OK to use with 0.2.12, or if it needs to be deleted"""
|
||||
# NOTE: don't use built-in load(), since that will apply defaults
|
||||
# memgpt_config = MemGPTConfig.load()
|
||||
memgpt_config_file = os.path.join(MEMGPT_DIR, "config")
|
||||
parser = configparser.ConfigParser()
|
||||
parser.read(memgpt_config_file)
|
||||
|
||||
if "version" in parser and "memgpt_version" in parser["version"]:
|
||||
version = parser["version"]["memgpt_version"]
|
||||
else:
|
||||
version = None
|
||||
|
||||
if version is None:
|
||||
typer.secho(f"Current config version is missing", fg=typer.colors.RED)
|
||||
return False
|
||||
elif version_less_than(version, VERSION_CUTOFF):
|
||||
typer.secho(f"Current config version ({version}) is older than cutoff ({VERSION_CUTOFF})", fg=typer.colors.RED)
|
||||
return False
|
||||
else:
|
||||
typer.secho(f"Current config version {version} is compatible!", fg=typer.colors.GREEN)
|
||||
return True
|
||||
|
||||
|
||||
def agent_is_migrateable(agent_name: str) -> bool:
|
||||
"""Determine whether or not the agent folder is a migration target"""
|
||||
agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name)
|
||||
|
||||
if not os.path.exists(agent_folder):
|
||||
raise ValueError(f"Folder {agent_folder} does not exist")
|
||||
|
||||
agent_config_file = os.path.join(agent_folder, "config.json")
|
||||
if not os.path.exists(agent_config_file):
|
||||
raise ValueError(f"Agent folder {agent_folder} does not have a config file")
|
||||
|
||||
try:
|
||||
with open(agent_config_file, "r") as fh:
|
||||
agent_config = json.load(fh)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load agent config file ({agent_config_file}), error = {e}")
|
||||
|
||||
if not hasattr(agent_config, "memgpt_version") or version_less_than(agent_config.memgpt_version, VERSION_CUTOFF):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def migrate_source(source_name: str):
|
||||
"""
|
||||
Migrate an old source folder (`~/.memgpt/sources/{source_name}`).
|
||||
"""
|
||||
|
||||
# 1. Load the VectorIndex from ~/.memgpt/sources/{source_name}/index
|
||||
# TODO
|
||||
source_path = os.path.join(MEMGPT_DIR, "archival", source_name, "nodes.pkl")
|
||||
assert os.path.exists(source_path), f"Source {source_name} does not exist at {source_path}"
|
||||
|
||||
# load state from old checkpoint file
|
||||
from memgpt.cli.cli_load import load_index
|
||||
|
||||
# 2. Create a new AgentState using the agent config + agent internal state
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# gets default user
|
||||
ms = MetadataStore(config)
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
raise ValueError(
|
||||
f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
|
||||
)
|
||||
|
||||
# insert source into metadata store
|
||||
source = Source(user_id=user.id, name=source_name)
|
||||
ms.create_source(source)
|
||||
|
||||
try:
|
||||
nodes = pickle.load(open(source_path, "rb"))
|
||||
passages = []
|
||||
for node in nodes:
|
||||
# print(len(node.embedding))
|
||||
# TODO: make sure embedding config matches embedding size?
|
||||
passages.append(Passage(user_id=user.id, data_source=source_name, text=node.text, embedding=node.embedding))
|
||||
|
||||
assert len(passages) > 0, f"Source {source_name} has no passages"
|
||||
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
|
||||
conn.insert_many(passages)
|
||||
print(f"Inserted {len(passages)} to {source_name}")
|
||||
except Exception as e:
|
||||
# delete from metadata store
|
||||
ms.delete_source(source.id)
|
||||
raise ValueError(f"Failed to migrate {source_name}: {str(e)}")
|
||||
|
||||
# basic checks
|
||||
source = ms.get_source(user_id=user.id, source_name=source_name)
|
||||
assert source is not None, f"Failed to load source {source_name} from database after migration"
|
||||
|
||||
|
||||
def migrate_agent(agent_name: str):
|
||||
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
|
||||
|
||||
Steps:
|
||||
1. Load the agent state JSON from the old folder
|
||||
2. Create a new AgentState using the agent config + agent internal state
|
||||
3. Instantiate a new Agent by passing AgentState to Agent.__init__
|
||||
(This will automatically run into a new database)
|
||||
"""
|
||||
|
||||
# 1. Load the agent state JSON from the old folder
|
||||
# TODO
|
||||
agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name)
|
||||
# migration_file = os.path.join(agent_folder, MIGRATION_FILE_NAME)
|
||||
|
||||
# load state from old checkpoint file
|
||||
agent_ckpt_directory = os.path.join(agent_folder, "agent_state")
|
||||
json_files = glob.glob(os.path.join(agent_ckpt_directory, "*.json")) # This will list all .json files in the current directory.
|
||||
if not json_files:
|
||||
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}")
|
||||
# NOTE this is a soft fail, just allow it to pass
|
||||
# return
|
||||
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
state_filename = max(json_files, key=os.path.getmtime)
|
||||
state_dict = json.load(open(state_filename, "r"))
|
||||
|
||||
# print(state_dict.keys())
|
||||
# print(state_dict["memory"])
|
||||
# dict_keys(['model', 'system', 'functions', 'messages', 'messages_total', 'memory'])
|
||||
|
||||
# load old data from the persistence manager
|
||||
persistence_filename = os.path.basename(state_filename).replace(".json", ".persistence.pickle")
|
||||
persistence_filename = os.path.join(agent_folder, "persistence_manager", persistence_filename)
|
||||
archival_filename = os.path.join(agent_folder, "persistence_manager", "index", "nodes.pkl")
|
||||
if not os.path.exists(persistence_filename):
|
||||
raise ValueError(f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}")
|
||||
|
||||
try:
|
||||
with open(persistence_filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
except ModuleNotFoundError as e:
|
||||
# Patch for stripped openai package
|
||||
# ModuleNotFoundError: No module named 'openai.openai_object'
|
||||
with open(persistence_filename, "rb") as f:
|
||||
unpickler = OpenAIBackcompatUnpickler(f)
|
||||
data = unpickler.load()
|
||||
|
||||
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
||||
|
||||
def convert_openai_objects_to_dict(obj):
|
||||
if isinstance(obj, OpenAIObject):
|
||||
# Convert to dict or handle as needed
|
||||
# print(f"detected OpenAIObject on {obj}")
|
||||
return obj.to_dict_recursive()
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_openai_objects_to_dict(v) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
data = convert_openai_objects_to_dict(data)
|
||||
|
||||
# data will contain:
|
||||
# print("data.keys()", data.keys())
|
||||
# manager.all_messages = data["all_messages"]
|
||||
# manager.messages = data["messages"]
|
||||
# manager.recall_memory = data["recall_memory"]
|
||||
|
||||
agent_config_filename = os.path.join(agent_folder, "config.json")
|
||||
with open(agent_config_filename, "r") as fh:
|
||||
agent_config = json.load(fh)
|
||||
|
||||
# 2. Create a new AgentState using the agent config + agent internal state
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# gets default user
|
||||
ms = MetadataStore(config)
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
raise ValueError(
|
||||
f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
|
||||
)
|
||||
# ms.create_user(User(id=user_id))
|
||||
# user = ms.get_user(user_id=user_id)
|
||||
# if user is None:
|
||||
# typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
|
||||
# sys.exit(1)
|
||||
|
||||
agent_state = AgentState(
|
||||
name=agent_config["name"],
|
||||
user_id=user.id,
|
||||
persona=agent_config["persona"], # eg 'sam_pov'
|
||||
human=agent_config["human"], # eg 'basic'
|
||||
preset=agent_config["preset"], # eg 'memgpt_chat'
|
||||
state=dict(
|
||||
human=state_dict["memory"]["human"],
|
||||
persona=state_dict["memory"]["persona"],
|
||||
system=state_dict["system"],
|
||||
functions=state_dict["functions"], # this shouldn't matter, since Agent.__init__ will re-link
|
||||
messages=annotate_message_json_list_with_tool_calls(state_dict["messages"]),
|
||||
),
|
||||
llm_config=user.default_llm_config,
|
||||
embedding_config=user.default_embedding_config,
|
||||
)
|
||||
|
||||
# 3. Instantiate a new Agent by passing AgentState to Agent.__init__
|
||||
# NOTE: the Agent.__init__ will trigger a save, which will write to the DB
|
||||
try:
|
||||
agent = Agent(
|
||||
agent_state=agent_state,
|
||||
messages_total=state_dict["messages_total"], # TODO: do we need this?
|
||||
interface=None,
|
||||
)
|
||||
except Exception as e:
|
||||
# if "Agent with name" in str(e):
|
||||
# print(e)
|
||||
# return
|
||||
# elif "was specified in agent.state.functions":
|
||||
# print(e)
|
||||
# return
|
||||
# else:
|
||||
# raise
|
||||
raise
|
||||
|
||||
# Wrap the rest in a try-except so that we can cleanup by deleting the agent if we fail
|
||||
try:
|
||||
## 4. Insert into recall
|
||||
# TODO should this be 'messages', or 'all_messages'?
|
||||
# all_messages in recall will have fields "timestamp" and "message"
|
||||
full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
|
||||
# We want to keep the timestamp
|
||||
for i in range(len(data["all_messages"])):
|
||||
data["all_messages"][i]["message"] = full_message_history_buffer[i]
|
||||
messages_to_insert = [agent.persistence_manager.json_to_message(msg) for msg in data["all_messages"]]
|
||||
agent.persistence_manager.recall_memory.insert_many(messages_to_insert)
|
||||
# print("Finished migrating recall memory")
|
||||
|
||||
# TODO should we also assign data["messages"] to RecallMemory.messages?
|
||||
|
||||
# 5. Insert into archival
|
||||
if os.path.exists(archival_filename):
|
||||
nodes = pickle.load(open(archival_filename, "rb"))
|
||||
passages = []
|
||||
for node in nodes:
|
||||
# print(len(node.embedding))
|
||||
# TODO: make sure embeding size matches embedding config?
|
||||
passages.append(Passage(user_id=user.id, agent_id=agent_state.id, text=node.text, embedding=node.embedding))
|
||||
if len(passages) > 0:
|
||||
agent.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
print(f"Inserted {len(passages)} passages into archival memory")
|
||||
|
||||
else:
|
||||
print("No archival memory found at", archival_filename)
|
||||
|
||||
except:
|
||||
ms.delete_agent(agent_state.id)
|
||||
raise
|
||||
|
||||
try:
|
||||
new_agent_folder = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents", agent_name)
|
||||
shutil.move(agent_folder, new_agent_folder)
|
||||
except Exception as e:
|
||||
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
|
||||
raise
|
||||
|
||||
|
||||
# def migrate_all_agents(stop_on_fail=True):
|
||||
def migrate_all_agents(stop_on_fail: bool = False) -> dict:
|
||||
"""Scan over all agent folders in MEMGPT_DIR and migrate each agent."""
|
||||
|
||||
if not os.path.exists(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)):
|
||||
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER))
|
||||
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents"))
|
||||
|
||||
if not config_is_compatible():
|
||||
typer.secho(f"Your current config file is incompatible with MemGPT versions >= {VERSION_CUTOFF}", fg=typer.colors.RED)
|
||||
if questionary.confirm(
|
||||
"To migrate old MemGPT agents, you must delete your config file and run `memgpt configure`. Would you like to proceed?"
|
||||
).ask():
|
||||
try:
|
||||
wipe_config_and_reconfigure()
|
||||
except Exception as e:
|
||||
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
else:
|
||||
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
agents_dir = os.path.join(MEMGPT_DIR, "agents")
|
||||
|
||||
# Ensure the directory exists
|
||||
if not os.path.exists(agents_dir):
|
||||
raise ValueError(f"Directory {agents_dir} does not exist.")
|
||||
|
||||
# Get a list of all folders in agents_dir
|
||||
agent_folders = [f for f in os.listdir(agents_dir) if os.path.isdir(os.path.join(agents_dir, f))]
|
||||
|
||||
# Iterate over each folder with a tqdm progress bar
|
||||
count = 0
|
||||
failures = []
|
||||
candidates = []
|
||||
try:
|
||||
for agent_name in tqdm(agent_folders, desc="Migrating agents"):
|
||||
# Assuming migrate_agent is a function that takes the agent name and performs migration
|
||||
try:
|
||||
if agent_is_migrateable(agent_name=agent_name):
|
||||
candidates.append(agent_name)
|
||||
migrate_agent(agent_name)
|
||||
count += 1
|
||||
else:
|
||||
continue
|
||||
except Exception as e:
|
||||
failures.append({"name": agent_name, "reason": str(e)})
|
||||
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
|
||||
traceback.print_exc()
|
||||
if stop_on_fail:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
|
||||
|
||||
if len(candidates) == 0:
|
||||
typer.secho(f"No migration candidates found ({len(agent_folders)} agent folders total)", fg=typer.colors.GREEN)
|
||||
else:
|
||||
typer.secho(f"Inspected {len(agent_folders)} agent folders")
|
||||
if len(failures) > 0:
|
||||
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
|
||||
for fail in failures:
|
||||
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
|
||||
typer.secho(f"❌ {len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
|
||||
if count > 0:
|
||||
typer.secho(f"✅ {count}/{len(candidates)} agents were successfully migrated to the new database format", fg=typer.colors.GREEN)
|
||||
|
||||
return {
|
||||
"agent_folders": len(agent_folders),
|
||||
"migration_candidates": len(candidates),
|
||||
"successful_migrations": count,
|
||||
"failed_migrations": len(failures),
|
||||
}
|
||||
|
||||
|
||||
def migrate_all_sources(stop_on_fail: bool = False) -> dict:
|
||||
"""Scan over all agent folders in MEMGPT_DIR and migrate each agent."""
|
||||
|
||||
sources_dir = os.path.join(MEMGPT_DIR, "archival")
|
||||
|
||||
# Ensure the directory exists
|
||||
if not os.path.exists(sources_dir):
|
||||
raise ValueError(f"Directory {sources_dir} does not exist.")
|
||||
|
||||
# Get a list of all folders in agents_dir
|
||||
source_folders = [f for f in os.listdir(sources_dir) if os.path.isdir(os.path.join(sources_dir, f))]
|
||||
|
||||
# Iterate over each folder with a tqdm progress bar
|
||||
count = 0
|
||||
failures = []
|
||||
candidates = []
|
||||
try:
|
||||
for source_name in tqdm(source_folders, desc="Migrating data sources"):
|
||||
# Assuming migrate_agent is a function that takes the agent name and performs migration
|
||||
try:
|
||||
candidates.append(source_name)
|
||||
migrate_source(source_name)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
failures.append({"name": source_name, "reason": str(e)})
|
||||
traceback.print_exc()
|
||||
if stop_on_fail:
|
||||
raise
|
||||
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
|
||||
except KeyboardInterrupt:
|
||||
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
|
||||
|
||||
if len(candidates) == 0:
|
||||
typer.secho(f"No migration candidates found ({len(source_folders)} source folders total)", fg=typer.colors.GREEN)
|
||||
else:
|
||||
typer.secho(f"Inspected {len(source_folders)} source folders")
|
||||
if len(failures) > 0:
|
||||
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
|
||||
for fail in failures:
|
||||
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
|
||||
typer.secho(f"❌ {len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
|
||||
if count > 0:
|
||||
typer.secho(f"✅ {count}/{len(candidates)} sources were successfully migrated to the new database format", fg=typer.colors.GREEN)
|
||||
|
||||
return {
|
||||
"source_folders": len(source_folders),
|
||||
"migration_candidates": len(candidates),
|
||||
"successful_migrations": count,
|
||||
"failed_migrations": len(failures),
|
||||
}
|
||||
@@ -10,9 +10,12 @@ from memgpt.data_types import Message, ToolCall, AgentState
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def parse_formatted_time(formatted_time):
|
||||
def parse_formatted_time(formatted_time: str):
|
||||
# parse times returned by memgpt.utils.get_formatted_time()
|
||||
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
try:
|
||||
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
except:
|
||||
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p")
|
||||
|
||||
|
||||
class PersistenceManager(ABC):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
import copy
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
@@ -6,8 +7,11 @@ import pickle
|
||||
import platform
|
||||
import random
|
||||
import subprocess
|
||||
import uuid
|
||||
import sys
|
||||
import io
|
||||
from typing import List
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from contextlib import contextmanager
|
||||
import difflib
|
||||
@@ -456,6 +460,86 @@ NOUN_BANK = [
|
||||
]
|
||||
|
||||
|
||||
def annotate_message_json_list_with_tool_calls(messages: List[dict]):
|
||||
"""Add in missing tool_call_id fields to a list of messages using function call style
|
||||
|
||||
Walk through the list forwards:
|
||||
- If we encounter an assistant message that calls a function ("function_call") but doesn't have a "tool_call_id" field
|
||||
- Generate the tool_call_id
|
||||
- Then check if the subsequent message is a role == "function" message
|
||||
- If so, then att
|
||||
"""
|
||||
tool_call_index = None
|
||||
tool_call_id = None
|
||||
updated_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if "role" not in message:
|
||||
raise ValueError(f"message missing 'role' field:\n{message}")
|
||||
|
||||
# If we find a function call w/o a tool call ID annotation, annotate it
|
||||
if message["role"] == "assistant" and "function_call" in message:
|
||||
if "tool_call_id" in message and message["tool_call_id"] is not None:
|
||||
printd(f"Message already has tool_call_id")
|
||||
tool_call_id = message["tool_call_id"]
|
||||
else:
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
message["tool_call_id"] = tool_call_id
|
||||
tool_call_index = i
|
||||
|
||||
# After annotating the call, we expect to find a follow-up response (also unannotated)
|
||||
elif message["role"] == "function":
|
||||
# We should have a new tool call id in the buffer
|
||||
if tool_call_id is None:
|
||||
# raise ValueError(
|
||||
print(
|
||||
f"Got a function call role, but did not have a saved tool_call_id ready to use (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
# allow a soft fail in this case
|
||||
message["tool_call_id"] = str(uuid.uuid4())
|
||||
elif "tool_call_id" in message:
|
||||
raise ValueError(
|
||||
f"Got a function call role, but it already had a saved tool_call_id (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
elif i != tool_call_index + 1:
|
||||
raise ValueError(
|
||||
f"Got a function call role, saved tool_call_id came earlier than i-1 (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
else:
|
||||
message["tool_call_id"] = tool_call_id
|
||||
tool_call_id = None # wipe the buffer
|
||||
|
||||
elif message["role"] == "tool":
|
||||
raise NotImplementedError(
|
||||
f"tool_call_id annotation is meant for deprecated functions style, but got role 'tool' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
|
||||
else:
|
||||
# eg role == 'user', nothing to do here
|
||||
pass
|
||||
|
||||
updated_messages.append(copy.deepcopy(message))
|
||||
|
||||
return updated_messages
|
||||
|
||||
|
||||
def version_less_than(version_a: str, version_b: str) -> bool:
|
||||
"""Compare versions to check if version_a is less than version_b."""
|
||||
# Regular expression to match version strings of the format int.int.int
|
||||
version_pattern = re.compile(r"^\d+\.\d+\.\d+$")
|
||||
|
||||
# Assert that version strings match the required format
|
||||
if not version_pattern.match(version_a) or not version_pattern.match(version_b):
|
||||
raise ValueError("Version strings must be in the format 'int.int.int'")
|
||||
|
||||
# Split the version strings into parts
|
||||
parts_a = [int(part) for part in version_a.split(".")]
|
||||
parts_b = [int(part) for part in version_b.split(".")]
|
||||
|
||||
# Compare version parts
|
||||
return parts_a < parts_b
|
||||
|
||||
|
||||
def create_random_username() -> str:
|
||||
"""Generate a random username by combining an adjective and a noun."""
|
||||
adjective = random.choice(ADJECTIVE_BANK).capitalize()
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
20
tests/data/memgpt-0.2.11/agents/agent_test/config.json
Normal file
20
tests/data/memgpt-0.2.11/agents/agent_test/config.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"name": "agent_test",
|
||||
"persona": "sam_pov",
|
||||
"human": "basic",
|
||||
"preset": "memgpt_chat",
|
||||
"context_window": 8192,
|
||||
"model": "gpt-4",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null,
|
||||
"embedding_endpoint_type": "openai",
|
||||
"embedding_endpoint": "https://api.openai.com/v1",
|
||||
"embedding_model": null,
|
||||
"embedding_dim": 1536,
|
||||
"embedding_chunk_size": 300,
|
||||
"data_sources": [],
|
||||
"create_time": "2024-01-11 12:42:25 PM",
|
||||
"memgpt_version": "0.2.11",
|
||||
"agent_config_path": "/Users/sarahwooders/.memgpt/agents/agent_test/config.json"
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "agent_test_attach",
|
||||
"persona": "sam_pov",
|
||||
"human": "basic",
|
||||
"preset": "memgpt_chat",
|
||||
"context_window": 8192,
|
||||
"model": "gpt-4",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null,
|
||||
"embedding_endpoint_type": "openai",
|
||||
"embedding_endpoint": "https://api.openai.com/v1",
|
||||
"embedding_model": null,
|
||||
"embedding_dim": 1536,
|
||||
"embedding_chunk_size": 300,
|
||||
"data_sources": [
|
||||
"test"
|
||||
],
|
||||
"create_time": "2024-01-11 12:41:37 PM",
|
||||
"memgpt_version": "0.2.11",
|
||||
"agent_config_path": "/Users/sarahwooders/.memgpt/agents/agent_test_attach/config.json"
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"name": "agent_test_empty_archival",
|
||||
"persona": "sam_pov",
|
||||
"human": "basic",
|
||||
"preset": "memgpt_chat",
|
||||
"context_window": 8192,
|
||||
"model": "gpt-4",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null,
|
||||
"embedding_endpoint_type": "openai",
|
||||
"embedding_endpoint": "https://api.openai.com/v1",
|
||||
"embedding_model": null,
|
||||
"embedding_dim": 1536,
|
||||
"embedding_chunk_size": 300,
|
||||
"data_sources": [],
|
||||
"create_time": "2024-01-11 12:44:07 PM",
|
||||
"memgpt_version": "0.2.11",
|
||||
"agent_config_path": "/Users/sarahwooders/.memgpt/agents/agent_test_empty_archival/config.json"
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
<EFBFBD>]<5D>.
|
||||
BIN
tests/data/memgpt-0.2.11/archival/test/nodes.pkl
Normal file
BIN
tests/data/memgpt-0.2.11/archival/test/nodes.pkl
Normal file
Binary file not shown.
29
tests/data/memgpt-0.2.11/config
Normal file
29
tests/data/memgpt-0.2.11/config
Normal file
@@ -0,0 +1,29 @@
|
||||
[defaults]
|
||||
preset = memgpt_chat
|
||||
persona = sam_pov
|
||||
human = basic
|
||||
|
||||
[model]
|
||||
model = gpt-4
|
||||
model_endpoint = https://api.openai.com/v1
|
||||
model_endpoint_type = openai
|
||||
context_window = 8192
|
||||
|
||||
[openai]
|
||||
key = FAKE_KEY
|
||||
|
||||
[embedding]
|
||||
embedding_endpoint_type = openai
|
||||
embedding_endpoint = https://api.openai.com/v1
|
||||
embedding_dim = 1536
|
||||
embedding_chunk_size = 300
|
||||
|
||||
[archival_storage]
|
||||
type = local
|
||||
|
||||
[version]
|
||||
memgpt_version = 0.2.11
|
||||
|
||||
[client]
|
||||
anon_clientid = 00000000000000000000d67f40108c5c
|
||||
|
||||
Reference in New Issue
Block a user