From 0e935d3ebdd7867d49089cba6d3f2efaa4748df2 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Mon, 11 Dec 2023 16:59:21 -0800 Subject: [PATCH] Add more compehensive tests, make row ids be strings (not integers) --- memgpt/cli/cli_config.py | 30 +- memgpt/config.py | 47 +-- memgpt/connectors/chroma.py | 65 ++-- memgpt/connectors/db.py | 7 +- memgpt/connectors/storage.py | 52 +-- memgpt/data_types.py | 6 +- memgpt/embeddings.py | 8 +- memgpt/main.py | 9 +- tests/test_storage.py | 664 ++++++++++++++++++++++------------- 9 files changed, 519 insertions(+), 369 deletions(-) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index f4df71b0..47073292 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -10,7 +10,7 @@ from memgpt import utils from memgpt.config import MemGPTConfig, AgentConfig from memgpt.constants import MEMGPT_DIR -from memgpt.connectors.storage import StorageConnector +from memgpt.connectors.storage import StorageConnector, TableType from memgpt.constants import LLM_MAX_TOKENS from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME from memgpt.local_llm.utils import get_available_wrappers @@ -601,3 +601,31 @@ def add( # write text to file with open(os.path.join(directory, name), "w") as f: f.write(text) + + +@app.command() +def delete( + option: str, + name: str = typer.Option(help="Name of human/persona/agent/source to delete"), +): + if option == "agent": + # delete state/config + # TODO: this will eventually need to go through the storage connector + agent_config = AgentConfig.load(name) + # remove directory + shutil.rmtree(agent_config.save_dir()) + + # delete memory + recall_storage = StorageConnector.get_recall_storage_connector(agent_config) + recall_storage.delete() + archival_storage = StorageConnector.get_archival_storage_connector(agent_config) + archival_storage.delete() + + elif option == "source": + # TODO: also delete document store + # TODO: remove data from any agents that have loaded it in (?) + storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES) + storage.delete({"data_source": name}) + + else: + raise NotImplementedError diff --git a/memgpt/config.py b/memgpt/config.py index d220958f..e6342095 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -303,50 +303,6 @@ class AgentConfig: os.path.join(MEMGPT_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path ) - def link_functions(self, function_schemas): - - # need to dynamically link the functions - # the saved agent.functions will just have the schemas, but we need to - # go through the functions library and pull the respective python functions - - # Available functions is a mapping from: - # function_name -> { - # json_schema: schema - # python_function: function - # } - # agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create) - # [{'name': ..., 'description': ...}, {...}] - available_functions = load_all_function_sets() - linked_function_set = {} - for f_schema in function_schemas: - # Attempt to find the function in the existing function library - f_name = f_schema.get("name") - if f_name is None: - raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}") - linked_function = available_functions.get(f_name) - if linked_function is None: - raise ValueError( - f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}" - ) - # Once we find a matching function, make sure the schema is identical - if json.dumps(f_schema) != json.dumps(linked_function["json_schema"]): - # error_message = ( - # f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different." - # + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2)}" - # + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2)}" - # ) - schema_diff = get_schema_diff(f_schema, linked_function["json_schema"]) - error_message = ( - f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n" - + "".join(schema_diff) - ) - - # NOTE to handle old configs, instead of erroring here let's just warn - # raise ValueError(error_message) - utils.printd(error_message) - linked_function_set[f_name] = linked_function - return linked_function_set - def generate_agent_id(self, length=6): ## random character based # characters = string.ascii_lowercase + string.digits @@ -362,6 +318,9 @@ class AgentConfig: self.data_sources.append(data_source) self.save() + def save_dir(self): + return os.path.join(MEMGPT_DIR, "agents", self.name) + def save_state_dir(self): # directory to save agent state return os.path.join(MEMGPT_DIR, "agents", self.name, "agent_state") diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py index a2e7a43f..74b32547 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/connectors/chroma.py @@ -29,21 +29,37 @@ class ChromaStorageConnector(StorageConnector): # get a collection or create if it doesn't exist already self.collection = self.client.get_or_create_collection(self.table_name) - self.include = ["id", "documents", "embeddings", "metadatas"] + self.include = ["documents", "embeddings", "metadatas"] + + def get_filters(self, filters: Optional[Dict] = {}): + # get all filters for query + if filters is not None: + filter_conditions = {**self.filters, **filters} + else: + filter_conditions = self.filters + + # convert to chroma format + chroma_filters = {"$and": []} + for key, value in filter_conditions.items(): + chroma_filters["$and"].append({key: {"$eq": value}}) + return chroma_filters def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]: offset = 0 filters = self.get_filters(filters) + print(filters) while True: # Retrieve a chunk of records with the given page_size - db_chunks = self.collection.get(offset=offset, limit=page_size, include=self.include, where=filters) + print("querying...", self.collection.count(), offset, page_size) + results = self.collection.get(offset=offset, limit=page_size, include=self.include, where=filters) + print(len(results["embeddings"])) # If the chunk is empty, we've retrieved all records - if not db_chunks: + if len(results["embeddings"]) == 0: break # Yield a list of Record objects converted from the chunk - yield self.results_to_records(db_chunks) + yield self.results_to_records(results) # Increment the offset to get the next chunk in the next iteration offset += page_size @@ -54,8 +70,8 @@ class ChromaStorageConnector(StorageConnector): if "created_at" in metadata: metadata["created_at"] = timestamp_to_datetime(metadata["created_at"]) return [ - self.type(id=id, text=text, embedding=embedding, **metadatas) - for (id, text, embedding, metadatas) in zip(results["ids"], results["documents"], results["embeddings"], results["metadatas"]) + self.type(text=text, embedding=embedding, **metadatas) + for (text, embedding, metadatas) in zip(results["documents"], results["embeddings"], results["metadatas"]) ] def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]: @@ -68,26 +84,12 @@ class ChromaStorageConnector(StorageConnector): results = self.collection.get(ids=[id]) return self.results_to_records(results) - def insert(self, record: Record): - if record.id is None: - record.id = str(self.collection.count()) - metadata = vars(record) - metadata.pop("id") - metadata.pop("text") - metadata.pop("embedding") - self.collection.add(documents=[record.text], embeddings=[record.embedding], ids=[record.id], metadatas=[metadata]) - - def insert_many(self, records: List[Record], show_progress=True): - count = self.collection.count() + def format_records(self, records: List[Record]): metadatas = [] - ids = [] + ids = [str(record.id) for record in records] documents = [record.text for record in records] embeddings = [record.embedding for record in records] for record in records: - if record.id is None: - count += 1 - ids.append(str(count)) - # TODO: ensure that other record.ids dont match metadata = vars(record) metadata.pop("id") metadata.pop("text") @@ -96,6 +98,17 @@ class ChromaStorageConnector(StorageConnector): metadata["created_at"] = datetime_to_timestamp(metadata["created_at"]) metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed metadatas.append(metadata) + return ids, documents, embeddings, metadatas + + def insert(self, record: Record): + ids, documents, embeddings, metadatas = self.format_records([record]) + if not any(embeddings): + self.collection.add(documents=documents, ids=ids, metadatas=metadatas) + else: + self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas) + + def insert_many(self, records: List[Record], show_progress=True): + ids, documents, embeddings, metadatas = self.format_records(records) if not any(embeddings): self.collection.add(documents=documents, ids=ids, metadatas=metadatas) else: @@ -110,8 +123,11 @@ class ChromaStorageConnector(StorageConnector): pass def size(self, filters: Optional[Dict] = {}) -> int: - filters = self.get_filters(filters) - return self.collection.count() + # unfortuantely, need to use pagination to get filtering + count = 0 + for records in self.get_all_paginated(page_size=100, filters=filters): + count += len(records) + return count def list_data_sources(self): raise NotImplementedError @@ -123,6 +139,7 @@ class ChromaStorageConnector(StorageConnector): def query_date(self, start_date, end_date, start=None, count=None): # TODO: no idea if this is correct + # TODO: convert start/end_date into timestamp filters = self.get_filters(filters) filters["created_at"] = { "$gte": start_date, diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 7bb70113..5134c803 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -8,8 +8,9 @@ from sqlalchemy.orm import sessionmaker, mapped_column from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import func from sqlalchemy import Column, BIGINT, String, DateTime -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy_json import mutable_json_type +import uuid import re from tqdm import tqdm @@ -41,7 +42,7 @@ def get_db_model(table_name: str, table_type: TableType): __abstract__ = True # this line is necessary # Assuming passage_id is the primary key - id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) user_id = Column(String, nullable=False) text = Column(String, nullable=False) doc_id = Column(String) @@ -77,7 +78,7 @@ def get_db_model(table_name: str, table_type: TableType): __abstract__ = True # this line is necessary # Assuming message_id is the primary key - id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) user_id = Column(String, nullable=False) agent_id = Column(String, nullable=False) role = Column(String, nullable=False) diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 20b790ad..9367a562 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -70,7 +70,6 @@ class StorageConnector: filter_conditions = self.filters print("FILTERS", filter_conditions) return filter_conditions - return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] def generate_table_name(self, agent_config: AgentConfig, table_type: TableType): @@ -99,59 +98,42 @@ class StorageConnector: raise ValueError(f"Table type {table_type} not implemented") @staticmethod - def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None): - storage_type = MemGPTConfig.load().archival_storage_type + def get_storage_connector(table_type: TableType, storage_type: Optional[str] = None, agent_config: Optional[AgentConfig] = None): - if storage_type == "local": - from memgpt.connectors.local import VectorIndexStorageConnector + # read from config if not provided + if storage_type is None: + storage_type = MemGPTConfig.load().archival_storage_type - return VectorIndexStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) - - elif storage_type == "postgres": + if storage_type == "postgres": from memgpt.connectors.db import PostgresStorageConnector - return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) + return PostgresStorageConnector(agent_config=agent_config, table_type=table_type) elif storage_type == "chroma": from memgpt.connectors.chroma import ChromaStorageConnector - return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) + return ChromaStorageConnector(agent_config=agent_config, table_type=table_type) elif storage_type == "lancedb": from memgpt.connectors.db import LanceDBConnector - return LanceDBConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) + return LanceDBConnector(agent_config=agent_config, table_type=table_type) else: raise NotImplementedError(f"Storage type {storage_type} not implemented") + @staticmethod + def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None): + return StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, agent_config=agent_config) + @staticmethod def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None): - storage_type = MemGPTConfig.load().recall_storage_type - - print("Recall storage type", storage_type) - - if storage_type == "local": - from memgpt.connectors.local import InMemoryStorageConnector - - # maintains in-memory list for storage - return InMemoryStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY) - - elif storage_type == "postgres": - from memgpt.connectors.db import PostgresStorageConnector - - return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY) - - elif storage_type == "chroma": - from memgpt.connectors.chroma import ChromaStorageConnector - - return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY) - - else: - raise NotImplementedError(f"Storage type {storage_type} not implemented") + return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, agent_config=agent_config) @staticmethod - def list_loaded_data(): + def list_loaded_data(storage_type: Optional[str] = None): # TODO: modify this to simply list loaded data from a given user - storage_type = MemGPTConfig.load().archival_storage_type + if storage_type is None: + storage_type = MemGPTConfig.load().archival_storage_type + if storage_type == "local": from memgpt.connectors.local import VectorIndexStorageConnector diff --git a/memgpt/data_types.py b/memgpt/data_types.py index fd6ab5c6..59cd2416 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -1,4 +1,5 @@ """ This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """ +import uuid from abc import abstractmethod from typing import Optional import numpy as np @@ -18,7 +19,10 @@ class Record: self.user_id = user_id self.agent_id = agent_id self.text = text - self.id = id + if id is None: + self.id = uuid.uuid4() + else: + self.id = id # todo: generate unique uuid # todo: self.role = role (?) diff --git a/memgpt/embeddings.py b/memgpt/embeddings.py index fc3023d4..735ec3f7 100644 --- a/memgpt/embeddings.py +++ b/memgpt/embeddings.py @@ -97,14 +97,14 @@ def embedding_model(): # load config config = MemGPTConfig.load() + endpoint_type = config.embedding_endpoint_type - endpoint = config.embedding_endpoint_type - if endpoint == "openai": + if endpoint_type == "openai": model = OpenAIEmbedding( api_base=config.embedding_endpoint, api_key=config.openai_key, additional_kwargs={"user": config.anon_clientid} ) return model - elif endpoint == "azure": + elif endpoint_type == "azure": # https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings model = "text-embedding-ada-002" deployment = config.azure_embedding_deployment if config.azure_embedding_deployment is not None else model @@ -115,7 +115,7 @@ def embedding_model(): azure_endpoint=config.azure_endpoint, api_version=config.azure_version, ) - elif endpoint == "hugging-face": + elif endpoint_type == "hugging-face": embed_model = EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=config.anon_clientid) return embed_model else: diff --git a/memgpt/main.py b/memgpt/main.py index 608edf3c..f8f94d3c 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -21,9 +21,8 @@ from memgpt.interface import CLIInterface as interface # for printing to termin import memgpt.agent as agent import memgpt.system as system import memgpt.constants as constants -import memgpt.errors as errors -from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart -from memgpt.cli.cli_config import configure, list, add +from memgpt.cli.cli import run, attach, version +from memgpt.cli.cli_config import configure, list, add, delete from memgpt.cli.cli_load import app as load_app from memgpt.connectors.storage import StorageConnector @@ -34,9 +33,7 @@ app.command(name="attach")(attach) app.command(name="configure")(configure) app.command(name="list")(list) app.command(name="add")(add) -app.command(name="server")(server) -app.command(name="folder")(open_folder) -app.command(name="quickstart")(quickstart) +app.command(name="delete")(delete) # load data commands app.add_typer(load_app, name="load") diff --git a/tests/test_storage.py b/tests/test_storage.py index 44d3b8fb..49c2ca91 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -17,280 +17,442 @@ from memgpt.embeddings import embedding_model from memgpt.data_types import Message, Passage from memgpt.config import MemGPTConfig, AgentConfig from memgpt.utils import get_local_time +from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN import argparse from datetime import datetime, timedelta -def test_recall_db(): - # os.environ["MEMGPT_CONFIG_PATH"] = "./config" - - storage_type = "postgres" - storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config = MemGPTConfig( - recall_storage_type=storage_type, - recall_storage_uri=storage_uri, - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - model="gpt4", - ) - print(config.config_path) - assert config.recall_storage_uri is not None - config.save() - print(config) - - agent_config = AgentConfig( - persona=config.persona, - human=config.human, - model=config.model, - ) - - conn = StorageConnector.get_recall_storage_connector(agent_config) - - # construct recall memory messages - message1 = Message( - agent_id=agent_config.name, - role="agent", - text="This is a test message", - user_id=config.anon_clientid, - model=agent_config.model, - created_at=datetime.now(), - ) - message2 = Message( - agent_id=agent_config.name, - role="user", - text="This is a test message", - user_id=config.anon_clientid, - model=agent_config.model, - created_at=datetime.now(), - ) - print(vars(message1)) - - # test insert - conn.insert(message1) - conn.insert_many([message2]) - - # test size - assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}" - assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}' - - # test text query - res = conn.query_text("test") - print(res) - assert len(res) >= 2, f"Expected 2 messages, got {len(res)}" - - # test date query - current_time = datetime.now() - ten_weeks_ago = current_time - timedelta(weeks=1) - res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time) - print(res) - assert len(res) >= 2, f"Expected 2 messages, got {len(res)}" - - print(conn.get_all()) +texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] +start_date = datetime(2009, 10, 5, 18, 00) +dates = [start_date - timedelta(weeks=1), start_date, start_date + timedelta(weeks=1)] +roles = ["user", "agent", "user"] +agent_ids = ["agent1", "agent2", "agent1"] +ids = ["test1", "test2", "test3"] # TODO: generate unique uuid -@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key") -def test_postgres_openai(): - if not os.getenv("PGVECTOR_TEST_DB_URL"): - return # soft pass - if not os.getenv("OPENAI_API_KEY"): - return # soft pass - - # os.environ["MEMGPT_CONFIG_PATH"] = "./config" - config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL")) - print(config.config_path) - assert config.archival_storage_uri is not None - config.archival_storage_uri = config.archival_storage_uri.replace( - "postgres://", "postgresql://" - ) # https://stackoverflow.com/a/64698899 - config.save() - print(config) - - embed_model = embedding_model() - - passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] - - agent_config = AgentConfig( - name="test_agent", - persona=config.persona, - human=config.human, - model=config.model, - ) - - db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) - - # db.delete() - # return - for passage in passage: - db.insert( +def generate_passages(embed_model): + """Generate list of 3 Passage objects""" + # embeddings: use openai if env is set, otherwise local + passages = [] + for (text, _, _, agent_id, id) in zip(texts, dates, roles, agent_ids, ids): + embedding = None + if embed_model: + embedding = embed_model.get_text_embedding(text) + passages.append( Passage( - text=passage, - embedding=embed_model.get_text_embedding(passage), - user_id=config.anon_clientid, - agent_id="test_agent", - data_source="test", - metadata={"test_metadata_key": "test_metadata_value"}, + user_id="test", + text=text, + agent_id=agent_id, + embedding=embedding, + data_source="test_source", ) ) - - print(db.get_all()) - - query = "why was she crying" - query_vec = embed_model.get_text_embedding(query) - res = db.query(None, query_vec, top_k=2) - - assert len(res) == 2, f"Expected 2 results, got {len(res)}" - assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" - - # TODO fix (causes a hang for some reason) - # print("deleting...") - # db.delete() - # print("...finished") + return passages -@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key") -def test_chroma_openai(): - if not os.getenv("OPENAI_API_KEY"): - return # soft pass +def generate_messages(): + """Generate list of 3 Message objects""" + messages = [] + for (text, date, role, agent_id, id) in zip(texts, dates, roles, agent_ids, ids): + messages.append(Message(user_id="test", text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt4")) + print(messages[-1].text) + return messages - config = MemGPTConfig( - archival_storage_type="chroma", - archival_storage_path="./test_chroma", - embedding_endpoint_type="openai", - embedding_dim=1536, - model="gpt4", - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - ) + +@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"]) +@pytest.mark.parametrize("table_type", [TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY]) +def test_storage(storage_connector, table_type): + + # setup memgpt config + # TODO: set env for different config path + config = MemGPTConfig() + if storage_connector == "postgres": + if not os.getenv("PGVECTOR_TEST_DB_URL"): + print("Skipping test, missing PG URI") + return + config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + config.archival_storage_type = "postgres" + config.recall_storage_type = "postgres" + if storage_connector == "lancedb": + if not os.getenv("LANCEDB_TEST_URL"): + print("Skipping test, missing LanceDB URI") + return + config.archival_storage_uri = os.getenv("LANCEDB_TEST_URL") + config.recall_storage_uri = os.getenv("LANCEDB_TEST_URL") + config.archival_storage_type = "lancedb" + config.recall_storage_type = "lancedb" + if storage_connector == "chroma": + config.archival_storage_type = "chroma" + config.recall_storage_type = "chroma" + config.recall_storage_path = "./test_chroma" + config.archival_storage_path = "./test_chroma" + + # get embedding model + embed_model = None + if os.getenv("OPENAI_API_KEY"): + config.embedding_endpoint_type = "openai" + config.embedding_endpoint = "https://api.openai.com/v1" + config.embedding_dim = 1536 + config.openai_key = os.getenv("OPENAI_API_KEY") + else: + config.embedding_endpoint_type = "local" + config.embedding_endpoint = None + config.embedding_dim = 384 config.save() - embed_model = embedding_model() - passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] - - db = ChromaStorageConnector(name="test-openai") - - for passage in passage: - db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) - - query = "why was she crying" - query_vec = embed_model.get_text_embedding(query) - res = db.query(query, query_vec, top_k=2) - - assert len(res) == 2, f"Expected 2 results, got {len(res)}" - assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" - - print(res[0].text) - - print("deleting") - db.delete() - - -@pytest.mark.skipif( - not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key" -) -def test_lancedb_openai(): - assert os.getenv("LANCEDB_TEST_URL") is not None - if os.getenv("OPENAI_API_KEY") is None: - return # soft pass - - config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL")) - print(config.config_path) - assert config.archival_storage_uri is not None - print(config) - - embed_model = embedding_model() - - passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] - - db = LanceDBConnector(name="test-openai") - - for passage in passage: - db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) - - print(db.get_all()) - - query = "why was she crying" - query_vec = embed_model.get_text_embedding(query) - res = db.query(None, query_vec, top_k=2) - - assert len(res) == 2, f"Expected 2 results, got {len(res)}" - assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" - - -@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI") -def test_postgres_local(): - if not os.getenv("PGVECTOR_TEST_DB_URL"): - return - # os.environ["MEMGPT_CONFIG_PATH"] = "./config" - - config = MemGPTConfig( - archival_storage_type="postgres", - archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), - embedding_endpoint_type="local", - embedding_dim=384, # use HF model + # create agent + agent_config = AgentConfig( + persona=DEFAULT_PERSONA, + human=DEFAULT_HUMAN, + model=DEFAULT_MEMGPT_MODEL, ) - print(config.config_path) - assert config.archival_storage_uri is not None - config.archival_storage_uri = config.archival_storage_uri.replace( - "postgres://", "postgresql://" - ) # https://stackoverflow.com/a/64698899 - config.save() - print(config) - embed_model = embedding_model() + # create storage connector + conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config) - passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] + # generate data + if table_type == TableType.ARCHIVAL_MEMORY: + records = generate_passages(embed_model) + elif table_type == TableType.RECALL_MEMORY: + records = generate_messages() + else: + raise NotImplementedError(f"Table type {table_type} not implemented") - db = PostgresStorageConnector(name="test-local") + # test: insert + conn.insert(records[0]) + assert conn.size() == 1, f"Expected 1 record, got {conn.size()}" - for passage in passage: - db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) + # test: insert_many + conn.insert_many(records[1:]) + assert conn.size() == 3, f"Expected 1 record, got {conn.size()}" - print(db.get_all()) + # test: list_loaded_data + if table_type == TableType.ARCHIVAL_MEMORY: + sources = StorageConnector.list_loaded_data(storage_type=storage_connector) + assert len(sources) == 1, f"Expected 1 source, got {len(sources)}" + assert sources[0] == "test_source", f"Expected 'test_source', got {sources[0]}" - query = "why was she crying" - query_vec = embed_model.get_text_embedding(query) - res = db.query(None, query_vec, top_k=2) + # test: get_all_paginated + paginated_total = 0 + for page in conn.get_all_paginated(page_size=1): + paginated_total += len(page) + assert paginated_total == 3, f"Expected 3 records, got {paginated_total}" - assert len(res) == 2, f"Expected 2 results, got {len(res)}" - assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" + # test: get_all + all_records = conn.get_all() + assert len(all_records) == 3, f"Expected 3 records, got {len(all_records)}" + all_records = conn.get_all(limit=2) + assert len(all_records) == 2, f"Expected 2 records, got {len(all_records)}" - # TODO fix (causes a hang for some reason) - # print("deleting...") - # db.delete() - # print("...finished") + # test: get + res = conn.get(id=ids[0]) + assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}" + + # test: size + assert conn.size() == 3, f"Expected 3 records, got {conn.size()}" + assert conn.size(filters={"agent_id", "agent1"}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', 'agent1'})}" + if table_type == TableType.RECALL_MEMORY: + assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}" + + # test: query (vector) + if embed_model: + query = "why was she crying" + query_vec = embed_model.get_text_embedding(query) + res = conn.query(None, query_vec, top_k=2) + assert len(res) == 2, f"Expected 2 results, got {len(res)}" + assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" + + # test: query_text + query = "CindereLLa" + res = conn.query_text(query) + assert len(res) == 1, f"Expected 1 result, got {len(res)}" + assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}" + + # test: query_date (recall memory only) + if table_type == TableType.RECALL_MEMORY: + print("Testing recall memory date search") + start_date = start_date - timedelta(days=1) + end_date = start_date + timedelta(days=1) + res = conn.query_date(start_date=start_date, end_date=end_date) + assert len(res) == 1, f"Expected 1 result, got {len(res): {res}}" + + # test: delete + conn.delete({"id": ids[0]}) + assert conn.size() == 2, f"Expected 2 records, got {conn.size()}" + conn.delete() + assert conn.size() == 0, f"Expected 0 records, got {conn.size()}" -@pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI") -def test_lancedb_local(): - assert os.getenv("LANCEDB_TEST_URL") is not None - - config = MemGPTConfig( - archival_storage_type="lancedb", - archival_storage_uri=os.getenv("LANCEDB_TEST_URL"), - embedding_model="local", - embedding_dim=384, # use HF model - ) - print(config.config_path) - assert config.archival_storage_uri is not None - - embed_model = embedding_model() - - passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] - - db = LanceDBConnector(name="test-local") - - for passage in passage: - db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) - - print(db.get_all()) - - query = "why was she crying" - query_vec = embed_model.get_text_embedding(query) - res = db.query(None, query_vec, top_k=2) - - assert len(res) == 2, f"Expected 2 results, got {len(res)}" - assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" - - -test_recall_db() +# def test_recall_db(): +# # os.environ["MEMGPT_CONFIG_PATH"] = "./config" +# +# storage_type = "postgres" +# storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") +# config = MemGPTConfig( +# recall_storage_type=storage_type, +# recall_storage_uri=storage_uri, +# model_endpoint_type="openai", +# model_endpoint="https://api.openai.com/v1", +# model="gpt4", +# ) +# print(config.config_path) +# assert config.recall_storage_uri is not None +# config.save() +# print(config) +# +# agent_config = AgentConfig( +# persona=config.persona, +# human=config.human, +# model=config.model, +# ) +# +# conn = StorageConnector.get_recall_storage_connector(agent_config) +# +# # construct recall memory messages +# message1 = Message( +# agent_id=agent_config.name, +# role="agent", +# text="This is a test message", +# user_id=config.anon_clientid, +# model=agent_config.model, +# created_at=datetime.now(), +# ) +# message2 = Message( +# agent_id=agent_config.name, +# role="user", +# text="This is a test message", +# user_id=config.anon_clientid, +# model=agent_config.model, +# created_at=datetime.now(), +# ) +# print(vars(message1)) +# +# # test insert +# conn.insert(message1) +# conn.insert_many([message2]) +# +# # test size +# assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}" +# assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}' +# +# # test text query +# res = conn.query_text("test") +# print(res) +# assert len(res) >= 2, f"Expected 2 messages, got {len(res)}" +# +# # test date query +# current_time = datetime.now() +# ten_weeks_ago = current_time - timedelta(weeks=1) +# res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time) +# print(res) +# assert len(res) >= 2, f"Expected 2 messages, got {len(res)}" +# +# print(conn.get_all()) +# +# +# @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key") +# def test_postgres_openai(): +# if not os.getenv("PGVECTOR_TEST_DB_URL"): +# return # soft pass +# if not os.getenv("OPENAI_API_KEY"): +# return # soft pass +# +# # os.environ["MEMGPT_CONFIG_PATH"] = "./config" +# config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL")) +# print(config.config_path) +# assert config.archival_storage_uri is not None +# config.archival_storage_uri = config.archival_storage_uri.replace( +# "postgres://", "postgresql://" +# ) # https://stackoverflow.com/a/64698899 +# config.save() +# print(config) +# +# embed_model = embedding_model() +# +# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] +# +# agent_config = AgentConfig( +# name="test_agent", +# persona=config.persona, +# human=config.human, +# model=config.model, +# ) +# +# db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) +# +# # db.delete() +# # return +# for passage in passage: +# db.insert( +# Passage( +# text=passage, +# embedding=embed_model.get_text_embedding(passage), +# user_id=config.anon_clientid, +# agent_id="test_agent", +# data_source="test", +# metadata={"test_metadata_key": "test_metadata_value"}, +# ) +# ) +# +# print(db.get_all()) +# +# query = "why was she crying" +# query_vec = embed_model.get_text_embedding(query) +# res = db.query(None, query_vec, top_k=2) +# +# assert len(res) == 2, f"Expected 2 results, got {len(res)}" +# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" +# +# # TODO fix (causes a hang for some reason) +# # print("deleting...") +# # db.delete() +# # print("...finished") +# +# +# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key") +# def test_chroma_openai(): +# if not os.getenv("OPENAI_API_KEY"): +# return # soft pass +# +# config = MemGPTConfig( +# archival_storage_type="chroma", +# archival_storage_path="./test_chroma", +# embedding_endpoint_type="openai", +# embedding_dim=1536, +# model="gpt4", +# model_endpoint_type="openai", +# model_endpoint="https://api.openai.com/v1", +# ) +# config.save() +# embed_model = embedding_model() +# +# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] +# +# db = ChromaStorageConnector(name="test-openai") +# +# for passage in passage: +# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) +# +# query = "why was she crying" +# query_vec = embed_model.get_text_embedding(query) +# res = db.query(query, query_vec, top_k=2) +# +# assert len(res) == 2, f"Expected 2 results, got {len(res)}" +# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" +# +# print(res[0].text) +# +# print("deleting") +# db.delete() +# +# +# @pytest.mark.skipif( +# not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key" +# ) +# def test_lancedb_openai(): +# assert os.getenv("LANCEDB_TEST_URL") is not None +# if os.getenv("OPENAI_API_KEY") is None: +# return # soft pass +# +# config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL")) +# print(config.config_path) +# assert config.archival_storage_uri is not None +# print(config) +# +# embed_model = embedding_model() +# +# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] +# +# db = LanceDBConnector(name="test-openai") +# +# for passage in passage: +# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) +# +# print(db.get_all()) +# +# query = "why was she crying" +# query_vec = embed_model.get_text_embedding(query) +# res = db.query(None, query_vec, top_k=2) +# +# assert len(res) == 2, f"Expected 2 results, got {len(res)}" +# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" +# +# +# @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI") +# def test_postgres_local(): +# if not os.getenv("PGVECTOR_TEST_DB_URL"): +# return +# # os.environ["MEMGPT_CONFIG_PATH"] = "./config" +# +# config = MemGPTConfig( +# archival_storage_type="postgres", +# archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), +# embedding_endpoint_type="local", +# embedding_dim=384, # use HF model +# ) +# print(config.config_path) +# assert config.archival_storage_uri is not None +# config.archival_storage_uri = config.archival_storage_uri.replace( +# "postgres://", "postgresql://" +# ) # https://stackoverflow.com/a/64698899 +# config.save() +# print(config) +# +# embed_model = embedding_model() +# +# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] +# +# db = PostgresStorageConnector(name="test-local") +# +# for passage in passage: +# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) +# +# print(db.get_all()) +# +# query = "why was she crying" +# query_vec = embed_model.get_text_embedding(query) +# res = db.query(None, query_vec, top_k=2) +# +# assert len(res) == 2, f"Expected 2 results, got {len(res)}" +# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" +# +# # TODO fix (causes a hang for some reason) +# # print("deleting...") +# # db.delete() +# # print("...finished") +# +# +# @pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI") +# def test_lancedb_local(): +# assert os.getenv("LANCEDB_TEST_URL") is not None +# +# config = MemGPTConfig( +# archival_storage_type="lancedb", +# archival_storage_uri=os.getenv("LANCEDB_TEST_URL"), +# embedding_model="local", +# embedding_dim=384, # use HF model +# ) +# print(config.config_path) +# assert config.archival_storage_uri is not None +# +# embed_model = embedding_model() +# +# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] +# +# db = LanceDBConnector(name="test-local") +# +# for passage in passage: +# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) +# +# print(db.get_all()) +# +# query = "why was she crying" +# query_vec = embed_model.get_text_embedding(query) +# res = db.query(None, query_vec, top_k=2) +# +# assert len(res) == 2, f"Expected 2 results, got {len(res)}" +# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" +#