Files
letta-server/tests/test_storage.py
Sarah Wooders 8ae1e64987 chore: migrate package name to letta (#1775)
Co-authored-by: Charles Packer <packercharles@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
2024-09-23 09:15:18 -07:00

319 lines
12 KiB
Python

# TODO: add back post DB refactor
# import os
# import uuid
# from datetime import datetime, timedelta
#
# import pytest
# from sqlalchemy.ext.declarative import declarative_base
#
# from letta.agent_store.storage import StorageConnector, TableType
# from letta.config import LettaConfig
# from letta.constants import BASE_TOOLS, MAX_EMBEDDING_DIM
# from letta.credentials import LettaCredentials
# from letta.embeddings import embedding_model, query_embedding
# from letta.metadata import MetadataStore
# from letta.settings import settings
# from tests import TEST_MEMGPT_CONFIG
# from tests.utils import create_config, wipe_config
#
# from .utils import with_qdrant_storage
#
# from letta.schemas.agent import AgentState
# from letta.schemas.message import Message
# from letta.schemas.passage import Passage
# from letta.schemas.user import User
#
#
## Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
# texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
# start_date = datetime(2009, 10, 5, 18, 00)
# dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)]
# roles = ["user", "assistant", "assistant"]
# agent_1_id = uuid.uuid4()
# agent_2_id = uuid.uuid4()
# agent_ids = [agent_1_id, agent_2_id, agent_1_id]
# ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()]
# user_id = uuid.uuid4()
#
#
## Data generation functions: Passages
# def generate_passages(embed_model):
# """Generate list of 3 Passage objects"""
# # embeddings: use openai if env is set, otherwise local
# passages = []
# for text, _, _, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
# embedding, embedding_model, embedding_dim = None, None, None
# if embed_model:
# embedding = embed_model.get_text_embedding(text)
# embedding_model = "gpt-4"
# embedding_dim = len(embedding)
# passages.append(
# Passage(
# user_id=user_id,
# text=text,
# agent_id=agent_id,
# embedding=embedding,
# data_source="test_source",
# id=id,
# embedding_dim=embedding_dim,
# embedding_model=embedding_model,
# )
# )
# return passages
#
#
## Data generation functions: Messages
# def generate_messages(embed_model):
# """Generate list of 3 Message objects"""
# messages = []
# for text, date, role, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
# embedding, embedding_model, embedding_dim = None, None, None
# if embed_model:
# embedding = embed_model.get_text_embedding(text)
# embedding_model = "gpt-4"
# embedding_dim = len(embedding)
# messages.append(
# Message(
# user_id=user_id,
# text=text,
# agent_id=agent_id,
# role=role,
# created_at=date,
# id=id,
# model="gpt-4",
# embedding=embedding,
# embedding_model=embedding_model,
# embedding_dim=embedding_dim,
# )
# )
# print(messages[-1].text)
# return messages
#
#
# @pytest.fixture(autouse=True)
# def clear_dynamically_created_models():
# """Wipe globals for SQLAlchemy"""
# yield
# for key in list(globals().keys()):
# if key.endswith("Model"):
# del globals()[key]
#
#
# @pytest.fixture(autouse=True)
# def recreate_declarative_base():
# """Recreate the declarative base before each test"""
# global Base
# Base = declarative_base()
# yield
# Base.metadata.clear()
#
#
# @pytest.mark.parametrize("storage_connector", with_qdrant_storage(["postgres", "chroma", "sqlite", "milvus"]))
## @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"])
## @pytest.mark.parametrize("storage_connector", ["postgres"])
# @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
# def test_storage(
# storage_connector,
# table_type,
# clear_dynamically_created_models,
# recreate_declarative_base,
# ):
# # setup letta config
# # TODO: set env for different config path
#
# # hacky way to cleanup globals that scruw up tests
# # for table_name in ['Message']:
# # if 'Message' in globals():
# # print("Removing messages", globals()['Message'])
# # del globals()['Message']
#
# wipe_config()
# if os.getenv("OPENAI_API_KEY"):
# create_config("openai")
# credentials = LettaCredentials(
# openai_key=os.getenv("OPENAI_API_KEY"),
# )
# else: # hosted
# create_config("letta_hosted")
# LettaCredentials()
#
# config = LettaConfig.load()
# TEST_MEMGPT_CONFIG.default_embedding_config = config.default_embedding_config
# TEST_MEMGPT_CONFIG.default_llm_config = config.default_llm_config
#
# if storage_connector == "postgres":
# TEST_MEMGPT_CONFIG.archival_storage_uri = settings.letta_pg_uri
# TEST_MEMGPT_CONFIG.recall_storage_uri = settings.letta_pg_uri
# TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
# TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
# if storage_connector == "lancedb":
# # TODO: complete lancedb implementation
# if not os.getenv("LANCEDB_TEST_URL"):
# print("Skipping test, missing LanceDB URI")
# return
# TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["LANCEDB_TEST_URL"]
# TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["LANCEDB_TEST_URL"]
# TEST_MEMGPT_CONFIG.archival_storage_type = "lancedb"
# TEST_MEMGPT_CONFIG.recall_storage_type = "lancedb"
# if storage_connector == "chroma":
# if table_type == TableType.RECALL_MEMORY:
# print("Skipping test, chroma only supported for archival memory")
# return
# TEST_MEMGPT_CONFIG.archival_storage_type = "chroma"
# TEST_MEMGPT_CONFIG.archival_storage_path = "./test_chroma"
# if storage_connector == "sqlite":
# if table_type == TableType.ARCHIVAL_MEMORY:
# print("Skipping test, sqlite only supported for recall memory")
# return
# TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite"
# if storage_connector == "qdrant":
# if table_type == TableType.RECALL_MEMORY:
# print("Skipping test, Qdrant only supports archival memory")
# return
# TEST_MEMGPT_CONFIG.archival_storage_type = "qdrant"
# TEST_MEMGPT_CONFIG.archival_storage_uri = "localhost:6333"
# if storage_connector == "milvus":
# if table_type == TableType.RECALL_MEMORY:
# print("Skipping test, Milvus only supports archival memory")
# return
# TEST_MEMGPT_CONFIG.archival_storage_type = "milvus"
# TEST_MEMGPT_CONFIG.archival_storage_uri = "./milvus.db"
# # get embedding model
# embedding_config = TEST_MEMGPT_CONFIG.default_embedding_config
# embed_model = embedding_model(TEST_MEMGPT_CONFIG.default_embedding_config)
#
# # create user
# ms = MetadataStore(TEST_MEMGPT_CONFIG)
# ms.delete_user(user_id)
# user = User(id=user_id)
# agent = AgentState(
# user_id=user_id,
# name="agent_1",
# id=agent_1_id,
# llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
# embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
# system="",
# tools=BASE_TOOLS,
# state={
# "persona": "",
# "human": "",
# "messages": None,
# },
# )
# ms.create_user(user)
# ms.create_agent(agent)
#
# # create storage connector
# conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
# # conn.client.delete_collection(conn.collection.name) # clear out data
# conn.delete_table()
# conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
#
# # generate data
# if table_type == TableType.ARCHIVAL_MEMORY:
# records = generate_passages(embed_model)
# elif table_type == TableType.RECALL_MEMORY:
# records = generate_messages(embed_model)
# else:
# raise NotImplementedError(f"Table type {table_type} not implemented")
#
# # check record dimentions
# print("TABLE TYPE", table_type, type(records[0]), len(records[0].embedding))
# if embed_model:
# assert len(records[0].embedding) == MAX_EMBEDDING_DIM, f"Expected {MAX_EMBEDDING_DIM}, got {len(records[0].embedding)}"
# assert (
# records[0].embedding_dim == embedding_config.embedding_dim
# ), f"Expected {embedding_config.embedding_dim}, got {records[0].embedding_dim}"
#
# # test: insert
# conn.insert(records[0])
# assert conn.size() == 1, f"Expected 1 record, got {conn.size()}: {conn.get_all()}"
#
# # test: insert_many
# conn.insert_many(records[1:])
# assert (
# conn.size() == 2
# ), f"Expected 2 records, got {conn.size()}: {conn.get_all()}" # expect 2, since storage connector filters for agent1
#
# # test: update
# # NOTE: only testing with messages
# if table_type == TableType.RECALL_MEMORY:
# TEST_STRING = "hello world"
#
# updated_record = records[1]
# updated_record.text = TEST_STRING
#
# current_record = conn.get(id=updated_record.id)
# assert current_record is not None, f"Couldn't find {updated_record.id}"
# assert current_record.text != TEST_STRING, (current_record.text, TEST_STRING)
#
# conn.update(updated_record)
# new_record = conn.get(id=updated_record.id)
# assert new_record is not None, f"Couldn't find {updated_record.id}"
# assert new_record.text == TEST_STRING, (new_record.text, TEST_STRING)
#
# # test: list_loaded_data
# # TODO: add back
# # if table_type == TableType.ARCHIVAL_MEMORY:
# # sources = StorageConnector.list_loaded_data(storage_type=storage_connector)
# # assert len(sources) == 1, f"Expected 1 source, got {len(sources)}"
# # assert sources[0] == "test_source", f"Expected 'test_source', got {sources[0]}"
#
# # test: get_all_paginated
# paginated_total = 0
# for page in conn.get_all_paginated(page_size=1):
# paginated_total += len(page)
# assert paginated_total == 2, f"Expected 2 records, got {paginated_total}"
#
# # test: get_all
# all_records = conn.get_all()
# assert len(all_records) == 2, f"Expected 2 records, got {len(all_records)}"
# all_records = conn.get_all(limit=1)
# assert len(all_records) == 1, f"Expected 1 records, got {len(all_records)}"
#
# # test: get
# print("GET ID", ids[0], records)
# res = conn.get(id=ids[0])
# assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
#
# # test: size
# assert conn.size() == 2, f"Expected 2 records, got {conn.size()}"
# assert conn.size(filters={"agent_id": agent.id}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', agent.id})}"
# if table_type == TableType.RECALL_MEMORY:
# assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}"
#
# # test: query (vector)
# if table_type == TableType.ARCHIVAL_MEMORY:
# query = "why was she crying"
# query_vec = query_embedding(embed_model, query)
# res = conn.query(None, query_vec, top_k=2)
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
# print("Archival memory results", res)
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
#
# # test optional query functions for recall memory
# if table_type == TableType.RECALL_MEMORY:
# # test: query_text
# query = "CindereLLa"
# res = conn.query_text(query)
# assert len(res) == 1, f"Expected 1 result, got {len(res)}"
# assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
#
# # test: query_date (recall memory only)
# print("Testing recall memory date search")
# start_date = datetime(2009, 10, 5, 18, 00)
# start_date = start_date - timedelta(days=1)
# end_date = start_date + timedelta(days=1)
# res = conn.query_date(start_date=start_date, end_date=end_date)
# print("DATE", res)
# assert len(res) == 1, f"Expected 1 result, got {len(res)}: {res}"
#
# # test: delete
# conn.delete({"id": ids[0]})
# assert conn.size() == 1, f"Expected 2 records, got {conn.size()}"
#
# # cleanup
# ms.delete_user(user_id)
#