Add more compehensive tests, make row ids be strings (not integers)
This commit is contained in:
@@ -10,7 +10,7 @@ from memgpt import utils
|
||||
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.constants import LLM_MAX_TOKENS
|
||||
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
|
||||
from memgpt.local_llm.utils import get_available_wrappers
|
||||
@@ -601,3 +601,31 @@ def add(
|
||||
# write text to file
|
||||
with open(os.path.join(directory, name), "w") as f:
|
||||
f.write(text)
|
||||
|
||||
|
||||
@app.command()
|
||||
def delete(
|
||||
option: str,
|
||||
name: str = typer.Option(help="Name of human/persona/agent/source to delete"),
|
||||
):
|
||||
if option == "agent":
|
||||
# delete state/config
|
||||
# TODO: this will eventually need to go through the storage connector
|
||||
agent_config = AgentConfig.load(name)
|
||||
# remove directory
|
||||
shutil.rmtree(agent_config.save_dir())
|
||||
|
||||
# delete memory
|
||||
recall_storage = StorageConnector.get_recall_storage_connector(agent_config)
|
||||
recall_storage.delete()
|
||||
archival_storage = StorageConnector.get_archival_storage_connector(agent_config)
|
||||
archival_storage.delete()
|
||||
|
||||
elif option == "source":
|
||||
# TODO: also delete document store
|
||||
# TODO: remove data from any agents that have loaded it in (?)
|
||||
storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES)
|
||||
storage.delete({"data_source": name})
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -303,50 +303,6 @@ class AgentConfig:
|
||||
os.path.join(MEMGPT_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path
|
||||
)
|
||||
|
||||
def link_functions(self, function_schemas):
|
||||
|
||||
# need to dynamically link the functions
|
||||
# the saved agent.functions will just have the schemas, but we need to
|
||||
# go through the functions library and pull the respective python functions
|
||||
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
# agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create)
|
||||
# [{'name': ..., 'description': ...}, {...}]
|
||||
available_functions = load_all_function_sets()
|
||||
linked_function_set = {}
|
||||
for f_schema in function_schemas:
|
||||
# Attempt to find the function in the existing function library
|
||||
f_name = f_schema.get("name")
|
||||
if f_name is None:
|
||||
raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}")
|
||||
linked_function = available_functions.get(f_name)
|
||||
if linked_function is None:
|
||||
raise ValueError(
|
||||
f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}"
|
||||
)
|
||||
# Once we find a matching function, make sure the schema is identical
|
||||
if json.dumps(f_schema) != json.dumps(linked_function["json_schema"]):
|
||||
# error_message = (
|
||||
# f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different."
|
||||
# + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2)}"
|
||||
# + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2)}"
|
||||
# )
|
||||
schema_diff = get_schema_diff(f_schema, linked_function["json_schema"])
|
||||
error_message = (
|
||||
f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n"
|
||||
+ "".join(schema_diff)
|
||||
)
|
||||
|
||||
# NOTE to handle old configs, instead of erroring here let's just warn
|
||||
# raise ValueError(error_message)
|
||||
utils.printd(error_message)
|
||||
linked_function_set[f_name] = linked_function
|
||||
return linked_function_set
|
||||
|
||||
def generate_agent_id(self, length=6):
|
||||
## random character based
|
||||
# characters = string.ascii_lowercase + string.digits
|
||||
@@ -362,6 +318,9 @@ class AgentConfig:
|
||||
self.data_sources.append(data_source)
|
||||
self.save()
|
||||
|
||||
def save_dir(self):
|
||||
return os.path.join(MEMGPT_DIR, "agents", self.name)
|
||||
|
||||
def save_state_dir(self):
|
||||
# directory to save agent state
|
||||
return os.path.join(MEMGPT_DIR, "agents", self.name, "agent_state")
|
||||
|
||||
@@ -29,21 +29,37 @@ class ChromaStorageConnector(StorageConnector):
|
||||
|
||||
# get a collection or create if it doesn't exist already
|
||||
self.collection = self.client.get_or_create_collection(self.table_name)
|
||||
self.include = ["id", "documents", "embeddings", "metadatas"]
|
||||
self.include = ["documents", "embeddings", "metadatas"]
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}):
|
||||
# get all filters for query
|
||||
if filters is not None:
|
||||
filter_conditions = {**self.filters, **filters}
|
||||
else:
|
||||
filter_conditions = self.filters
|
||||
|
||||
# convert to chroma format
|
||||
chroma_filters = {"$and": []}
|
||||
for key, value in filter_conditions.items():
|
||||
chroma_filters["$and"].append({key: {"$eq": value}})
|
||||
return chroma_filters
|
||||
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]:
|
||||
offset = 0
|
||||
filters = self.get_filters(filters)
|
||||
print(filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
db_chunks = self.collection.get(offset=offset, limit=page_size, include=self.include, where=filters)
|
||||
print("querying...", self.collection.count(), offset, page_size)
|
||||
results = self.collection.get(offset=offset, limit=page_size, include=self.include, where=filters)
|
||||
print(len(results["embeddings"]))
|
||||
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
if not db_chunks:
|
||||
if len(results["embeddings"]) == 0:
|
||||
break
|
||||
|
||||
# Yield a list of Record objects converted from the chunk
|
||||
yield self.results_to_records(db_chunks)
|
||||
yield self.results_to_records(results)
|
||||
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
@@ -54,8 +70,8 @@ class ChromaStorageConnector(StorageConnector):
|
||||
if "created_at" in metadata:
|
||||
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
|
||||
return [
|
||||
self.type(id=id, text=text, embedding=embedding, **metadatas)
|
||||
for (id, text, embedding, metadatas) in zip(results["ids"], results["documents"], results["embeddings"], results["metadatas"])
|
||||
self.type(text=text, embedding=embedding, **metadatas)
|
||||
for (text, embedding, metadatas) in zip(results["documents"], results["embeddings"], results["metadatas"])
|
||||
]
|
||||
|
||||
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
@@ -68,26 +84,12 @@ class ChromaStorageConnector(StorageConnector):
|
||||
results = self.collection.get(ids=[id])
|
||||
return self.results_to_records(results)
|
||||
|
||||
def insert(self, record: Record):
|
||||
if record.id is None:
|
||||
record.id = str(self.collection.count())
|
||||
metadata = vars(record)
|
||||
metadata.pop("id")
|
||||
metadata.pop("text")
|
||||
metadata.pop("embedding")
|
||||
self.collection.add(documents=[record.text], embeddings=[record.embedding], ids=[record.id], metadatas=[metadata])
|
||||
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
count = self.collection.count()
|
||||
def format_records(self, records: List[Record]):
|
||||
metadatas = []
|
||||
ids = []
|
||||
ids = [str(record.id) for record in records]
|
||||
documents = [record.text for record in records]
|
||||
embeddings = [record.embedding for record in records]
|
||||
for record in records:
|
||||
if record.id is None:
|
||||
count += 1
|
||||
ids.append(str(count))
|
||||
# TODO: ensure that other record.ids dont match
|
||||
metadata = vars(record)
|
||||
metadata.pop("id")
|
||||
metadata.pop("text")
|
||||
@@ -96,6 +98,17 @@ class ChromaStorageConnector(StorageConnector):
|
||||
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
|
||||
metadatas.append(metadata)
|
||||
return ids, documents, embeddings, metadatas
|
||||
|
||||
def insert(self, record: Record):
|
||||
ids, documents, embeddings, metadatas = self.format_records([record])
|
||||
if not any(embeddings):
|
||||
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
else:
|
||||
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
|
||||
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
ids, documents, embeddings, metadatas = self.format_records(records)
|
||||
if not any(embeddings):
|
||||
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
else:
|
||||
@@ -110,8 +123,11 @@ class ChromaStorageConnector(StorageConnector):
|
||||
pass
|
||||
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
filters = self.get_filters(filters)
|
||||
return self.collection.count()
|
||||
# unfortuantely, need to use pagination to get filtering
|
||||
count = 0
|
||||
for records in self.get_all_paginated(page_size=100, filters=filters):
|
||||
count += len(records)
|
||||
return count
|
||||
|
||||
def list_data_sources(self):
|
||||
raise NotImplementedError
|
||||
@@ -123,6 +139,7 @@ class ChromaStorageConnector(StorageConnector):
|
||||
|
||||
def query_date(self, start_date, end_date, start=None, count=None):
|
||||
# TODO: no idea if this is correct
|
||||
# TODO: convert start/end_date into timestamp
|
||||
filters = self.get_filters(filters)
|
||||
filters["created_at"] = {
|
||||
"$gte": start_date,
|
||||
|
||||
@@ -8,8 +8,9 @@ from sqlalchemy.orm import sessionmaker, mapped_column
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import Column, BIGINT, String, DateTime
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy_json import mutable_json_type
|
||||
import uuid
|
||||
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
@@ -41,7 +42,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(String, nullable=False)
|
||||
text = Column(String, nullable=False)
|
||||
doc_id = Column(String)
|
||||
@@ -77,7 +78,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
|
||||
@@ -70,7 +70,6 @@ class StorageConnector:
|
||||
filter_conditions = self.filters
|
||||
print("FILTERS", filter_conditions)
|
||||
return filter_conditions
|
||||
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||
|
||||
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
|
||||
|
||||
@@ -99,59 +98,42 @@ class StorageConnector:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
@staticmethod
|
||||
def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None):
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
def get_storage_connector(table_type: TableType, storage_type: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
# read from config if not provided
|
||||
if storage_type is None:
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
|
||||
return VectorIndexStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
|
||||
elif storage_type == "postgres":
|
||||
if storage_type == "postgres":
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
|
||||
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
return PostgresStorageConnector(agent_config=agent_config, table_type=table_type)
|
||||
elif storage_type == "chroma":
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
|
||||
return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
return ChromaStorageConnector(agent_config=agent_config, table_type=table_type)
|
||||
elif storage_type == "lancedb":
|
||||
from memgpt.connectors.db import LanceDBConnector
|
||||
|
||||
return LanceDBConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
return LanceDBConnector(agent_config=agent_config, table_type=table_type)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@staticmethod
|
||||
def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None):
|
||||
return StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, agent_config=agent_config)
|
||||
|
||||
@staticmethod
|
||||
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
|
||||
storage_type = MemGPTConfig.load().recall_storage_type
|
||||
|
||||
print("Recall storage type", storage_type)
|
||||
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import InMemoryStorageConnector
|
||||
|
||||
# maintains in-memory list for storage
|
||||
return InMemoryStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
|
||||
|
||||
elif storage_type == "postgres":
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
|
||||
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
|
||||
|
||||
elif storage_type == "chroma":
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
|
||||
return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, agent_config=agent_config)
|
||||
|
||||
@staticmethod
|
||||
def list_loaded_data():
|
||||
def list_loaded_data(storage_type: Optional[str] = None):
|
||||
# TODO: modify this to simply list loaded data from a given user
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
if storage_type is None:
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
@@ -18,7 +19,10 @@ class Record:
|
||||
self.user_id = user_id
|
||||
self.agent_id = agent_id
|
||||
self.text = text
|
||||
self.id = id
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
else:
|
||||
self.id = id
|
||||
# todo: generate unique uuid
|
||||
# todo: self.role = role (?)
|
||||
|
||||
|
||||
@@ -97,14 +97,14 @@ def embedding_model():
|
||||
|
||||
# load config
|
||||
config = MemGPTConfig.load()
|
||||
endpoint_type = config.embedding_endpoint_type
|
||||
|
||||
endpoint = config.embedding_endpoint_type
|
||||
if endpoint == "openai":
|
||||
if endpoint_type == "openai":
|
||||
model = OpenAIEmbedding(
|
||||
api_base=config.embedding_endpoint, api_key=config.openai_key, additional_kwargs={"user": config.anon_clientid}
|
||||
)
|
||||
return model
|
||||
elif endpoint == "azure":
|
||||
elif endpoint_type == "azure":
|
||||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
|
||||
model = "text-embedding-ada-002"
|
||||
deployment = config.azure_embedding_deployment if config.azure_embedding_deployment is not None else model
|
||||
@@ -115,7 +115,7 @@ def embedding_model():
|
||||
azure_endpoint=config.azure_endpoint,
|
||||
api_version=config.azure_version,
|
||||
)
|
||||
elif endpoint == "hugging-face":
|
||||
elif endpoint_type == "hugging-face":
|
||||
embed_model = EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=config.anon_clientid)
|
||||
return embed_model
|
||||
else:
|
||||
|
||||
@@ -21,9 +21,8 @@ from memgpt.interface import CLIInterface as interface # for printing to termin
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
import memgpt.errors as errors
|
||||
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart
|
||||
from memgpt.cli.cli_config import configure, list, add
|
||||
from memgpt.cli.cli import run, attach, version
|
||||
from memgpt.cli.cli_config import configure, list, add, delete
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
|
||||
@@ -34,9 +33,7 @@ app.command(name="attach")(attach)
|
||||
app.command(name="configure")(configure)
|
||||
app.command(name="list")(list)
|
||||
app.command(name="add")(add)
|
||||
app.command(name="server")(server)
|
||||
app.command(name="folder")(open_folder)
|
||||
app.command(name="quickstart")(quickstart)
|
||||
app.command(name="delete")(delete)
|
||||
# load data commands
|
||||
app.add_typer(load_app, name="load")
|
||||
|
||||
|
||||
@@ -17,280 +17,442 @@ from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Message, Passage
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.utils import get_local_time
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN
|
||||
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def test_recall_db():
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
|
||||
storage_type = "postgres"
|
||||
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config = MemGPTConfig(
|
||||
recall_storage_type=storage_type,
|
||||
recall_storage_uri=storage_uri,
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt4",
|
||||
)
|
||||
print(config.config_path)
|
||||
assert config.recall_storage_uri is not None
|
||||
config.save()
|
||||
print(config)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
persona=config.persona,
|
||||
human=config.human,
|
||||
model=config.model,
|
||||
)
|
||||
|
||||
conn = StorageConnector.get_recall_storage_connector(agent_config)
|
||||
|
||||
# construct recall memory messages
|
||||
message1 = Message(
|
||||
agent_id=agent_config.name,
|
||||
role="agent",
|
||||
text="This is a test message",
|
||||
user_id=config.anon_clientid,
|
||||
model=agent_config.model,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
message2 = Message(
|
||||
agent_id=agent_config.name,
|
||||
role="user",
|
||||
text="This is a test message",
|
||||
user_id=config.anon_clientid,
|
||||
model=agent_config.model,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
print(vars(message1))
|
||||
|
||||
# test insert
|
||||
conn.insert(message1)
|
||||
conn.insert_many([message2])
|
||||
|
||||
# test size
|
||||
assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}"
|
||||
assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}'
|
||||
|
||||
# test text query
|
||||
res = conn.query_text("test")
|
||||
print(res)
|
||||
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||
|
||||
# test date query
|
||||
current_time = datetime.now()
|
||||
ten_weeks_ago = current_time - timedelta(weeks=1)
|
||||
res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time)
|
||||
print(res)
|
||||
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||
|
||||
print(conn.get_all())
|
||||
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
start_date = datetime(2009, 10, 5, 18, 00)
|
||||
dates = [start_date - timedelta(weeks=1), start_date, start_date + timedelta(weeks=1)]
|
||||
roles = ["user", "agent", "user"]
|
||||
agent_ids = ["agent1", "agent2", "agent1"]
|
||||
ids = ["test1", "test2", "test3"] # TODO: generate unique uuid
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
||||
def test_postgres_openai():
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
return # soft pass
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
return # soft pass
|
||||
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
config.archival_storage_uri = config.archival_storage_uri.replace(
|
||||
"postgres://", "postgresql://"
|
||||
) # https://stackoverflow.com/a/64698899
|
||||
config.save()
|
||||
print(config)
|
||||
|
||||
embed_model = embedding_model()
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
name="test_agent",
|
||||
persona=config.persona,
|
||||
human=config.human,
|
||||
model=config.model,
|
||||
)
|
||||
|
||||
db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
|
||||
# db.delete()
|
||||
# return
|
||||
for passage in passage:
|
||||
db.insert(
|
||||
def generate_passages(embed_model):
|
||||
"""Generate list of 3 Passage objects"""
|
||||
# embeddings: use openai if env is set, otherwise local
|
||||
passages = []
|
||||
for (text, _, _, agent_id, id) in zip(texts, dates, roles, agent_ids, ids):
|
||||
embedding = None
|
||||
if embed_model:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
passages.append(
|
||||
Passage(
|
||||
text=passage,
|
||||
embedding=embed_model.get_text_embedding(passage),
|
||||
user_id=config.anon_clientid,
|
||||
agent_id="test_agent",
|
||||
data_source="test",
|
||||
metadata={"test_metadata_key": "test_metadata_value"},
|
||||
user_id="test",
|
||||
text=text,
|
||||
agent_id=agent_id,
|
||||
embedding=embedding,
|
||||
data_source="test_source",
|
||||
)
|
||||
)
|
||||
|
||||
print(db.get_all())
|
||||
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = db.query(None, query_vec, top_k=2)
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
# TODO fix (causes a hang for some reason)
|
||||
# print("deleting...")
|
||||
# db.delete()
|
||||
# print("...finished")
|
||||
return passages
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
|
||||
def test_chroma_openai():
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
return # soft pass
|
||||
def generate_messages():
|
||||
"""Generate list of 3 Message objects"""
|
||||
messages = []
|
||||
for (text, date, role, agent_id, id) in zip(texts, dates, roles, agent_ids, ids):
|
||||
messages.append(Message(user_id="test", text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt4"))
|
||||
print(messages[-1].text)
|
||||
return messages
|
||||
|
||||
config = MemGPTConfig(
|
||||
archival_storage_type="chroma",
|
||||
archival_storage_path="./test_chroma",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_dim=1536,
|
||||
model="gpt4",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY])
|
||||
def test_storage(storage_connector, table_type):
|
||||
|
||||
# setup memgpt config
|
||||
# TODO: set env for different config path
|
||||
config = MemGPTConfig()
|
||||
if storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
if storage_connector == "lancedb":
|
||||
if not os.getenv("LANCEDB_TEST_URL"):
|
||||
print("Skipping test, missing LanceDB URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("LANCEDB_TEST_URL")
|
||||
config.recall_storage_uri = os.getenv("LANCEDB_TEST_URL")
|
||||
config.archival_storage_type = "lancedb"
|
||||
config.recall_storage_type = "lancedb"
|
||||
if storage_connector == "chroma":
|
||||
config.archival_storage_type = "chroma"
|
||||
config.recall_storage_type = "chroma"
|
||||
config.recall_storage_path = "./test_chroma"
|
||||
config.archival_storage_path = "./test_chroma"
|
||||
|
||||
# get embedding model
|
||||
embed_model = None
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config.embedding_endpoint_type = "openai"
|
||||
config.embedding_endpoint = "https://api.openai.com/v1"
|
||||
config.embedding_dim = 1536
|
||||
config.openai_key = os.getenv("OPENAI_API_KEY")
|
||||
else:
|
||||
config.embedding_endpoint_type = "local"
|
||||
config.embedding_endpoint = None
|
||||
config.embedding_dim = 384
|
||||
config.save()
|
||||
embed_model = embedding_model()
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
|
||||
db = ChromaStorageConnector(name="test-openai")
|
||||
|
||||
for passage in passage:
|
||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = db.query(query, query_vec, top_k=2)
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
print(res[0].text)
|
||||
|
||||
print("deleting")
|
||||
db.delete()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key"
|
||||
)
|
||||
def test_lancedb_openai():
|
||||
assert os.getenv("LANCEDB_TEST_URL") is not None
|
||||
if os.getenv("OPENAI_API_KEY") is None:
|
||||
return # soft pass
|
||||
|
||||
config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
print(config)
|
||||
|
||||
embed_model = embedding_model()
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
|
||||
db = LanceDBConnector(name="test-openai")
|
||||
|
||||
for passage in passage:
|
||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
|
||||
print(db.get_all())
|
||||
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = db.query(None, query_vec, top_k=2)
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
|
||||
def test_postgres_local():
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
return
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
|
||||
config = MemGPTConfig(
|
||||
archival_storage_type="postgres",
|
||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
embedding_endpoint_type="local",
|
||||
embedding_dim=384, # use HF model
|
||||
# create agent
|
||||
agent_config = AgentConfig(
|
||||
persona=DEFAULT_PERSONA,
|
||||
human=DEFAULT_HUMAN,
|
||||
model=DEFAULT_MEMGPT_MODEL,
|
||||
)
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
config.archival_storage_uri = config.archival_storage_uri.replace(
|
||||
"postgres://", "postgresql://"
|
||||
) # https://stackoverflow.com/a/64698899
|
||||
config.save()
|
||||
print(config)
|
||||
|
||||
embed_model = embedding_model()
|
||||
# create storage connector
|
||||
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
# generate data
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
records = generate_passages(embed_model)
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
records = generate_messages()
|
||||
else:
|
||||
raise NotImplementedError(f"Table type {table_type} not implemented")
|
||||
|
||||
db = PostgresStorageConnector(name="test-local")
|
||||
# test: insert
|
||||
conn.insert(records[0])
|
||||
assert conn.size() == 1, f"Expected 1 record, got {conn.size()}"
|
||||
|
||||
for passage in passage:
|
||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
# test: insert_many
|
||||
conn.insert_many(records[1:])
|
||||
assert conn.size() == 3, f"Expected 1 record, got {conn.size()}"
|
||||
|
||||
print(db.get_all())
|
||||
# test: list_loaded_data
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
sources = StorageConnector.list_loaded_data(storage_type=storage_connector)
|
||||
assert len(sources) == 1, f"Expected 1 source, got {len(sources)}"
|
||||
assert sources[0] == "test_source", f"Expected 'test_source', got {sources[0]}"
|
||||
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = db.query(None, query_vec, top_k=2)
|
||||
# test: get_all_paginated
|
||||
paginated_total = 0
|
||||
for page in conn.get_all_paginated(page_size=1):
|
||||
paginated_total += len(page)
|
||||
assert paginated_total == 3, f"Expected 3 records, got {paginated_total}"
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
# test: get_all
|
||||
all_records = conn.get_all()
|
||||
assert len(all_records) == 3, f"Expected 3 records, got {len(all_records)}"
|
||||
all_records = conn.get_all(limit=2)
|
||||
assert len(all_records) == 2, f"Expected 2 records, got {len(all_records)}"
|
||||
|
||||
# TODO fix (causes a hang for some reason)
|
||||
# print("deleting...")
|
||||
# db.delete()
|
||||
# print("...finished")
|
||||
# test: get
|
||||
res = conn.get(id=ids[0])
|
||||
assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
|
||||
|
||||
# test: size
|
||||
assert conn.size() == 3, f"Expected 3 records, got {conn.size()}"
|
||||
assert conn.size(filters={"agent_id", "agent1"}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', 'agent1'})}"
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}"
|
||||
|
||||
# test: query (vector)
|
||||
if embed_model:
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = conn.query(None, query_vec, top_k=2)
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
# test: query_text
|
||||
query = "CindereLLa"
|
||||
res = conn.query_text(query)
|
||||
assert len(res) == 1, f"Expected 1 result, got {len(res)}"
|
||||
assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
|
||||
|
||||
# test: query_date (recall memory only)
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
print("Testing recall memory date search")
|
||||
start_date = start_date - timedelta(days=1)
|
||||
end_date = start_date + timedelta(days=1)
|
||||
res = conn.query_date(start_date=start_date, end_date=end_date)
|
||||
assert len(res) == 1, f"Expected 1 result, got {len(res): {res}}"
|
||||
|
||||
# test: delete
|
||||
conn.delete({"id": ids[0]})
|
||||
assert conn.size() == 2, f"Expected 2 records, got {conn.size()}"
|
||||
conn.delete()
|
||||
assert conn.size() == 0, f"Expected 0 records, got {conn.size()}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI")
|
||||
def test_lancedb_local():
|
||||
assert os.getenv("LANCEDB_TEST_URL") is not None
|
||||
|
||||
config = MemGPTConfig(
|
||||
archival_storage_type="lancedb",
|
||||
archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
|
||||
embedding_model="local",
|
||||
embedding_dim=384, # use HF model
|
||||
)
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
|
||||
embed_model = embedding_model()
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
|
||||
db = LanceDBConnector(name="test-local")
|
||||
|
||||
for passage in passage:
|
||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
|
||||
print(db.get_all())
|
||||
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = db.query(None, query_vec, top_k=2)
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
|
||||
test_recall_db()
|
||||
# def test_recall_db():
|
||||
# # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
#
|
||||
# storage_type = "postgres"
|
||||
# storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
# config = MemGPTConfig(
|
||||
# recall_storage_type=storage_type,
|
||||
# recall_storage_uri=storage_uri,
|
||||
# model_endpoint_type="openai",
|
||||
# model_endpoint="https://api.openai.com/v1",
|
||||
# model="gpt4",
|
||||
# )
|
||||
# print(config.config_path)
|
||||
# assert config.recall_storage_uri is not None
|
||||
# config.save()
|
||||
# print(config)
|
||||
#
|
||||
# agent_config = AgentConfig(
|
||||
# persona=config.persona,
|
||||
# human=config.human,
|
||||
# model=config.model,
|
||||
# )
|
||||
#
|
||||
# conn = StorageConnector.get_recall_storage_connector(agent_config)
|
||||
#
|
||||
# # construct recall memory messages
|
||||
# message1 = Message(
|
||||
# agent_id=agent_config.name,
|
||||
# role="agent",
|
||||
# text="This is a test message",
|
||||
# user_id=config.anon_clientid,
|
||||
# model=agent_config.model,
|
||||
# created_at=datetime.now(),
|
||||
# )
|
||||
# message2 = Message(
|
||||
# agent_id=agent_config.name,
|
||||
# role="user",
|
||||
# text="This is a test message",
|
||||
# user_id=config.anon_clientid,
|
||||
# model=agent_config.model,
|
||||
# created_at=datetime.now(),
|
||||
# )
|
||||
# print(vars(message1))
|
||||
#
|
||||
# # test insert
|
||||
# conn.insert(message1)
|
||||
# conn.insert_many([message2])
|
||||
#
|
||||
# # test size
|
||||
# assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}"
|
||||
# assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}'
|
||||
#
|
||||
# # test text query
|
||||
# res = conn.query_text("test")
|
||||
# print(res)
|
||||
# assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||
#
|
||||
# # test date query
|
||||
# current_time = datetime.now()
|
||||
# ten_weeks_ago = current_time - timedelta(weeks=1)
|
||||
# res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time)
|
||||
# print(res)
|
||||
# assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||
#
|
||||
# print(conn.get_all())
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
||||
# def test_postgres_openai():
|
||||
# if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
# return # soft pass
|
||||
# if not os.getenv("OPENAI_API_KEY"):
|
||||
# return # soft pass
|
||||
#
|
||||
# # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
# config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
|
||||
# print(config.config_path)
|
||||
# assert config.archival_storage_uri is not None
|
||||
# config.archival_storage_uri = config.archival_storage_uri.replace(
|
||||
# "postgres://", "postgresql://"
|
||||
# ) # https://stackoverflow.com/a/64698899
|
||||
# config.save()
|
||||
# print(config)
|
||||
#
|
||||
# embed_model = embedding_model()
|
||||
#
|
||||
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
#
|
||||
# agent_config = AgentConfig(
|
||||
# name="test_agent",
|
||||
# persona=config.persona,
|
||||
# human=config.human,
|
||||
# model=config.model,
|
||||
# )
|
||||
#
|
||||
# db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
#
|
||||
# # db.delete()
|
||||
# # return
|
||||
# for passage in passage:
|
||||
# db.insert(
|
||||
# Passage(
|
||||
# text=passage,
|
||||
# embedding=embed_model.get_text_embedding(passage),
|
||||
# user_id=config.anon_clientid,
|
||||
# agent_id="test_agent",
|
||||
# data_source="test",
|
||||
# metadata={"test_metadata_key": "test_metadata_value"},
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# print(db.get_all())
|
||||
#
|
||||
# query = "why was she crying"
|
||||
# query_vec = embed_model.get_text_embedding(query)
|
||||
# res = db.query(None, query_vec, top_k=2)
|
||||
#
|
||||
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
#
|
||||
# # TODO fix (causes a hang for some reason)
|
||||
# # print("deleting...")
|
||||
# # db.delete()
|
||||
# # print("...finished")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
|
||||
# def test_chroma_openai():
|
||||
# if not os.getenv("OPENAI_API_KEY"):
|
||||
# return # soft pass
|
||||
#
|
||||
# config = MemGPTConfig(
|
||||
# archival_storage_type="chroma",
|
||||
# archival_storage_path="./test_chroma",
|
||||
# embedding_endpoint_type="openai",
|
||||
# embedding_dim=1536,
|
||||
# model="gpt4",
|
||||
# model_endpoint_type="openai",
|
||||
# model_endpoint="https://api.openai.com/v1",
|
||||
# )
|
||||
# config.save()
|
||||
# embed_model = embedding_model()
|
||||
#
|
||||
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
#
|
||||
# db = ChromaStorageConnector(name="test-openai")
|
||||
#
|
||||
# for passage in passage:
|
||||
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
#
|
||||
# query = "why was she crying"
|
||||
# query_vec = embed_model.get_text_embedding(query)
|
||||
# res = db.query(query, query_vec, top_k=2)
|
||||
#
|
||||
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
#
|
||||
# print(res[0].text)
|
||||
#
|
||||
# print("deleting")
|
||||
# db.delete()
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skipif(
|
||||
# not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key"
|
||||
# )
|
||||
# def test_lancedb_openai():
|
||||
# assert os.getenv("LANCEDB_TEST_URL") is not None
|
||||
# if os.getenv("OPENAI_API_KEY") is None:
|
||||
# return # soft pass
|
||||
#
|
||||
# config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
|
||||
# print(config.config_path)
|
||||
# assert config.archival_storage_uri is not None
|
||||
# print(config)
|
||||
#
|
||||
# embed_model = embedding_model()
|
||||
#
|
||||
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
#
|
||||
# db = LanceDBConnector(name="test-openai")
|
||||
#
|
||||
# for passage in passage:
|
||||
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
#
|
||||
# print(db.get_all())
|
||||
#
|
||||
# query = "why was she crying"
|
||||
# query_vec = embed_model.get_text_embedding(query)
|
||||
# res = db.query(None, query_vec, top_k=2)
|
||||
#
|
||||
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
|
||||
# def test_postgres_local():
|
||||
# if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
# return
|
||||
# # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
#
|
||||
# config = MemGPTConfig(
|
||||
# archival_storage_type="postgres",
|
||||
# archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
# embedding_endpoint_type="local",
|
||||
# embedding_dim=384, # use HF model
|
||||
# )
|
||||
# print(config.config_path)
|
||||
# assert config.archival_storage_uri is not None
|
||||
# config.archival_storage_uri = config.archival_storage_uri.replace(
|
||||
# "postgres://", "postgresql://"
|
||||
# ) # https://stackoverflow.com/a/64698899
|
||||
# config.save()
|
||||
# print(config)
|
||||
#
|
||||
# embed_model = embedding_model()
|
||||
#
|
||||
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
#
|
||||
# db = PostgresStorageConnector(name="test-local")
|
||||
#
|
||||
# for passage in passage:
|
||||
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
#
|
||||
# print(db.get_all())
|
||||
#
|
||||
# query = "why was she crying"
|
||||
# query_vec = embed_model.get_text_embedding(query)
|
||||
# res = db.query(None, query_vec, top_k=2)
|
||||
#
|
||||
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
#
|
||||
# # TODO fix (causes a hang for some reason)
|
||||
# # print("deleting...")
|
||||
# # db.delete()
|
||||
# # print("...finished")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI")
|
||||
# def test_lancedb_local():
|
||||
# assert os.getenv("LANCEDB_TEST_URL") is not None
|
||||
#
|
||||
# config = MemGPTConfig(
|
||||
# archival_storage_type="lancedb",
|
||||
# archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
|
||||
# embedding_model="local",
|
||||
# embedding_dim=384, # use HF model
|
||||
# )
|
||||
# print(config.config_path)
|
||||
# assert config.archival_storage_uri is not None
|
||||
#
|
||||
# embed_model = embedding_model()
|
||||
#
|
||||
# passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
#
|
||||
# db = LanceDBConnector(name="test-local")
|
||||
#
|
||||
# for passage in passage:
|
||||
# db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
#
|
||||
# print(db.get_all())
|
||||
#
|
||||
# query = "why was she crying"
|
||||
# query_vec = embed_model.get_text_embedding(query)
|
||||
# res = db.query(None, query_vec, top_k=2)
|
||||
#
|
||||
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user