Add more compehensive tests, make row ids be strings (not integers)

This commit is contained in:
Sarah Wooders
2023-12-11 16:59:21 -08:00
parent 453a7c0c3e
commit 0e935d3ebd
9 changed files with 519 additions and 369 deletions

View File

@@ -10,7 +10,7 @@ from memgpt import utils
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.connectors.storage import StorageConnector
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.constants import LLM_MAX_TOKENS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
@@ -601,3 +601,31 @@ def add(
# write text to file
with open(os.path.join(directory, name), "w") as f:
f.write(text)
@app.command()
def delete(
option: str,
name: str = typer.Option(help="Name of human/persona/agent/source to delete"),
):
if option == "agent":
# delete state/config
# TODO: this will eventually need to go through the storage connector
agent_config = AgentConfig.load(name)
# remove directory
shutil.rmtree(agent_config.save_dir())
# delete memory
recall_storage = StorageConnector.get_recall_storage_connector(agent_config)
recall_storage.delete()
archival_storage = StorageConnector.get_archival_storage_connector(agent_config)
archival_storage.delete()
elif option == "source":
# TODO: also delete document store
# TODO: remove data from any agents that have loaded it in (?)
storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES)
storage.delete({"data_source": name})
else:
raise NotImplementedError

View File

@@ -303,50 +303,6 @@ class AgentConfig:
os.path.join(MEMGPT_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path
)
def link_functions(self, function_schemas):
# need to dynamically link the functions
# the saved agent.functions will just have the schemas, but we need to
# go through the functions library and pull the respective python functions
# Available functions is a mapping from:
# function_name -> {
# json_schema: schema
# python_function: function
# }
# agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create)
# [{'name': ..., 'description': ...}, {...}]
available_functions = load_all_function_sets()
linked_function_set = {}
for f_schema in function_schemas:
# Attempt to find the function in the existing function library
f_name = f_schema.get("name")
if f_name is None:
raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}")
linked_function = available_functions.get(f_name)
if linked_function is None:
raise ValueError(
f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}"
)
# Once we find a matching function, make sure the schema is identical
if json.dumps(f_schema) != json.dumps(linked_function["json_schema"]):
# error_message = (
# f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different."
# + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2)}"
# + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2)}"
# )
schema_diff = get_schema_diff(f_schema, linked_function["json_schema"])
error_message = (
f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n"
+ "".join(schema_diff)
)
# NOTE to handle old configs, instead of erroring here let's just warn
# raise ValueError(error_message)
utils.printd(error_message)
linked_function_set[f_name] = linked_function
return linked_function_set
def generate_agent_id(self, length=6):
## random character based
# characters = string.ascii_lowercase + string.digits
@@ -362,6 +318,9 @@ class AgentConfig:
self.data_sources.append(data_source)
self.save()
def save_dir(self):
return os.path.join(MEMGPT_DIR, "agents", self.name)
def save_state_dir(self):
# directory to save agent state
return os.path.join(MEMGPT_DIR, "agents", self.name, "agent_state")

View File

