Bugfixes and test updates for passing tests for both postgres and chroma
This commit is contained in:
@@ -19,6 +19,8 @@ class ChromaStorageConnector(StorageConnector):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
assert table_type == TableType.ARCHIVAL_MEMORY, "Chroma only supports archival memory"
|
||||
|
||||
# create chroma client
|
||||
if config.archival_storage_path:
|
||||
self.client = chromadb.PersistentClient(config.archival_storage_path)
|
||||
@@ -34,7 +36,6 @@ class ChromaStorageConnector(StorageConnector):
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}):
|
||||
# get all filters for query
|
||||
print("GET FILTER", filters)
|
||||
if filters is not None:
|
||||
filter_conditions = {**self.filters, **filters}
|
||||
else:
|
||||
@@ -53,12 +54,9 @@ class ChromaStorageConnector(StorageConnector):
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
|
||||
offset = 0
|
||||
ids, filters = self.get_filters(filters)
|
||||
print("FILTERS", filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
print("querying...", self.collection.count(), "offset", offset, "page", page_size)
|
||||
results = self.collection.get(ids=ids, offset=offset, limit=page_size, include=self.include, where=filters)
|
||||
print(len(results["embeddings"]))
|
||||
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
if len(results["embeddings"]) == 0:
|
||||
@@ -72,9 +70,6 @@ class ChromaStorageConnector(StorageConnector):
|
||||
|
||||
def results_to_records(self, results):
|
||||
# convert timestamps to datetime
|
||||
print("ID", results["ids"])
|
||||
print("ID TYPE", type(results["ids"][0]))
|
||||
print(uuid.UUID(results["ids"][0]))
|
||||
for metadata in results["metadatas"]:
|
||||
if "created_at" in metadata:
|
||||
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
|
||||
@@ -126,13 +121,11 @@ class ChromaStorageConnector(StorageConnector):
|
||||
record_metadata = {}
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
|
||||
metadata = {**metadata, **record_metadata} # merge with metadata
|
||||
print("m", metadata)
|
||||
metadatas.append(metadata)
|
||||
return ids, documents, embeddings, metadatas
|
||||
|
||||
def insert(self, record: Record):
|
||||
ids, documents, embeddings, metadatas = self.format_records([record])
|
||||
print("metadata", record, metadatas)
|
||||
if not any(embeddings):
|
||||
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
else:
|
||||
@@ -149,10 +142,14 @@ class ChromaStorageConnector(StorageConnector):
|
||||
ids, filters = self.get_filters(filters)
|
||||
self.collection.delete(ids=ids, where=filters)
|
||||
|
||||
def delete_table(self):
|
||||
# drop collection
|
||||
self.client.delete_collection(self.collection.name)
|
||||
|
||||
def save(self):
|
||||
# save to persistence file (nothing needs to be done)
|
||||
printd("Saving chroma")
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# unfortuantely, need to use pagination to get filtering
|
||||
|
||||
@@ -4,6 +4,7 @@ import psycopg
|
||||
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import sessionmaker, mapped_column
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
@@ -81,12 +82,18 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
|
||||
# openai info
|
||||
role = Column(String, nullable=False)
|
||||
text = Column(String, nullable=False)
|
||||
model = Column(String, nullable=False)
|
||||
user = Column(String) # optional: multi-agent only
|
||||
|
||||
# function info
|
||||
function_name = Column(String)
|
||||
function_args = Column(String)
|
||||
function_response = Column(String)
|
||||
|
||||
embedding = mapped_column(Vector(config.embedding_dim))
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
@@ -100,6 +107,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
role=self.role,
|
||||
user=self.user,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
function_name=self.function_name,
|
||||
@@ -118,7 +126,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
|
||||
class PostgresStorageConnector(StorageConnector):
|
||||
class SQLStorageConnector(StorageConnector):
|
||||
"""Storage via Postgres"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
@@ -127,6 +135,8 @@ class PostgresStorageConnector(StorageConnector):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# TODO: only support recall memory (need postgres for archival)
|
||||
|
||||
# get storage URI
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
self.uri = config.archival_storage_uri
|
||||
@@ -155,20 +165,20 @@ class PostgresStorageConnector(StorageConnector):
|
||||
|
||||
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]:
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
|
||||
session = self.Session()
|
||||
offset = 0
|
||||
filters = self.get_filters(filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
db_passages_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
|
||||
db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
|
||||
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
if not db_passages_chunk:
|
||||
if not db_record_chunk:
|
||||
break
|
||||
|
||||
# Yield a list of Record objects converted from the chunk
|
||||
yield [self.type(**p.to_dict()) for p in db_passages_chunk]
|
||||
yield [record.to_record() for record in db_record_chunk]
|
||||
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
@@ -179,10 +189,9 @@ class PostgresStorageConnector(StorageConnector):
|
||||
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
return [record.to_record() for record in db_records]
|
||||
|
||||
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Record]:
|
||||
def get(self, id: str) -> Optional[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
db_record = session.query(self.db_model).filter(*filters).get(id)
|
||||
db_record = session.query(self.db_model).get(id)
|
||||
if db_record is None:
|
||||
return None
|
||||
return db_record.to_record()
|
||||
@@ -209,15 +218,7 @@ class PostgresStorageConnector(StorageConnector):
|
||||
session.commit()
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
results = session.scalars(
|
||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||
).all()
|
||||
|
||||
# Convert the results into Passage objects
|
||||
records = [result.to_record() for result in results]
|
||||
return records
|
||||
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
|
||||
|
||||
def save(self):
|
||||
return
|
||||
@@ -255,11 +256,70 @@ class PostgresStorageConnector(StorageConnector):
|
||||
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
|
||||
session = self.Session()
|
||||
filters = self.get_filters({})
|
||||
results = session.query(self.db_model).filter(*filters).filter(self.db_model.text.contains(query)).all()
|
||||
print(results)
|
||||
results = session.query(self.db_model).filter(*filters).filter(func.lower(self.db_model.text).contains(func.lower(query))).all()
|
||||
# return [self.type(**vars(result)) for result in results]
|
||||
return [result.to_record() for result in results]
|
||||
|
||||
def delete_table(self):
|
||||
session = self.Session()
|
||||
self.db_model.__table__.drop(session.bind)
|
||||
session.commit()
|
||||
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
session.query(self.db_model).filter(*filters).delete()
|
||||
session.commit()
|
||||
|
||||
|
||||
class PostgresStorageConnector(SQLStorageConnector):
|
||||
"""Storage via Postgres"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
results = session.scalars(
|
||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||
).all()
|
||||
|
||||
# Convert the results into Passage objects
|
||||
records = [result.to_record() for result in results]
|
||||
return records
|
||||
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
session.query(self.db_model).filter(*filters).delete()
|
||||
session.commit()
|
||||
|
||||
|
||||
class PostgresStorageConnector(SQLStorageConnector):
|
||||
"""Storage via Postgres"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
results = session.scalars(
|
||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||
).all()
|
||||
|
||||
# Convert the results into Passage objects
|
||||
records = [result.to_record() for result in results]
|
||||
return records
|
||||
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
|
||||
@@ -277,8 +337,6 @@ class LanceDBConnector(StorageConnector):
|
||||
else:
|
||||
raise ValueError("Must specify either agent config or name")
|
||||
|
||||
printd(f"Using table name {self.table_name}")
|
||||
|
||||
# create table
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
@@ -326,7 +384,7 @@ class LanceDBConnector(StorageConnector):
|
||||
if self.table:
|
||||
return len(self.table)
|
||||
else:
|
||||
print(f"Table with name {self.table_name} not present")
|
||||
printd(f"Table with name {self.table_name} not present")
|
||||
return 0
|
||||
|
||||
def insert(self, passage: Passage):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, List, Iterator
|
||||
import shutil
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
@@ -181,39 +182,56 @@ class InMemoryStorageConnector(StorageConnector):
|
||||
raise ValueError(f"Table type {table_type} not supported by InMemoryStorageConnector")
|
||||
|
||||
# TODO: load if exists
|
||||
self.agent_config = agent_config
|
||||
if agent_config is None:
|
||||
# is a data source
|
||||
raise ValueError("Cannot load data source from InMemoryStorageConnector")
|
||||
else:
|
||||
directory = agent_config.save_state_dir()
|
||||
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
|
||||
if not json_files:
|
||||
print(f"/load error: no .json checkpoint files found")
|
||||
raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}")
|
||||
if os.path.exists(directory):
|
||||
print(f"Loading saved agent {agent_config.name} from {directory}")
|
||||
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
|
||||
if not json_files:
|
||||
print(f"/load error: no .json checkpoint files found")
|
||||
raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}")
|
||||
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
filename = max(json_files, key=os.path.getmtime)
|
||||
state = json.load(open(filename, "r"))
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
filename = max(json_files, key=os.path.getmtime)
|
||||
state = json.load(open(filename, "r"))
|
||||
|
||||
# load persistence manager
|
||||
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
|
||||
directory = agent_config.save_persistence_manager_dir()
|
||||
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
|
||||
with open(filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
self.rows = data["all_messages"]
|
||||
# load persistence manager
|
||||
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
|
||||
directory = agent_config.save_persistence_manager_dir()
|
||||
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
|
||||
with open(filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
self.rows = data["all_messages"]
|
||||
else:
|
||||
print(f"Creating new agent {agent_config.name}")
|
||||
self.rows = []
|
||||
|
||||
# convert to Record class
|
||||
self.rows = [self.json_to_message(m) for m in self.rows]
|
||||
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
|
||||
raise NotImplementedError
|
||||
offset = 0
|
||||
while True:
|
||||
yield self.rows[offset : offset + page_size]
|
||||
offset += page_size
|
||||
if offset >= len(self.rows):
|
||||
break
|
||||
|
||||
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
|
||||
raise NotImplementedError
|
||||
def get_all(self, limit: Optional[int] = None, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
if limit:
|
||||
return self.rows[:limit]
|
||||
return self.rows
|
||||
|
||||
def get(self, id: str) -> Record:
|
||||
raise NotImplementedError
|
||||
match_row = [row for row in self.rows if row.id == id]
|
||||
if len(match_row) == 0:
|
||||
return None
|
||||
assert len(match_row) == 1, f"Expected 1 match, got {len(match_row)} matches"
|
||||
return match_row[0]
|
||||
|
||||
def insert(self, record: Record):
|
||||
self.rows.append(record)
|
||||
@@ -284,3 +302,12 @@ class InMemoryStorageConnector(StorageConnector):
|
||||
|
||||
def query_text(self, query: str) -> List[Record]:
|
||||
return [row for row in self.rows if row.role not in ["system", "function"] and query.lower() in row.text.lower()]
|
||||
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_table(self, filters: Optional[Dict] = {}):
|
||||
if os.path.exists(self.agent_config.save_state_dir()):
|
||||
shutil.rmtree(self.agent_config.save_state_dir())
|
||||
if os.path.exists(self.agent_config.save_persistence_manager_dir()):
|
||||
shutil.rmtree(self.agent_config.save_persistence_manager_dir())
|
||||
|
||||
@@ -117,6 +117,11 @@ class StorageConnector:
|
||||
|
||||
return LanceDBConnector(agent_config=agent_config, table_type=table_type)
|
||||
|
||||
elif storage_type == "local":
|
||||
from memgpt.connectors.local import InMemoryStorageConnector
|
||||
|
||||
return InMemoryStorageConnector(agent_config=agent_config, table_type=table_type)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@@ -134,6 +139,8 @@ class StorageConnector:
|
||||
if storage_type is None:
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
|
||||
return
|
||||
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import uuid
|
||||
import subprocess
|
||||
import sys
|
||||
import pytest
|
||||
@@ -11,7 +12,7 @@ import pytest
|
||||
import pgvector # Try to import again after installing
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
||||
from memgpt.connectors.db import SQLStorageConnector, LanceDBConnector
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Message, Passage
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
@@ -22,13 +23,13 @@ from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMA
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
start_date = datetime(2009, 10, 5, 18, 00)
|
||||
dates = [start_date - timedelta(weeks=1), start_date, start_date + timedelta(weeks=1)]
|
||||
dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)]
|
||||
roles = ["user", "agent", "agent"]
|
||||
agent_ids = ["agent1", "agent2", "agent1"]
|
||||
ids = ["test1", "test2", "test3"] # TODO: generate unique uuid
|
||||
ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()]
|
||||
user_id = "test_user"
|
||||
|
||||
|
||||
@@ -41,16 +42,7 @@ def generate_passages(embed_model):
|
||||
embedding = None
|
||||
if embed_model:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user_id,
|
||||
text=text,
|
||||
agent_id=agent_id,
|
||||
embedding=embedding,
|
||||
data_source="test_source",
|
||||
id=id,
|
||||
)
|
||||
)
|
||||
passages.append(Passage(user_id=user_id, text=text, agent_id=agent_id, embedding=embedding, data_source="test_source", id=id))
|
||||
return passages
|
||||
|
||||
|
||||
@@ -65,7 +57,8 @@ def generate_messages():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY])
|
||||
# @pytest.mark.parametrize("storage_connector", ["postgres"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
|
||||
def test_storage(storage_connector, table_type):
|
||||
|
||||
# setup memgpt config
|
||||
@@ -88,10 +81,16 @@ def test_storage(storage_connector, table_type):
|
||||
config.archival_storage_type = "lancedb"
|
||||
config.recall_storage_type = "lancedb"
|
||||
if storage_connector == "chroma":
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
print("Skipping test, chroma only supported for archival memory")
|
||||
return
|
||||
config.archival_storage_type = "chroma"
|
||||
config.recall_storage_type = "chroma"
|
||||
config.recall_storage_path = "./test_chroma"
|
||||
config.archival_storage_path = "./test_chroma"
|
||||
if storage_connector == "local":
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
print("Skipping test, local only supported for recall memory")
|
||||
return
|
||||
config.recall_storage_type = "local"
|
||||
|
||||
# get embedding model
|
||||
embed_model = None
|
||||
@@ -116,7 +115,8 @@ def test_storage(storage_connector, table_type):
|
||||
|
||||
# create storage connector
|
||||
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
|
||||
conn.delete() # clear out data
|
||||
# conn.client.delete_collection(conn.collection.name) # clear out data
|
||||
conn.delete_table()
|
||||
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
|
||||
|
||||
# override filters
|
||||
@@ -161,6 +161,7 @@ def test_storage(storage_connector, table_type):
|
||||
assert len(all_records) == 1, f"Expected 1 records, got {len(all_records)}"
|
||||
|
||||
# test: get
|
||||
print("GET ID", ids[0], records)
|
||||
res = conn.get(id=ids[0])
|
||||
assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
|
||||
|
||||
@@ -178,8 +179,8 @@ def test_storage(storage_connector, table_type):
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
# test optional query functions
|
||||
if storage_connector != "chroma":
|
||||
# test optional query functions for recall memory
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
# test: query_text
|
||||
query = "CindereLLa"
|
||||
res = conn.query_text(query)
|
||||
@@ -187,12 +188,13 @@ def test_storage(storage_connector, table_type):
|
||||
assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
|
||||
|
||||
# test: query_date (recall memory only)
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
print("Testing recall memory date search")
|
||||
start_date = start_date - timedelta(days=1)
|
||||
end_date = start_date + timedelta(days=1)
|
||||
res = conn.query_date(start_date=start_date, end_date=end_date)
|
||||
assert len(res) == 1, f"Expected 1 result, got {len(res): {res}}"
|
||||
print("Testing recall memory date search")
|
||||
start_date = datetime(2009, 10, 5, 18, 00)
|
||||
start_date = start_date - timedelta(days=1)
|
||||
end_date = start_date + timedelta(days=1)
|
||||
res = conn.query_date(start_date=start_date, end_date=end_date)
|
||||
print("DATE", res)
|
||||
assert len(res) == 1, f"Expected 1 result, got {len(res)}: {res}"
|
||||
|
||||
# test: delete
|
||||
conn.delete({"id": ids[0]})
|
||||
|
||||
Reference in New Issue
Block a user