diff --git a/docs/storage.md b/docs/storage.md index ea143e46..72bfbac8 100644 --- a/docs/storage.md +++ b/docs/storage.md @@ -18,5 +18,22 @@ pip install 'pymemgpt[postgres]' You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation). +## LanceDB +In order to use the LanceDB backend. + + You have to enable the LanceDB backend by running + + ``` + memgpt configure + ``` + and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. + +To enable the LanceDB backend, make sure to install the required dependencies with: +``` +pip install 'pymemgpt[lancedb]' +``` +for more checkout [lancedb docs](https://lancedb.github.io/lancedb/) + + ## Chroma (Coming soon) diff --git a/memgpt/autogen/examples/memgpt_coder_autogen.ipynb b/memgpt/autogen/examples/memgpt_coder_autogen.ipynb index 6719b739..28b174f4 100644 --- a/memgpt/autogen/examples/memgpt_coder_autogen.ipynb +++ b/memgpt/autogen/examples/memgpt_coder_autogen.ipynb @@ -38,7 +38,8 @@ "outputs": [], "source": [ "import openai\n", - "openai.api_key=\"YOUR_API_KEY\"" + "\n", + "openai.api_key = \"YOUR_API_KEY\"" ] }, { diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index c57a4002..6c05bf84 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -210,7 +210,7 @@ def configure_cli(config: MemGPTConfig): def configure_archival_storage(config: MemGPTConfig): # Configure archival storage backend - archival_storage_options = ["local", "postgres"] + archival_storage_options = ["local", "lancedb", "postgres"] archival_storage_type = questionary.select( "Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type ).ask() @@ -220,8 +220,17 @@ def configure_archival_storage(config: MemGPTConfig): "Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):", default=config.archival_storage_uri if config.archival_storage_uri else "", ).ask() + + if archival_storage_type == "lancedb": + archival_storage_uri = questionary.text( + "Enter lanncedb connection string (e.g. ./.lancedb", + default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb", + ).ask() + return archival_storage_type, archival_storage_uri + # TODO: allow configuring embedding model + @app.command() def configure(): diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 9ae873b8..6d2bb3e2 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -13,6 +13,7 @@ from tqdm import tqdm from typing import Optional, List, Iterator import numpy as np from tqdm import tqdm +import pandas as pd from memgpt.config import MemGPTConfig from memgpt.connectors.storage import StorageConnector, Passage @@ -181,3 +182,139 @@ class PostgresStorageConnector(StorageConnector): def generate_table_name(self, name: str): return f"memgpt_{self.sanitize_table_name(name)}" + + +class LanceDBConnector(StorageConnector): + """Storage via LanceDB""" + + # TODO: this should probably eventually be moved into a parent DB class + + def __init__(self, name: Optional[str] = None): + config = MemGPTConfig.load() + + # determine table name + if name: + self.table_name = self.generate_table_name(name) + else: + self.table_name = "lancedb_tbl" + + printd(f"Using table name {self.table_name}") + + # create table + self.uri = config.archival_storage_uri + if config.archival_storage_uri is None: + raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}") + import lancedb + + self.db = lancedb.connect(self.uri) + self.table = None + + def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]: + session = self.Session() + offset = 0 + while True: + # Retrieve a chunk of records with the given page_size + db_passages_chunk = self.table.search().limit(page_size).to_list() + + # If the chunk is empty, we've retrieved all records + if not db_passages_chunk: + break + + # Yield a list of Passage objects converted from the chunk + yield [ + Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk + ] + + # Increment the offset to get the next chunk in the next iteration + offset += page_size + + def get_all(self, limit=10) -> List[Passage]: + db_passages = self.table.search().limit(limit).to_list() + return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages] + + def get(self, id: str) -> Optional[Passage]: + db_passage = self.table.where(f"passage_id={id}").to_list() + if len(db_passage) == 0: + return None + return Passage( + text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"] + ) + + def size(self) -> int: + # return size of table + if self.table: + return len(self.table.search().to_list()) + else: + print(f"Table with name {self.table_name} not present") + return 0 + + def insert(self, passage: Passage): + data = [{"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}] + + if self.table: + self.table.add(data) + else: + self.table = self.db.create_table(self.table_name, data=data, mode="overwrite") + + def insert_many(self, passages: List[Passage], show_progress=True): + data = [] + iterable = tqdm(passages) if show_progress else passages + for passage in iterable: + temp_dict = {"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding} + data.append(temp_dict) + + if self.table: + self.table.add(data) + else: + self.table = self.db.create_table(self.table_name, data=data, mode="overwrite") + + def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]: + # Assuming query_vec is of same length as embeddings inside table + results = self.table.search(query_vec).limit(top_k) + + # Convert the results into Passage objects + passages = [ + Passage(text=result["text"], embedding=result["embedding"], doc_id=result["doc_id"], passage_id=result["passage_id"]) + for result in results + ] + return passages + + def delete(self): + """Drop the passage table from the database.""" + # Drop the table specified by the PassageModel class + self.db.drop_table(self.table_name) + + def save(self): + return + + @staticmethod + def list_loaded_data(): + config = MemGPTConfig.load() + import lancedb + + db = lancedb.connect(config.archival_storage_uri) + + tables = db.table_names() + tables = [table for table in tables if table.startswith("memgpt_")] + tables = [table.replace("memgpt_", "") for table in tables] + return tables + + def sanitize_table_name(self, name: str) -> str: + # Remove leading and trailing whitespace + name = name.strip() + + # Replace spaces and invalid characters with underscores + name = re.sub(r"\s+|\W+", "_", name) + + # Truncate to the maximum identifier length + max_length = 63 + if len(name) > max_length: + name = name[:max_length].rstrip("_") + + # Convert to lowercase + name = name.lower() + + return name + + def generate_table_name(self, name: str): + return f"memgpt_{self.sanitize_table_name(name)}" diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 21a69cb8..0852f79a 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -48,6 +48,11 @@ class StorageConnector: return PostgresStorageConnector(name=name, agent_config=agent_config) + elif storage_type == "lancedb": + from memgpt.connectors.db import LanceDBConnector + + return LanceDBConnector(name=name) + else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @@ -62,6 +67,11 @@ class StorageConnector: from memgpt.connectors.db import PostgresStorageConnector return PostgresStorageConnector.list_loaded_data() + + elif storage_type == "lancedb": + from memgpt.connectors.db import LanceDBConnector + + return LanceDBConnector.list_loaded_data() else: raise NotImplementedError(f"Storage type {storage_type} not implemented") diff --git a/poetry.lock b/poetry.lock index f48f7354..9e2e3997 100644 --- a/poetry.lock +++ b/poetry.lock @@ -250,6 +250,17 @@ d = ["aiohttp (>=3.7.4)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "cachetools" +version = "5.3.2" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, + {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, +] + [[package]] name = "certifi" version = "2023.7.22" @@ -482,6 +493,17 @@ tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elast torch = ["torch"] vision = ["Pillow (>=6.2.1)"] +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "demjson3" version = "3.0.6" @@ -509,6 +531,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "dill" version = "0.3.7" @@ -910,6 +946,39 @@ files = [ {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] +[[package]] +name = "lancedb" +version = "0.3.3" +description = "lancedb" +optional = false +python-versions = ">=3.8" +files = [ + {file = "lancedb-0.3.3-py3-none-any.whl", hash = "sha256:67ccea22a6cb39c688041f7469be778a2e64b141db80866f6f0dec25a3122aff"}, + {file = "lancedb-0.3.3.tar.gz", hash = "sha256:8d8a9c2b107154ee57f6f75957d215719a204cd64c9efbe7095eaf41b43c2a29"}, +] + +[package.dependencies] +aiohttp = "*" +attrs = ">=21.3.0" +cachetools = "*" +click = ">=8.1.7" +deprecation = "*" +pydantic = ">=1.10" +pylance = "0.8.10" +pyyaml = ">=6.0" +ratelimiter = ">=1.0,<2.0" +requests = ">=2.31.0" +retry = ">=0.9.2" +semver = ">=3.0" +tqdm = ">=4.1.0" + +[package.extras] +clip = ["open-clip", "pillow", "torch"] +dev = ["black", "pre-commit", "ruff"] +docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +embeddings = ["cohere", "open-clip-torch", "openai", "pillow", "sentence-transformers", "torch"] +tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock", "requests"] + [[package]] name = "langchain" version = "0.0.333" @@ -1936,11 +2005,22 @@ files = [ {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"}, ] +[[package]] +name = "py" +version = "1.11.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, + {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, +] + [[package]] name = "pyarrow" version = "14.0.1" description = "Python library for Apache Arrow" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "pyarrow-14.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:96d64e5ba7dceb519a955e5eeb5c9adcfd63f73a56aea4722e2cc81364fc567a"}, @@ -2135,6 +2215,28 @@ files = [ [package.extras] plugins = ["importlib-metadata"] +[[package]] +name = "pylance" +version = "0.8.10" +description = "python wrapper for Lance columnar format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pylance-0.8.10-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:aecf053e12f13a1810a70c786c1e73bcf3ffe7287c0bfe2cc5df77a91f0a084c"}, + {file = "pylance-0.8.10-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:b778fbcfae2e9186053292b7bd3fcd28efc92bd0471f733f8dbf4a1f840c9ce4"}, + {file = "pylance-0.8.10-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ea617723593d4cc0d2faaaf4a861e31ae3c8657517b83e2fb99e5f68c0c1481"}, + {file = "pylance-0.8.10-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b1bf9cc33d7095196931f96588d733e80d69a4c312b5352d9dab9a0d5a84c8f"}, + {file = "pylance-0.8.10-cp38-abi3-win_amd64.whl", hash = "sha256:d700f874710c6f1a5567c6e4f98426c22aebf9937dcdbd7305573f519712b683"}, +] + +[package.dependencies] +numpy = ">=1.22" +pyarrow = ">=10" + +[package.extras] +benchmarks = ["pytest-benchmark"] +tests = ["duckdb", "ml_dtypes", "pandas (>=1.4,<2.1)", "polars[pandas,pyarrow]", "pytest", "semver", "tensorflow", "tqdm"] + [[package]] name = "pymupdf" version = "1.23.6" @@ -2328,6 +2430,20 @@ files = [ [package.dependencies] prompt_toolkit = ">=2.0,<=3.0.36" +[[package]] +name = "ratelimiter" +version = "1.2.0.post0" +description = "Simple python rate limiting object" +optional = false +python-versions = "*" +files = [ + {file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"}, + {file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"}, +] + +[package.extras] +test = ["pytest (>=3.0)", "pytest-asyncio"] + [[package]] name = "regex" version = "2023.10.3" @@ -2446,6 +2562,21 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "retry" +version = "0.9.2" +description = "Easy to use retry decorator." +optional = false +python-versions = "*" +files = [ + {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"}, + {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"}, +] + +[package.dependencies] +decorator = ">=3.4.2" +py = ">=1.4.26,<2.0.0" + [[package]] name = "rich" version = "13.6.0" @@ -2597,6 +2728,17 @@ files = [ [package.dependencies] asn1crypto = ">=1.5.1" +[[package]] +name = "semver" +version = "3.0.2" +description = "Python helper for Semantic Versioning (https://semver.org)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"}, + {file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"}, +] + [[package]] name = "setuptools" version = "68.2.2" @@ -3601,6 +3743,7 @@ multidict = ">=4.0" [extras] dev = ["black", "datasets", "pre-commit", "pytest"] +lancedb = [] legacy = ["faiss-cpu", "numpy"] local = ["huggingface-hub", "torch", "transformers"] postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"] @@ -3608,4 +3751,4 @@ postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary" [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "0fa0b65ce00550c139abcf5b4134e9e5b19b277930782ffe8421afec9d2743e2" +content-hash = "130c4da6c4b59aeb80aecf9549f75bed28123c275e30f159232e491d726034d5" diff --git a/pyproject.toml b/pyproject.toml index da735001..803a569d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,10 +47,12 @@ pg8000 = {version = "^1.30.3", optional = true} torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true} websockets = "^12.0" docstring-parser = "^0.15" +lancedb = {version = "^0.3.3", optional = true} [tool.poetry.extras] legacy = ["faiss-cpu", "numpy"] local = ["torch", "huggingface-hub", "transformers"] +lancedb = ["lancedb"] postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"] dev = ["pytest", "black", "pre-commit", "datasets"] diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index f58bc0c7..a303279f 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -46,6 +46,39 @@ def test_postgres(): ) +def test_lancedb(): + return + + subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"]) + import lancedb # Try to import again after installing + + # override config path with enviornment variable + # TODO: make into temporary file + os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg" + print("env", os.getenv("MEMGPT_CONFIG_PATH")) + config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb", config_path=os.getenv("MEMGPT_CONFIG_PATH")) + print(config) + config.save() + + # loading dataset from hugging face + name = "tmp_hf_dataset" + + dataset = load_dataset("MemGPT/example_short_stories") + + cache_dir = os.getenv("HF_DATASETS_CACHE") + if cache_dir is None: + # Construct the default path if the environment variable is not set. + cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") + + config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb") + + load_directory( + name=name, + input_dir=cache_dir, + recursive=True, + ) + + def test_chroma(): return diff --git a/tests/test_storage.py b/tests/test_storage.py index efecfedc..3ace6650 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -6,10 +6,12 @@ import pytest subprocess.check_call( [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"] ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"]) + +subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"]) import pgvector # Try to import again after installing from memgpt.connectors.storage import StorageConnector, Passage -from memgpt.connectors.db import PostgresStorageConnector +from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector from memgpt.embeddings import embedding_model from memgpt.config import MemGPTConfig, AgentConfig @@ -57,6 +59,38 @@ def test_postgres_openai(): # print("...finished") +@pytest.mark.skipif( + not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key" +) +def test_lancedb_openai(): + assert os.getenv("LANCEDB_TEST_URL") is not None + if os.getenv("OPENAI_API_KEY") is None: + return # soft pass + + config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL")) + print(config.config_path) + assert config.archival_storage_uri is not None + print(config) + + embed_model = embedding_model() + + passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] + + db = LanceDBConnector(name="test-openai") + + for passage in passage: + db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) + + print(db.get_all()) + + query = "why was she crying" + query_vec = embed_model.get_text_embedding(query) + res = db.query(None, query_vec, top_k=2) + + assert len(res) == 2, f"Expected 2 results, got {len(res)}" + assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" + + @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI") def test_postgres_local(): if not os.getenv("PGVECTOR_TEST_DB_URL"): @@ -101,4 +135,33 @@ def test_postgres_local(): # print("...finished") -# test_postgres() +@pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI") +def test_lancedb_local(): + assert os.getenv("LANCEDB_TEST_URL") is not None + + config = MemGPTConfig( + archival_storage_type="lancedb", + archival_storage_uri=os.getenv("LANCEDB_TEST_URL"), + embedding_model="local", + embedding_dim=384, # use HF model + ) + print(config.config_path) + assert config.archival_storage_uri is not None + + embed_model = embedding_model() + + passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] + + db = LanceDBConnector(name="test-local") + + for passage in passage: + db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) + + print(db.get_all()) + + query = "why was she crying" + query_vec = embed_model.get_text_embedding(query) + res = db.query(None, query_vec, top_k=2) + + assert len(res) == 2, f"Expected 2 results, got {len(res)}" + assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"