@@ -29,21 +29,37 @@ class ChromaStorageConnector(StorageConnector):
# get a collection or create if it doesn't exist already
self.collection = self.client.get_or_create_collection(self.table_name)
self.include = ["id", "documents", "embeddings", "metadatas"]
self.include = ["documents", "embeddings", "metadatas"]
def get_filters(self, filters: Optional[Dict] = {}):
# get all filters for query
if filters is not None:
filter_conditions = {**self.filters, **filters}
else:
filter_conditions = self.filters
# convert to chroma format
chroma_filters = {"$and": []}
for key, value in filter_conditions.items():
chroma_filters["$and"].append({key: {"$eq": value}})
return chroma_filters
def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]:
offset = 0
filters = self.get_filters(filters)
print(filters)
while True:
# Retrieve a chunk of records with the given page_size
db_chunks = self.collection.get(offset=offset, limit=page_size, include=self.include, where=filters)
print("querying...", self.collection.count(), offset, page_size)
results = self.collection.get(offset=offset, limit=page_size, include=self.include, where=filters)
print(len(results["embeddings"]))
# If the chunk is empty, we've retrieved all records
if not db_chunks:
if len(results["embeddings"]) == 0:
break
# Yield a list of Record objects converted from the chunk
yield self.results_to_records(db_chunks)
yield self.results_to_records(results)
# Increment the offset to get the next chunk in the next iteration
offset += page_size
@@ -54,8 +70,8 @@ class ChromaStorageConnector(StorageConnector):
if "created_at" in metadata:
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
return [
self.type(id=id, text=text, embedding=embedding, **metadatas)
for (id, text, embedding, metadatas) in zip(results["ids"], results["documents"], results["embeddings"], results["metadatas"])
self.type(text=text, embedding=embedding, **metadatas)
for (text, embedding, metadatas) in zip(results["documents"], results["embeddings"], results["metadatas"])
]
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
@@ -68,26 +84,12 @@ class ChromaStorageConnector(StorageConnector):
results = self.collection.get(ids=[id])
return self.results_to_records(results)
def insert(self, record: Record):
if record.id is None:
record.id = str(self.collection.count())
metadata = vars(record)
metadata.pop("id")
metadata.pop("text")
metadata.pop("embedding")
self.collection.add(documents=[record.text], embeddings=[record.embedding], ids=[record.id], metadatas=[metadata])
def insert_many(self, records: List[Record], show_progress=True):
count = self.collection.count()
def format_records(self, records: List[Record]):
metadatas = []
ids = []
ids = [str(record.id) for record in records]
documents = [record.text for record in records]
embeddings = [record.embedding for record in records]
for record in records:
if record.id is None:
count += 1
ids.append(str(count))
# TODO: ensure that other record.ids dont match
metadata = vars(record)
metadata.pop("id")
metadata.pop("text")
@@ -96,6 +98,17 @@ class ChromaStorageConnector(StorageConnector):
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
metadatas.append(metadata)
return ids, documents, embeddings, metadatas
def insert(self, record: Record):
ids, documents, embeddings, metadatas = self.format_records([record])
if not any(embeddings):
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
else:
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
def insert_many(self, records: List[Record], show_progress=True):
ids, documents, embeddings, metadatas = self.format_records(records)
if not any(embeddings):
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
else:
@@ -110,8 +123,11 @@ class ChromaStorageConnector(StorageConnector):
pass
def size(self, filters: Optional[Dict] = {}) -> int:
filters = self.get_filters(filters)
return self.collection.count()
# unfortuantely, need to use pagination to get filtering
count = 0
for records in self.get_all_paginated(page_size=100, filters=filters):
count += len(records)
return count
def list_data_sources(self):
raise NotImplementedError
@@ -123,6 +139,7 @@ class ChromaStorageConnector(StorageConnector):
def query_date(self, start_date, end_date, start=None, count=None):
# TODO: no idea if this is correct
# TODO: convert start/end_date into timestamp
filters = self.get_filters(filters)
filters["created_at"] = {
"$gte": start_date,

View File

@@ -8,8 +8,9 @@ from sqlalchemy.orm import sessionmaker, mapped_column
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
from sqlalchemy import Column, BIGINT, String, DateTime
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy_json import mutable_json_type
import uuid
import re
from tqdm import tqdm
@@ -41,7 +42,7 @@ def get_db_model(table_name: str, table_type: TableType):
__abstract__ = True # this line is necessary
# Assuming passage_id is the primary key
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False)
text = Column(String, nullable=False)
doc_id = Column(String)
@@ -77,7 +78,7 @@ def get_db_model(table_name: str, table_type: TableType):
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
role = Column(String, nullable=False)

View File

@@ -70,7 +70,6 @@ class StorageConnector:
filter_conditions = self.filters
print("FILTERS", filter_conditions)
return filter_conditions
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
@@ -99,59 +98,42 @@ class StorageConnector:
raise ValueError(f"Table type {table_type} not implemented")
@staticmethod
def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None):
storage_type = MemGPTConfig.load().archival_storage_type
def get_storage_connector(table_type: TableType, storage_type: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector
# read from config if not provided
if storage_type is None:
storage_type = MemGPTConfig.load().archival_storage_type
return VectorIndexStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
elif storage_type == "postgres":
if storage_type == "postgres":
from memgpt.connectors.db import PostgresStorageConnector
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
return PostgresStorageConnector(agent_config=agent_config, table_type=table_type)
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector
return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
return ChromaStorageConnector(agent_config=agent_config, table_type=table_type)
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector
return LanceDBConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
return LanceDBConnector(agent_config=agent_config, table_type=table_type)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@staticmethod
def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None):
return StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, agent_config=agent_config)
@staticmethod
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
storage_type = MemGPTConfig.load().recall_storage_type
print("Recall storage type", storage_type)
if storage_type == "local":
from memgpt.connectors.local import InMemoryStorageConnector
# maintains in-memory list for storage
return InMemoryStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
elif storage_type == "postgres":
from memgpt.connectors.db import PostgresStorageConnector
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector
return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, agent_config=agent_config)
@staticmethod
def list_loaded_data():
def list_loaded_data(storage_type: Optional[str] = None):
# TODO: modify this to simply list loaded data from a given user
storage_type = MemGPTConfig.load().archival_storage_type
if storage_type is None:
storage_type = MemGPTConfig.load().archival_storage_type
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector

