diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 480fe138..05ba22bb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,6 +13,13 @@ jobs: test: runs-on: ubuntu-latest timeout-minutes: 15 + + services: + qdrant: + image: qdrant/qdrant + ports: + - 6333:6333 + steps: - name: Checkout uses: actions/checkout@v4 diff --git a/docs/storage.md b/docs/storage.md index cde9056c..acf5c960 100644 --- a/docs/storage.md +++ b/docs/storage.md @@ -36,19 +36,19 @@ To run the Postgres backend, you will need a URI to a Postgres database that sup 3. Configure the environment for `pgvector`. You can either: - Add the following line to your shell profile (e.g., `~/.bashrc`, `~/.zshrc`): - + ```sh export MEMGPT_PGURI=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt ``` - Or create a `.env` file in the root project directory with: - + ```sh MEMGPT_PGURI=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt ``` 4. Run the script from the root project directory: - + ```sh bash db/run_postgres.sh ``` @@ -105,6 +105,20 @@ memgpt configure and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/) +## Qdrant + +To enable the Qdrant backend, make sure to install the required dependencies with: + +```sh +pip install 'pymemgpt[qdrant]' +``` + +You can configure Qdrant with an in-memory instance or a server using the `memgpt configure` command. You can set an API key for authentication with a Qdrant server using the `QDRANT_API_KEY` environment variable. Learn more about setting up Qdrant [here](https://qdrant.tech/documentation/guides/installation/). + +```sh +? Select Qdrant backend: server +? Enter the Qdrant instance URI (Default: localhost:6333): localhost:6333 +``` ## Milvus diff --git a/memgpt/agent_store/qdrant.py b/memgpt/agent_store/qdrant.py new file mode 100644 index 00000000..640ad91a --- /dev/null +++ b/memgpt/agent_store/qdrant.py @@ -0,0 +1,201 @@ +import os +import uuid +from copy import deepcopy +from typing import Dict, Iterator, List, Optional, cast + +from memgpt.agent_store.storage import StorageConnector, TableType +from memgpt.config import MemGPTConfig +from memgpt.constants import MAX_EMBEDDING_DIM +from memgpt.data_types import Passage, Record, RecordType +from memgpt.utils import datetime_to_timestamp, timestamp_to_datetime + +TEXT_PAYLOAD_KEY = "text_content" +METADATA_PAYLOAD_KEY = "metadata" + + +class QdrantStorageConnector(StorageConnector): + """Storage via Qdrant""" + + def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): + super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) + try: + from qdrant_client import QdrantClient, models + except ImportError as e: + raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e + assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory" + if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2: + host, port = config.archival_storage_uri.split(":") + self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY")) + elif config.archival_storage_path: + self.qdrant_client = QdrantClient(path=config.archival_storage_path) + else: + raise ValueError("Qdrant storage requires either a URI or a path to the storage configured") + if not self.qdrant_client.collection_exists(self.table_name): + self.qdrant_client.create_collection( + collection_name=self.table_name, + vectors_config=models.VectorParams( + size=MAX_EMBEDDING_DIM, + distance=models.Distance.COSINE, + ), + ) + self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"] + + def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]: + from qdrant_client import grpc + + filters = self.get_qdrant_filters(filters) + next_offset = None + stop_scrolling = False + while not stop_scrolling: + results, next_offset = self.qdrant_client.scroll( + collection_name=self.table_name, + scroll_filter=filters, + limit=page_size, + offset=next_offset, + with_payload=True, + with_vectors=True, + ) + stop_scrolling = next_offset is None or ( + isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == "" + ) + yield self.to_records(results) + + def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]: + if self.size(filters) == 0: + return [] + filters = self.get_qdrant_filters(filters) + results, _ = self.qdrant_client.scroll( + self.table_name, + scroll_filter=filters, + limit=limit, + with_payload=True, + with_vectors=True, + ) + return self.to_records(results) + + def get(self, id: uuid.UUID) -> Optional[RecordType]: + results = self.qdrant_client.retrieve( + collection_name=self.table_name, + ids=[str(id)], + with_payload=True, + with_vectors=True, + ) + if not results: + return None + return self.to_records(results)[0] + + def insert(self, record: Record): + points = self.to_points([record]) + self.qdrant_client.upsert(self.table_name, points=points) + + def insert_many(self, records: List[RecordType], show_progress=False): + points = self.to_points(records) + self.qdrant_client.upsert(self.table_name, points=points) + + def delete(self, filters: Optional[Dict] = {}): + filters = self.get_qdrant_filters(filters) + self.qdrant_client.delete(self.table_name, points_selector=filters) + + def delete_table(self): + self.qdrant_client.delete_collection(self.table_name) + self.qdrant_client.close() + + def size(self, filters: Optional[Dict] = {}) -> int: + filters = self.get_qdrant_filters(filters) + return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count + + def close(self): + self.qdrant_client.close() + + def query( + self, + query: str, + query_vec: List[float], + top_k: int = 10, + filters: Optional[Dict] = {}, + ) -> List[RecordType]: + filters = self.get_filters(filters) + results = self.qdrant_client.search( + self.table_name, + query_vector=query_vec, + query_filter=filters, + limit=top_k, + with_payload=True, + with_vectors=True, + ) + return self.to_records(results) + + def to_records(self, records: list) -> List[RecordType]: + parsed_records = [] + for record in records: + record = deepcopy(record) + metadata = record.payload[METADATA_PAYLOAD_KEY] + text = record.payload[TEXT_PAYLOAD_KEY] + _id = metadata.pop("id") + embedding = record.vector + for key, value in metadata.items(): + if key in self.uuid_fields: + metadata[key] = uuid.UUID(value) + elif key == "created_at": + metadata[key] = timestamp_to_datetime(value) + parsed_records.append( + cast( + RecordType, + self.type( + text=text, + embedding=embedding, + id=uuid.UUID(_id), + **metadata, + ), + ) + ) + return parsed_records + + def to_points(self, records: List[RecordType]): + from qdrant_client import models + + assert all(isinstance(r, Passage) for r in records) + points = [] + records = list(set(records)) + for record in records: + record = vars(record) + _id = record.pop("id") + text = record.pop("text", "") + embedding = record.pop("embedding", {}) + record_metadata = record.pop("metadata_", None) or {} + if "created_at" in record: + record["created_at"] = datetime_to_timestamp(record["created_at"]) + metadata = {key: value for key, value in record.items() if value is not None} + metadata = { + **metadata, + **record_metadata, + "id": str(_id), + } + for key, value in metadata.items(): + if key in self.uuid_fields: + metadata[key] = str(value) + points.append( + models.PointStruct( + id=str(_id), + vector=embedding, + payload={ + TEXT_PAYLOAD_KEY: text, + METADATA_PAYLOAD_KEY: metadata, + }, + ) + ) + return points + + def get_qdrant_filters(self, filters: Optional[Dict] = {}): + from qdrant_client import models + + filter_conditions = {**self.filters, **filters} if filters is not None else self.filters + must_conditions = [] + for key, value in filter_conditions.items(): + match_value = str(value) if key in self.uuid_fields else value + field_condition = models.FieldCondition( + key=f"{METADATA_PAYLOAD_KEY}.{key}", + match=models.MatchValue(value=match_value), + ) + must_conditions.append(field_condition) + return models.Filter(must=must_conditions) diff --git a/memgpt/agent_store/storage.py b/memgpt/agent_store/storage.py index 31d6e4d3..215cd80e 100644 --- a/memgpt/agent_store/storage.py +++ b/memgpt/agent_store/storage.py @@ -100,6 +100,10 @@ class StorageConnector: return ChromaStorageConnector(table_type, config, user_id, agent_id) + elif storage_type == "qdrant": + from memgpt.agent_store.qdrant import QdrantStorageConnector + + return QdrantStorageConnector(table_type, config, user_id, agent_id) # TODO: add back # elif storage_type == "lancedb": # from memgpt.agent_store.db import LanceDBConnector diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 6e320518..55d74ad8 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -913,7 +913,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredentials): # Configure archival storage backend - archival_storage_options = ["postgres", "chroma", "milvus"] + archival_storage_options = ["postgres", "chroma", "milvus", "qdrant"] archival_storage_type = questionary.select( "Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type ).ask() @@ -950,6 +950,19 @@ def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredenti if chroma_type == "persistent": archival_storage_path = os.path.join(MEMGPT_DIR, "chroma") + if archival_storage_type == "qdrant": + qdrant_type = questionary.select("Select Qdrant backend:", ["local", "server"], default="local").ask() + if qdrant_type is None: + raise KeyboardInterrupt + if qdrant_type == "server": + archival_storage_uri = questionary.text( + "Enter the Qdrant instance URI (Default: localhost:6333):", default="localhost:6333" + ).ask() + if archival_storage_uri is None: + raise KeyboardInterrupt + if qdrant_type == "local": + archival_storage_path = os.path.join(MEMGPT_DIR, "qdrant") + if archival_storage_type == "milvus": default_milvus_uri = archival_storage_path = os.path.join(MEMGPT_DIR, "milvus.db") archival_storage_uri = questionary.text( diff --git a/poetry.lock b/poetry.lock index 66f50255..3f6b50b8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1422,6 +1422,66 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.63.0)"] +[[package]] +name = "grpcio-tools" +version = "1.48.2" +description = "Protobuf code generator for gRPC" +optional = true +python-versions = ">=3.6" +files = [ + {file = "grpcio-tools-1.48.2.tar.gz", hash = "sha256:8902a035708555cddbd61b5467cea127484362decc52de03f061a1a520fe90cd"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:92acc3e10ba2b0dcb90a88ae9fe1cc0ffba6868545207e4ff20ca95284f8e3c9"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e5bb396d63495667d4df42e506eed9d74fc9a51c99c173c04395fe7604c848f1"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:84a84d601a238572d049d3108e04fe4c206536e81076d56e623bd525a1b38def"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70564521e86a0de35ea9ac6daecff10cb46860aec469af65869974807ce8e98b"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdbbe63f6190187de5946891941629912ac8196701ed2253fa91624a397822ec"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae56f133b05b7e5d780ef7e032dd762adad7f3dc8f64adb43ff5bfabd659f435"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f0feb4f2b777fa6377e977faa89c26359d4f31953de15e035505b92f41aa6906"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-win32.whl", hash = "sha256:80f450272316ca0924545f488c8492649ca3aeb7044d4bf59c426dcdee527f7c"}, + {file = "grpcio_tools-1.48.2-cp310-cp310-win_amd64.whl", hash = "sha256:21ff50e321736eba22210bf9b94e05391a9ac345f26e7df16333dc75d63e74fb"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-linux_armv7l.whl", hash = "sha256:d598ccde6338b2cfbb3124f34c95f03394209013f9b1ed4a5360a736853b1c27"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:a43d26714933f23de93ea0bf9c86c66a6ede709b8ca32e357f9e2181703e64ae"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_aarch64.whl", hash = "sha256:55fdebc73fb580717656b1bafa4f8eca448726a7aa22726a6c0a7895d2f0f088"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8588819b22d0de3aa1951e1991cc3e4b9aa105eecf6e3e24eb0a2fc8ab958b3e"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9771d4d317dca029dfaca7ec9282d8afe731c18bc536ece37fd39b8a974cc331"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d886a9e052a038642b3af5d18e6f2085d1656d9788e202dc23258cf3a751e7ca"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d77e8b1613876e0d8fd17709509d4ceba13492816426bd156f7e88a4c47e7158"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-win32.whl", hash = "sha256:dcaaecdd5e847de5c1d533ea91522bf56c9e6b2dc98cdc0d45f0a1c26e846ea2"}, + {file = "grpcio_tools-1.48.2-cp36-cp36m-win_amd64.whl", hash = "sha256:0119aabd9ceedfdf41b56b9fdc8284dd85a7f589d087f2694d743f346a368556"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:189be2a9b672300ca6845d94016bdacc052fdbe9d1ae9e85344425efae2ff8ef"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:9443f5c30bac449237c3cf99da125f8d6e6c01e17972bc683ee73b75dea95573"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:e0403e095b343431195db1305248b50019ad55d3dd310254431af87e14ef83a2"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5410d6b601d1404835e34466bd8aee37213489b36ee1aad2276366e265ff29d4"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51be91b7c7056ff9ee48b1eccd4a2840b0126230803a5e09dfc082a5b16a91c1"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:516eedd5eb7af6326050bc2cfceb3a977b9cc1144f283c43cc4956905285c912"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d18599ab572b2f15a8f3db49503272d1bb4fcabb4b4d1214ef03aca1816b20a0"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-win32.whl", hash = "sha256:d18ef2adc05a8ef9e58ac46357f6d4ce7e43e077c7eda0a4425773461f9d0e6e"}, + {file = "grpcio_tools-1.48.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6d9753944e5a6b6b78b76ce9d2ae0fe3f748008c1849deb7fadcb64489d6553b"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:3c8749dca04a8d302862ceeb1dfbdd071ee13b281395975f24405a347e5baa57"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:7307dd2408b82ea545ae63502ec03036b025f449568556ea9a056e06129a7a4e"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:072234859f6069dc43a6be8ad6b7d682f4ba1dc2e2db2ebf5c75f62eee0f6dfb"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6cc298fbfe584de8876a85355efbcf796dfbcfac5948c9560f5df82e79336e2a"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f75973a42c710999acd419968bc79f00327e03e855bbe82c6529e003e49af660"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f766050e491d0b3203b6b85638015f543816a2eb7d089fc04e86e00f6de0e31d"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8e0d74403484eb77e8df2566a64b8b0b484b5c87903678c381634dd72f252d5e"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-win32.whl", hash = "sha256:cb75bac0cd43858cb759ef103fe68f8c540cb58b63dda127e710228fec3007b8"}, + {file = "grpcio_tools-1.48.2-cp38-cp38-win_amd64.whl", hash = "sha256:cabc8b0905cedbc3b2b7b2856334fa35cce3d4bc79ae241cacd8cca8940a5c85"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:e712a6d00606ad19abdeae852a7e521d6f6d0dcea843708fecf3a38be16a851e"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:e7e7668f89fd598c5469bb58e16bfd12b511d9947ccc75aec94da31f62bc3758"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:a415fbec67d4ff7efe88794cbe00cf548d0f0a5484cceffe0a0c89d47694c491"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d96e96ae7361aa51c9cd9c73b677b51f691f98df6086860fcc3c45852d96b0b0"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e20d7885a40e68a2bda92908acbabcdf3c14dd386c3845de73ba139e9df1f132"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8a5614251c46da07549e24f417cf989710250385e9d80deeafc53a0ee7df6325"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ace0035766fe01a1b096aa050be9f0a9f98402317e7aeff8bfe55349be32a407"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-win32.whl", hash = "sha256:4fa4300b1be59b046492ed3c5fdb59760bc6433f44c08f50de900f9552ec7461"}, + {file = "grpcio_tools-1.48.2-cp39-cp39-win_amd64.whl", hash = "sha256:0fb6c1c1e56eb26b224adc028a4204b6ad0f8b292efa28067dff273bbc8b27c4"}, +] + +[package.dependencies] +grpcio = ">=1.48.2" +protobuf = ">=3.12.0,<4.0dev" +setuptools = "*" + [[package]] name = "h11" version = "0.14.0" @@ -1433,6 +1493,32 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "h2" +version = "4.1.0" +description = "HTTP/2 State-Machine based protocol implementation" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"}, + {file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"}, +] + +[package.dependencies] +hpack = ">=4.0,<5" +hyperframe = ">=6.0,<7" + +[[package]] +name = "hpack" +version = "4.0.0" +description = "Pure-Python HPACK header compression" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"}, + {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, +] + [[package]] name = "html2text" version = "2020.1.16" @@ -1527,6 +1613,7 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" +h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} httpcore = "==1.*" idna = "*" sniffio = "*" @@ -1598,6 +1685,17 @@ files = [ [package.dependencies] pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} +[[package]] +name = "hyperframe" +version = "6.0.1" +description = "HTTP/2 framing layer for Python" +optional = true +python-versions = ">=3.6.1" +files = [ + {file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"}, + {file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"}, +] + [[package]] name = "identify" version = "2.5.36" @@ -4210,7 +4308,7 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" -optional = true +optional = false python-versions = "*" files = [ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, @@ -4254,6 +4352,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4288,6 +4387,32 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "qdrant-client" +version = "1.9.1" +description = "Client library for the Qdrant vector search engine" +optional = true +python-versions = ">=3.8" +files = [ + {file = "qdrant_client-1.9.1-py3-none-any.whl", hash = "sha256:b9b7e0e5c1a51410d8bb5106a869a51e12f92ab45a99030f27aba790553bd2c8"}, + {file = "qdrant_client-1.9.1.tar.gz", hash = "sha256:186b9c31d95aefe8f2db84b7746402d7365bd63b305550e530e31bde2002ce79"}, +] + +[package.dependencies] +grpcio = ">=1.41.0" +grpcio-tools = ">=1.41.0" +httpx = {version = ">=0.20.0", extras = ["http2"]} +numpy = [ + {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, +] +portalocker = ">=2.7.0,<3.0.0" +pydantic = ">=1.10.8" +urllib3 = ">=1.26.14,<3" + +[package.extras] +fastembed = ["fastembed (==0.2.6)"] + [[package]] name = "questionary" version = "2.0.1" @@ -6209,9 +6334,10 @@ local = ["llama-index-embeddings-huggingface"] milvus = ["pymilvus"] ollama = ["llama-index-embeddings-ollama"] postgres = ["pg8000", "pgvector"] +qdrant = ["qdrant-client"] server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "7e6032228cb8050f8d5d99c589a1e922dfd6e0d1626589671d415e3b82fba67e" +content-hash = "bfa14c084ae06f7d5ceb561406794d93f90808c20b098af13110f4ebe38c7928" diff --git a/pyproject.toml b/pyproject.toml index eacd8663..c67d52e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ docx2txt = "^0.8" sqlalchemy = "^2.0.25" pexpect = {version = "^4.9.0", optional = true} pyright = {version = "^1.1.347", optional = true} +qdrant-client = {version="^1.9.1", optional = true} pymilvus = {version ="^2.4.3", optional = true} python-box = "^7.1.1" sqlmodel = "^0.0.16" @@ -72,6 +73,7 @@ milvus = ["pymilvus"] dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort"] server = ["websockets", "fastapi", "uvicorn"] autogen = ["pyautogen"] +qdrant = ["qdrant-client"] ollama = ["llama-index-embeddings-ollama"] [tool.black] diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 53775098..85b866fa 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -17,7 +17,9 @@ from memgpt.settings import settings from memgpt.utils import get_human_text, get_persona_text from tests import TEST_MEMGPT_CONFIG -from .utils import create_config, wipe_config +from .utils import create_config, wipe_config, with_qdrant_storage + +GET_ALL_LIMIT = 1000 @pytest.fixture(autouse=True) @@ -39,7 +41,7 @@ def recreate_declarative_base(): @pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"]) -@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres", "milvus"]) +@pytest.mark.parametrize("passage_storage_connector", with_qdrant_storage(["chroma", "postgres", "milvus"])) def test_load_directory( metadata_storage_connector, passage_storage_connector, @@ -75,6 +77,10 @@ def test_load_directory( elif passage_storage_connector == "chroma": print("testing chroma passage storage") # nothing to do (should be config defaults) + elif passage_storage_connector == "qdrant": + print("Testing Qdrant passage storage") + TEST_MEMGPT_CONFIG.archival_storage_type = "qdrant" + TEST_MEMGPT_CONFIG.archival_storage_uri = "localhost:6333" elif passage_storage_connector == "milvus": print("Testing Milvus passage storage") TEST_MEMGPT_CONFIG.archival_storage_type = "milvus" @@ -157,7 +163,9 @@ def test_load_directory( passages_conn.delete_table() print("Re-creating tables...") passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id) - assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}" + assert ( + passages_conn.size() == 0 + ), f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all(limit=GET_ALL_LIMIT)]}" # test: load directory print("Loading directory") @@ -173,11 +181,12 @@ def test_load_directory( # test to see if contained in storage assert ( - len(passages_conn.get_all()) == passages_conn.size() - ), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all())}" - passages = passages_conn.get_all({"data_source": name}) + len(passages_conn.get_all(limit=GET_ALL_LIMIT)) == passages_conn.size() + ), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all(limit=GET_ALL_LIMIT))}" + passages = passages_conn.get_all({"data_source": name}, limit=GET_ALL_LIMIT) print("Source", [p.data_source for p in passages]) - print("All sources", [p.data_source for p in passages_conn.get_all()]) + print(passages_conn.get_all(limit=GET_ALL_LIMIT)) + print("All sources", [p.data_source for p in passages_conn.get_all(limit=GET_ALL_LIMIT)]) assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}" assert len(passages) == passages_conn.size(), f"Expected {passages_conn.size()} passages, but got {len(passages)}" assert [p.data_source == name for p in passages] @@ -198,7 +207,7 @@ def test_load_directory( # print("Deleting agent archival table...") # conn.delete_table() # conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id) - # assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}" + # assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all(limit=GET_ALL_LIMIT)]}" ## attach data # print("Attaching data...") @@ -206,11 +215,11 @@ def test_load_directory( ## test to see if contained in storage # assert len(passages) == conn.size() - # assert len(passages) == len(conn.get_all({"data_source": name})) + # assert len(passages) == len(conn.get_all({"data_source": name}, limit=GET_ALL_LIMIT)) ## test: delete source # passages_conn.delete({"data_source": name}) - # assert len(passages_conn.get_all({"data_source": name})) == 0 + # assert len(passages_conn.get_all({"data_source": name}, limit=GET_ALL_LIMIT)) == 0 # cleanup ms.delete_user(user.id) diff --git a/tests/test_storage.py b/tests/test_storage.py index 8b0ebf98..c5aa870d 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -17,6 +17,8 @@ from memgpt.utils import get_human_text, get_persona_text from tests import TEST_MEMGPT_CONFIG from tests.utils import create_config, wipe_config +from .utils import with_qdrant_storage + # Note: the database will filter out rows that do not correspond to agent1 and test_user by default. texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] start_date = datetime(2009, 10, 5, 18, 00) @@ -101,7 +103,7 @@ def recreate_declarative_base(): Base.metadata.clear() -@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqlite", "milvus"]) +@pytest.mark.parametrize("storage_connector", with_qdrant_storage(["postgres", "chroma", "sqlite", "milvus"])) # @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"]) # @pytest.mark.parametrize("storage_connector", ["postgres"]) @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY]) @@ -159,6 +161,12 @@ def test_storage( print("Skipping test, sqlite only supported for recall memory") return TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite" + if storage_connector == "qdrant": + if table_type == TableType.RECALL_MEMORY: + print("Skipping test, Qdrant only supports archival memory") + return + TEST_MEMGPT_CONFIG.archival_storage_type = "qdrant" + TEST_MEMGPT_CONFIG.archival_storage_uri = "localhost:6333" if storage_connector == "milvus": if table_type == TableType.RECALL_MEMORY: print("Skipping test, Milvus only supports archival memory") @@ -225,7 +233,7 @@ def test_storage( conn.insert_many(records[1:]) assert ( conn.size() == 2 - ), f"Expected 1 record, got {conn.size()}: {conn.get_all()}" # expect 2, since storage connector filters for agent1 + ), f"Expected 2 records, got {conn.size()}: {conn.get_all()}" # expect 2, since storage connector filters for agent1 # test: update # NOTE: only testing with messages diff --git a/tests/utils.py b/tests/utils.py index 35b4b085..c1c76f54 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,10 @@ import datetime import os +from importlib import util from typing import Dict, Iterator, List, Tuple +import requests + from memgpt.cli.cli import QuickstartChoice, quickstart from memgpt.data_sources.connectors import DataConnector from memgpt.data_types import Document @@ -121,3 +124,24 @@ def configure_memgpt(enable_openai=False, enable_azure=False): raise NotImplementedError else: configure_memgpt_localllm() + + +def qdrant_server_running() -> bool: + """Check if Qdrant server is running.""" + + try: + response = requests.get("http://localhost:6333", timeout=10.0) + response_json = response.json() + return response_json.get("title") == "qdrant - vector search engine" + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + return False + + +def with_qdrant_storage(storage: list[str]): + """If Qdrant server is running and `qdrant_client` is installed, + append `'qdrant'` to the storage list""" + + if util.find_spec("qdrant_client") is not None and qdrant_server_running(): + storage.append("qdrant") + + return storage