View File

@@ -1,4 +1,5 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
import uuid
from abc import abstractmethod
from typing import Optional
import numpy as np
@@ -18,7 +19,10 @@ class Record:
self.user_id = user_id
self.agent_id = agent_id
self.text = text
self.id = id
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
# todo: generate unique uuid
# todo: self.role = role (?)

View File

@@ -97,14 +97,14 @@ def embedding_model():
# load config
config = MemGPTConfig.load()
endpoint_type = config.embedding_endpoint_type
endpoint = config.embedding_endpoint_type
if endpoint == "openai":
if endpoint_type == "openai":
model = OpenAIEmbedding(
api_base=config.embedding_endpoint, api_key=config.openai_key, additional_kwargs={"user": config.anon_clientid}
)
return model
elif endpoint == "azure":
elif endpoint_type == "azure":
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
model = "text-embedding-ada-002"
deployment = config.azure_embedding_deployment if config.azure_embedding_deployment is not None else model
@@ -115,7 +115,7 @@ def embedding_model():
azure_endpoint=config.azure_endpoint,
api_version=config.azure_version,
)
elif endpoint == "hugging-face":
elif endpoint_type == "hugging-face":
embed_model = EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=config.anon_clientid)
return embed_model
else:

View File

@@ -21,9 +21,8 @@ from memgpt.interface import CLIInterface as interface # for printing to termin
import memgpt.agent as agent
import memgpt.system as system
import memgpt.constants as constants
import memgpt.errors as errors
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart
from memgpt.cli.cli_config import configure, list, add
from memgpt.cli.cli import run, attach, version
from memgpt.cli.cli_config import configure, list, add, delete
from memgpt.cli.cli_load import app as load_app
from memgpt.connectors.storage import StorageConnector
@@ -34,9 +33,7 @@ app.command(name="attach")(attach)
app.command(name="configure")(configure)
app.command(name="list")(list)
app.command(name="add")(add)
app.command(name="server")(server)
app.command(name="folder")(open_folder)
app.command(name="quickstart")(quickstart)
app.command(name="delete")(delete)
# load data commands
app.add_typer(load_app, name="load")

View File

@@ -17,280 +17,442 @@ from memgpt.embeddings import embedding_model
from memgpt.data_types import Message, Passage
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.utils import get_local_time
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN
import argparse
from datetime import datetime, timedelta
def test_recall_db():
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
storage_type = "postgres"
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config = MemGPTConfig(
recall_storage_type=storage_type,
recall_storage_uri=storage_uri,
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt4",
)
print(config.config_path)
assert config.recall_storage_uri is not None
config.save()
print(config)
agent_config = AgentConfig(
persona=config.persona,
human=config.human,
model=config.model,
)
conn = StorageConnector.get_recall_storage_connector(agent_config)
# construct recall memory messages
message1 = Message(
agent_id=agent_config.name,
role="agent",
text="This is a test message",
user_id=config.anon_clientid,
model=agent_config.model,
created_at=datetime.now(),
)
message2 = Message(
agent_id=agent_config.name,
role="user",
text="This is a test message",
user_id=config.anon_clientid,
model=agent_config.model,
created_at=datetime.now(),
)
print(vars(message1))
# test insert
conn.insert(message1)
conn.insert_many([message2])
# test size
assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}"
assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}'
# test text query
res = conn.query_text("test")
print(res)
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
# test date query
current_time = datetime.now()
ten_weeks_ago = current_time - timedelta(weeks=1)
res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time)
print(res)
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
print(conn.get_all())
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
start_date = datetime(2009, 10, 5, 18, 00)
dates = [start_date - timedelta(weeks=1), start_date, start_date + timedelta(weeks=1)]
roles = ["user", "agent", "user"]
agent_ids = ["agent1", "agent2", "agent1"]
ids = ["test1", "test2", "test3"] # TODO: generate unique uuid
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
def test_postgres_openai():
if not os.getenv("PGVECTOR_TEST_DB_URL"):
return # soft pass
if not os.getenv("OPENAI_API_KEY"):
return # soft pass
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"
) # https://stackoverflow.com/a/64698899
config.save()
print(config)
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
agent_config = AgentConfig(
name="test_agent",
persona=config.persona,
human=config.human,
model=config.model,
)
db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
# db.delete()
# return
for passage in passage:
db.insert(
def generate_passages(embed_model):
"""Generate list of 3 Passage objects"""
# embeddings: use openai if env is set, otherwise local
passages = []
for (text, _, _, agent_id, id) in zip(texts, dates, roles, agent_ids, ids):
embedding = None
if embed_model:
embedding = embed_model.get_text_embedding(text)
passages.append(
Passage(
text=passage,
embedding=embed_model.get_text_embedding(passage),
user_id=config.anon_clientid,
agent_id="test_agent",
data_source="test",
metadata={"test_metadata_key": "test_metadata_value"},
user_id="test",
text=text,
agent_id=agent_id,
embedding=embedding,
data_source="test_source",
)
)
print(db.get_all())
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
# TODO fix (causes a hang for some reason)
# print("deleting...")
# db.delete()
# print("...finished")
return passages
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
def test_chroma_openai():
if not os.getenv("OPENAI_API_KEY"):
return # soft pass
def generate_messages():
"""Generate list of 3 Message objects"""
messages = []
for (text, date, role, agent_id, id) in zip(texts, dates, roles, agent_ids, ids):
messages.append(Message(user_id="test", text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt4"))
print(messages[-1].text)
return messages
config = MemGPTConfig(
archival_storage_type="chroma",
archival_storage_path="./test_chroma",
embedding_endpoint_type="openai",
embedding_dim=1536,
model="gpt4",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
)
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"])
@pytest.mark.parametrize("table_type", [TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY])
def test_storage(storage_connector, table_type):
# setup memgpt config
# TODO: set env for different config path
config = MemGPTConfig()
if storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
if storage_connector == "lancedb":
if not os.getenv("LANCEDB_TEST_URL"):
print("Skipping test, missing LanceDB URI")
return
config.archival_storage_uri = os.getenv("LANCEDB_TEST_URL")
config.recall_storage_uri = os.getenv("LANCEDB_TEST_URL")
config.archival_storage_type = "lancedb"
config.recall_storage_type = "lancedb"
if storage_connector == "chroma":
config.archival_storage_type = "chroma"
config.recall_storage_type = "chroma"
config.recall_storage_path = "./test_chroma"
config.archival_storage_path = "./test_chroma"
# get embedding model
embed_model = None
if os.getenv("OPENAI_API_KEY"):
config.embedding_endpoint_type = "openai"
config.embedding_endpoint = "https://api.openai.com/v1"
config.embedding_dim = 1536
config.openai_key = os.getenv("OPENAI_API_KEY")
else:
config.embedding_endpoint_type = "local"
config.embedding_endpoint = None
config.embedding_dim = 384
config.save()
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = ChromaStorageConnector(name="test-openai")
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(query, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
print(res[0].text)
print("deleting")
db.delete()
@pytest.mark.skipif(
not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key"
)
def test_lancedb_openai():
assert os.getenv("LANCEDB_TEST_URL") is not None
if os.getenv("OPENAI_API_KEY") is None:
return # soft pass
config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
print(config.config_path)
assert config.archival_storage_uri is not None
print(config)
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = LanceDBConnector(name="test-openai")
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
print(db.get_all())
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
def test_postgres_local():
if not os.getenv("PGVECTOR_TEST_DB_URL"):
return
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(
archival_storage_type="postgres",
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
embedding_endpoint_type="local",
embedding_dim=384, # use HF model
# create agent
agent_config = AgentConfig(
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
model=DEFAULT_MEMGPT_MODEL,
)
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"
) # https://stackoverflow.com/a/64698899
config.save()
print(config)
embed_model = embedding_model()
# create storage connector
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
# generate data
if table_type == TableType.ARCHIVAL_MEMORY:
records = generate_passages(embed_model)
elif table_type == TableType.RECALL_MEMORY:
records = generate_messages()
else:
raise NotImplementedError(f"Table type {table_type} not implemented")
db = PostgresStorageConnector(name="test-local")
# test: insert
conn.insert(records[0])
assert conn.size() == 1, f"Expected 1 record, got {conn.size()}"
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
# test: insert_many
conn.insert_many(records[1:])
assert conn.size() == 3, f"Expected 1 record, got {conn.size()}"
print(db.get_all())
# test: list_loaded_data
if table_type == TableType.ARCHIVAL_MEMORY:
sources = StorageConnector.list_loaded_data(storage_type=storage_connector)
assert len(sources) == 1, f"Expected 1 source, got {len(sources)}"
assert sources[0] == "test_source", f"Expected 'test_source', got {sources[0]}"
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
# test: get_all_paginated
paginated_total = 0
for page in conn.get_all_paginated(page_size=1):
paginated_total += len(page)
assert paginated_total == 3, f"Expected 3 records, got {paginated_total}"
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
# test: get_all
all_records = conn.get_all()
assert len(all_records) == 3, f"Expected 3 records, got {len(all_records)}"
all_records = conn.get_all(limit=2)
assert len(all_records) == 2, f"Expected 2 records, got {len(all_records)}"
# TODO fix (causes a hang for some reason)
# print("deleting...")
# db.delete()
# print("...finished")
# test: get
res = conn.get(id=ids[0])
assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
# test: size
assert conn.size() == 3, f"Expected 3 records, got {conn.size()}"
assert conn.size(filters={"agent_id", "agent1"}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', 'agent1'})}"
if table_type == TableType.RECALL_MEMORY:
assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}"
# test: query (vector)
if embed_model:
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = conn.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
# test: query_text
query = "CindereLLa"
res = conn.query_text(query)
assert len(res) == 1, f"Expected 1 result, got {len(res)}"
assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
# test: query_date (recall memory only)
if table_type == TableType.RECALL_MEMORY:
print("Testing recall memory date search")
start_date = start_date - timedelta(days=1)
end_date = start_date + timedelta(days=1)
res = conn.query_date(start_date=start_date, end_date=end_date)
assert len(res) == 1, f"Expected 1 result, got {len(res): {res}}"
# test: delete
conn.delete({"id": ids[0]})
assert conn.size() == 2, f"Expected 2 records, got {conn.size()}"
conn.delete()
assert conn.size() == 0, f"Expected 0 records, got {conn.size()}"
@pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI")
def test_lancedb_local():
assert os.getenv("LANCEDB_TEST_URL") is not None
config = MemGPTConfig(
archival_storage_type="lancedb",
archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
embedding_model="local",
embedding_dim=384, # use HF model
)
print(config.config_path)
assert config.archival_storage_uri is not None
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = LanceDBConnector(name="test-local")
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
print(db.get_all())
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
test_recall_db()
# def test_recall_db():
# # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
#
# storage_type = "postgres"
# storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
# config = MemGPTConfig(
# recall_storage_type=storage_type,
# recall_storage_uri=storage_uri,
# model_endpoint_type="openai",
# model_endpoint="https://api.openai.com/v1",
# model="gpt4",
# )
# print(config.config_path)
# assert config.recall_storage_uri is not None
# config.save()
# print(config)
#
# agent_config = AgentConfig(
# persona=config.persona,
# human=config.human,
# model=config.model,
# )
#
# conn = StorageConnector.get_recall_storage_connector(agent_config)
#
# # construct recall memory messages
# message1 = Message(
# agent_id=agent_config.name,
# role="agent",
# text="This is a test message",
# user_id=config.anon_clientid,
# model=agent_config.model,
# created_at=datetime.now(),
# )
# message2 = Message(
# agent_id=agent_config.name,
# role="user",
# text="This is a test message",
# user_id=config.anon_clientid,
# model=agent_config.model,
# created_at=datetime.now(),
# )
# print(vars(message1))
#
# # test insert
# conn.insert(message1)
# conn.insert_many([message2])
#
# # test size
# assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}"
# assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}'
#
# # test text query
# res = conn.query_text("test")
# print(res)
# assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
#
# # test date query
# current_time = datetime.now()
# ten_weeks_ago = current_time - timedelta(weeks=1)
# res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time)
# print(res)
# assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
#
# print(conn.get_all())
#
#
# @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
# def test_postgres_openai():
# if not os.getenv("PGVECTOR_TEST_DB_URL"):
# return # soft pass
# if not os.getenv("OPENAI_API_KEY"):
# return # soft pass
#
# # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
# config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
# print(config.config_path)
# assert config.archival_storage_uri is not None
# config.archival_storage_uri = config.archival_storage_uri.replace(
# "postgres://", "postgresql://"
# ) # https://stackoverflow.com/a/64698899
# config.save()
# print(config)
#
# embed_model = embedding_model()
#
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
#
# agent_config = AgentConfig(
# name="test_agent",
# persona=config.persona,
# human=config.human,
# model=config.model,
# )
#
# db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
#
# # db.delete()
# # return
# for passage in passage:
# db.insert(
# Passage(
# text=passage,
# embedding=embed_model.get_text_embedding(passage),
# user_id=config.anon_clientid,
# agent_id="test_agent",
# data_source="test",
# metadata={"test_metadata_key": "test_metadata_value"},
# )
# )
#
# print(db.get_all())
#
# query = "why was she crying"
# query_vec = embed_model.get_text_embedding(query)
# res = db.query(None, query_vec, top_k=2)
#
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
#
# # TODO fix (causes a hang for some reason)
# # print("deleting...")
# # db.delete()
# # print("...finished")
#
#
# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
# def test_chroma_openai():
# if not os.getenv("OPENAI_API_KEY"):
# return # soft pass
#
# config = MemGPTConfig(
# archival_storage_type="chroma",
# archival_storage_path="./test_chroma",
# embedding_endpoint_type="openai",
# embedding_dim=1536,
# model="gpt4",
# model_endpoint_type="openai",
# model_endpoint="https://api.openai.com/v1",
# )
# config.save()
# embed_model = embedding_model()
#
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
#
# db = ChromaStorageConnector(name="test-openai")
#
# for passage in passage:
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
#
# query = "why was she crying"
# query_vec = embed_model.get_text_embedding(query)
# res = db.query(query, query_vec, top_k=2)
#
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
#
# print(res[0].text)
#
# print("deleting")
# db.delete()
#
#
# @pytest.mark.skipif(
# not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key"
# )
# def test_lancedb_openai():
# assert os.getenv("LANCEDB_TEST_URL") is not None
# if os.getenv("OPENAI_API_KEY") is None:
# return # soft pass
#
# config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
# print(config.config_path)
# assert config.archival_storage_uri is not None
# print(config)
#
# embed_model = embedding_model()
#
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
#
# db = LanceDBConnector(name="test-openai")
#
# for passage in passage:
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
#
# print(db.get_all())
#
# query = "why was she crying"
# query_vec = embed_model.get_text_embedding(query)
# res = db.query(None, query_vec, top_k=2)
#
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
#
#
# @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
# def test_postgres_local():
# if not os.getenv("PGVECTOR_TEST_DB_URL"):
# return
# # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
#
# config = MemGPTConfig(
# archival_storage_type="postgres",
# archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
# embedding_endpoint_type="local",
# embedding_dim=384, # use HF model
# )
# print(config.config_path)
# assert config.archival_storage_uri is not None
# config.archival_storage_uri = config.archival_storage_uri.replace(
# "postgres://", "postgresql://"
# ) # https://stackoverflow.com/a/64698899
# config.save()
# print(config)
#
# embed_model = embedding_model()
#
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
#
# db = PostgresStorageConnector(name="test-local")
#
# for passage in passage:
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
#
# print(db.get_all())
#
# query = "why was she crying"
# query_vec = embed_model.get_text_embedding(query)
# res = db.query(None, query_vec, top_k=2)
#
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
#
# # TODO fix (causes a hang for some reason)
# # print("deleting...")
# # db.delete()
# # print("...finished")
#
#
# @pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI")
# def test_lancedb_local():
# assert os.getenv("LANCEDB_TEST_URL") is not None
#
# config = MemGPTConfig(
# archival_storage_type="lancedb",
# archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
# embedding_model="local",
# embedding_dim=384, # use HF model
# )
# print(config.config_path)
# assert config.archival_storage_uri is not None
#
# embed_model = embedding_model()
#
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
#
# db = LanceDBConnector(name="test-local")
#
# for passage in passage:
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
#
# print(db.get_all())
#
# query = "why was she crying"
# query_vec = embed_model.get_text_embedding(query)
# res = db.query(None, query_vec, top_k=2)
#
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
#