feat: Qdrant storage connector (#1023)

This commit is contained in:
Anush
2024-06-05 11:54:25 +05:30
committed by GitHub
parent 16ec0abb4a
commit ab0e6e5805
10 changed files with 426 additions and 18 deletions

View File

@@ -13,6 +13,13 @@ jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 15
services:
qdrant:
image: qdrant/qdrant
ports:
- 6333:6333
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -36,19 +36,19 @@ To run the Postgres backend, you will need a URI to a Postgres database that sup
3. Configure the environment for `pgvector`. You can either:
- Add the following line to your shell profile (e.g., `~/.bashrc`, `~/.zshrc`):
```sh
export MEMGPT_PGURI=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
```
- Or create a `.env` file in the root project directory with:
```sh
MEMGPT_PGURI=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
```
4. Run the script from the root project directory:
```sh
bash db/run_postgres.sh
```
@@ -105,6 +105,20 @@ memgpt configure
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/)
## Qdrant
To enable the Qdrant backend, make sure to install the required dependencies with:
```sh
pip install 'pymemgpt[qdrant]'
```
You can configure Qdrant with an in-memory instance or a server using the `memgpt configure` command. You can set an API key for authentication with a Qdrant server using the `QDRANT_API_KEY` environment variable. Learn more about setting up Qdrant [here](https://qdrant.tech/documentation/guides/installation/).
```sh
? Select Qdrant backend: server
? Enter the Qdrant instance URI (Default: localhost:6333): localhost:6333
```
## Milvus

View File

@@ -0,0 +1,201 @@
import os
import uuid
from copy import deepcopy
from typing import Dict, Iterator, List, Optional, cast
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.data_types import Passage, Record, RecordType
from memgpt.utils import datetime_to_timestamp, timestamp_to_datetime
TEXT_PAYLOAD_KEY = "text_content"
METADATA_PAYLOAD_KEY = "metadata"
class QdrantStorageConnector(StorageConnector):
"""Storage via Qdrant"""
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
try:
from qdrant_client import QdrantClient, models
except ImportError as e:
raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e
assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory"
if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2:
host, port = config.archival_storage_uri.split(":")
self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY"))
elif config.archival_storage_path:
self.qdrant_client = QdrantClient(path=config.archival_storage_path)
else:
raise ValueError("Qdrant storage requires either a URI or a path to the storage configured")
if not self.qdrant_client.collection_exists(self.table_name):
self.qdrant_client.create_collection(
collection_name=self.table_name,
vectors_config=models.VectorParams(
size=MAX_EMBEDDING_DIM,
distance=models.Distance.COSINE,
),
)
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
from qdrant_client import grpc
filters = self.get_qdrant_filters(filters)
next_offset = None
stop_scrolling = False
while not stop_scrolling:
results, next_offset = self.qdrant_client.scroll(
collection_name=self.table_name,
scroll_filter=filters,
limit=page_size,
offset=next_offset,
with_payload=True,
with_vectors=True,
)
stop_scrolling = next_offset is None or (
isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == ""
)
yield self.to_records(results)
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
if self.size(filters) == 0:
return []
filters = self.get_qdrant_filters(filters)
results, _ = self.qdrant_client.scroll(
self.table_name,
scroll_filter=filters,
limit=limit,
with_payload=True,
with_vectors=True,
)
return self.to_records(results)
def get(self, id: uuid.UUID) -> Optional[RecordType]:
results = self.qdrant_client.retrieve(
collection_name=self.table_name,
ids=[str(id)],
with_payload=True,
with_vectors=True,
)
if not results:
return None
return self.to_records(results)[0]
def insert(self, record: Record):
points = self.to_points([record])
self.qdrant_client.upsert(self.table_name, points=points)
def insert_many(self, records: List[RecordType], show_progress=False):
points = self.to_points(records)
self.qdrant_client.upsert(self.table_name, points=points)
def delete(self, filters: Optional[Dict] = {}):
filters = self.get_qdrant_filters(filters)
self.qdrant_client.delete(self.table_name, points_selector=filters)
def delete_table(self):
self.qdrant_client.delete_collection(self.table_name)
self.qdrant_client.close()
def size(self, filters: Optional[Dict] = {}) -> int:
filters = self.get_qdrant_filters(filters)
return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count
def close(self):
self.qdrant_client.close()
def query(
self,
query: str,
query_vec: List[float],
top_k: int = 10,
filters: Optional[Dict] = {},
) -> List[RecordType]:
filters = self.get_filters(filters)
results = self.qdrant_client.search(
self.table_name,
query_vector=query_vec,
query_filter=filters,
limit=top_k,
with_payload=True,
with_vectors=True,
)
return self.to_records(results)
def to_records(self, records: list) -> List[RecordType]:
parsed_records = []
for record in records:
record = deepcopy(record)
metadata = record.payload[METADATA_PAYLOAD_KEY]
text = record.payload[TEXT_PAYLOAD_KEY]
_id = metadata.pop("id")
embedding = record.vector
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = uuid.UUID(value)
elif key == "created_at":
metadata[key] = timestamp_to_datetime(value)
parsed_records.append(
cast(
RecordType,
self.type(
text=text,
embedding=embedding,
id=uuid.UUID(_id),
**metadata,
),
)
)
return parsed_records
def to_points(self, records: List[RecordType]):
from qdrant_client import models
assert all(isinstance(r, Passage) for r in records)
points = []
records = list(set(records))
for record in records:
record = vars(record)
_id = record.pop("id")
text = record.pop("text", "")
embedding = record.pop("embedding", {})
record_metadata = record.pop("metadata_", None) or {}
if "created_at" in record:
record["created_at"] = datetime_to_timestamp(record["created_at"])
metadata = {key: value for key, value in record.items() if value is not None}
metadata = {
**metadata,
**record_metadata,
"id": str(_id),
}
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = str(value)
points.append(
models.PointStruct(
id=str(_id),
vector=embedding,
payload={
TEXT_PAYLOAD_KEY: text,
METADATA_PAYLOAD_KEY: metadata,
},
)
)
return points
def get_qdrant_filters(self, filters: Optional[Dict] = {}):
from qdrant_client import models
filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
must_conditions = []
for key, value in filter_conditions.items():
match_value = str(value) if key in self.uuid_fields else value
field_condition = models.FieldCondition(
key=f"{METADATA_PAYLOAD_KEY}.{key}",
match=models.MatchValue(value=match_value),
)
must_conditions.append(field_condition)
return models.Filter(must=must_conditions)

View File

@@ -100,6 +100,10 @@ class StorageConnector:
return ChromaStorageConnector(table_type, config, user_id, agent_id)
elif storage_type == "qdrant":
from memgpt.agent_store.qdrant import QdrantStorageConnector
return QdrantStorageConnector(table_type, config, user_id, agent_id)
# TODO: add back
# elif storage_type == "lancedb":
# from memgpt.agent_store.db import LanceDBConnector

View File

@@ -913,7 +913,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredentials):
# Configure archival storage backend
archival_storage_options = ["postgres", "chroma", "milvus"]
archival_storage_options = ["postgres", "chroma", "milvus", "qdrant"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
@@ -950,6 +950,19 @@ def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredenti
if chroma_type == "persistent":
archival_storage_path = os.path.join(MEMGPT_DIR, "chroma")
if archival_storage_type == "qdrant":
qdrant_type = questionary.select("Select Qdrant backend:", ["local", "server"], default="local").ask()
if qdrant_type is None:
raise KeyboardInterrupt
if qdrant_type == "server":
archival_storage_uri = questionary.text(
"Enter the Qdrant instance URI (Default: localhost:6333):", default="localhost:6333"
).ask()
if archival_storage_uri is None:
raise KeyboardInterrupt
if qdrant_type == "local":
archival_storage_path = os.path.join(MEMGPT_DIR, "qdrant")
if archival_storage_type == "milvus":
default_milvus_uri = archival_storage_path = os.path.join(MEMGPT_DIR, "milvus.db")
archival_storage_uri = questionary.text(

130
poetry.lock generated
View File

@@ -1422,6 +1422,66 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.63.0)"]
[[package]]
name = "grpcio-tools"
version = "1.48.2"
description = "Protobuf code generator for gRPC"
optional = true
python-versions = ">=3.6"
files = [
{file = "grpcio-tools-1.48.2.tar.gz", hash = "sha256:8902a035708555cddbd61b5467cea127484362decc52de03f061a1a520fe90cd"},
{file = "grpcio_tools-1.48.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:92acc3e10ba2b0dcb90a88ae9fe1cc0ffba6868545207e4ff20ca95284f8e3c9"},
{file = "grpcio_tools-1.48.2-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e5bb396d63495667d4df42e506eed9d74fc9a51c99c173c04395fe7604c848f1"},
{file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:84a84d601a238572d049d3108e04fe4c206536e81076d56e623bd525a1b38def"},
{file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70564521e86a0de35ea9ac6daecff10cb46860aec469af65869974807ce8e98b"},
{file = "grpcio_tools-1.48.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdbbe63f6190187de5946891941629912ac8196701ed2253fa91624a397822ec"},
{file = "grpcio_tools-1.48.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae56f133b05b7e5d780ef7e032dd762adad7f3dc8f64adb43ff5bfabd659f435"},
{file = "grpcio_tools-1.48.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f0feb4f2b777fa6377e977faa89c26359d4f31953de15e035505b92f41aa6906"},
{file = "grpcio_tools-1.48.2-cp310-cp310-win32.whl", hash = "sha256:80f450272316ca0924545f488c8492649ca3aeb7044d4bf59c426dcdee527f7c"},
{file = "grpcio_tools-1.48.2-cp310-cp310-win_amd64.whl", hash = "sha256:21ff50e321736eba22210bf9b94e05391a9ac345f26e7df16333dc75d63e74fb"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-linux_armv7l.whl", hash = "sha256:d598ccde6338b2cfbb3124f34c95f03394209013f9b1ed4a5360a736853b1c27"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:a43d26714933f23de93ea0bf9c86c66a6ede709b8ca32e357f9e2181703e64ae"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_aarch64.whl", hash = "sha256:55fdebc73fb580717656b1bafa4f8eca448726a7aa22726a6c0a7895d2f0f088"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8588819b22d0de3aa1951e1991cc3e4b9aa105eecf6e3e24eb0a2fc8ab958b3e"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9771d4d317dca029dfaca7ec9282d8afe731c18bc536ece37fd39b8a974cc331"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d886a9e052a038642b3af5d18e6f2085d1656d9788e202dc23258cf3a751e7ca"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d77e8b1613876e0d8fd17709509d4ceba13492816426bd156f7e88a4c47e7158"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-win32.whl", hash = "sha256:dcaaecdd5e847de5c1d533ea91522bf56c9e6b2dc98cdc0d45f0a1c26e846ea2"},
{file = "grpcio_tools-1.48.2-cp36-cp36m-win_amd64.whl", hash = "sha256:0119aabd9ceedfdf41b56b9fdc8284dd85a7f589d087f2694d743f346a368556"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:189be2a9b672300ca6845d94016bdacc052fdbe9d1ae9e85344425efae2ff8ef"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:9443f5c30bac449237c3cf99da125f8d6e6c01e17972bc683ee73b75dea95573"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:e0403e095b343431195db1305248b50019ad55d3dd310254431af87e14ef83a2"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5410d6b601d1404835e34466bd8aee37213489b36ee1aad2276366e265ff29d4"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51be91b7c7056ff9ee48b1eccd4a2840b0126230803a5e09dfc082a5b16a91c1"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:516eedd5eb7af6326050bc2cfceb3a977b9cc1144f283c43cc4956905285c912"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d18599ab572b2f15a8f3db49503272d1bb4fcabb4b4d1214ef03aca1816b20a0"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-win32.whl", hash = "sha256:d18ef2adc05a8ef9e58ac46357f6d4ce7e43e077c7eda0a4425773461f9d0e6e"},
{file = "grpcio_tools-1.48.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6d9753944e5a6b6b78b76ce9d2ae0fe3f748008c1849deb7fadcb64489d6553b"},
{file = "grpcio_tools-1.48.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:3c8749dca04a8d302862ceeb1dfbdd071ee13b281395975f24405a347e5baa57"},
{file = "grpcio_tools-1.48.2-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:7307dd2408b82ea545ae63502ec03036b025f449568556ea9a056e06129a7a4e"},
{file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:072234859f6069dc43a6be8ad6b7d682f4ba1dc2e2db2ebf5c75f62eee0f6dfb"},
{file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6cc298fbfe584de8876a85355efbcf796dfbcfac5948c9560f5df82e79336e2a"},
{file = "grpcio_tools-1.48.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f75973a42c710999acd419968bc79f00327e03e855bbe82c6529e003e49af660"},
{file = "grpcio_tools-1.48.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f766050e491d0b3203b6b85638015f543816a2eb7d089fc04e86e00f6de0e31d"},
{file = "grpcio_tools-1.48.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8e0d74403484eb77e8df2566a64b8b0b484b5c87903678c381634dd72f252d5e"},
{file = "grpcio_tools-1.48.2-cp38-cp38-win32.whl", hash = "sha256:cb75bac0cd43858cb759ef103fe68f8c540cb58b63dda127e710228fec3007b8"},
{file = "grpcio_tools-1.48.2-cp38-cp38-win_amd64.whl", hash = "sha256:cabc8b0905cedbc3b2b7b2856334fa35cce3d4bc79ae241cacd8cca8940a5c85"},
{file = "grpcio_tools-1.48.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:e712a6d00606ad19abdeae852a7e521d6f6d0dcea843708fecf3a38be16a851e"},
{file = "grpcio_tools-1.48.2-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:e7e7668f89fd598c5469bb58e16bfd12b511d9947ccc75aec94da31f62bc3758"},
{file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:a415fbec67d4ff7efe88794cbe00cf548d0f0a5484cceffe0a0c89d47694c491"},
{file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d96e96ae7361aa51c9cd9c73b677b51f691f98df6086860fcc3c45852d96b0b0"},
{file = "grpcio_tools-1.48.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e20d7885a40e68a2bda92908acbabcdf3c14dd386c3845de73ba139e9df1f132"},
{file = "grpcio_tools-1.48.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8a5614251c46da07549e24f417cf989710250385e9d80deeafc53a0ee7df6325"},
{file = "grpcio_tools-1.48.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ace0035766fe01a1b096aa050be9f0a9f98402317e7aeff8bfe55349be32a407"},
{file = "grpcio_tools-1.48.2-cp39-cp39-win32.whl", hash = "sha256:4fa4300b1be59b046492ed3c5fdb59760bc6433f44c08f50de900f9552ec7461"},
{file = "grpcio_tools-1.48.2-cp39-cp39-win_amd64.whl", hash = "sha256:0fb6c1c1e56eb26b224adc028a4204b6ad0f8b292efa28067dff273bbc8b27c4"},
]
[package.dependencies]
grpcio = ">=1.48.2"
protobuf = ">=3.12.0,<4.0dev"
setuptools = "*"
[[package]]
name = "h11"
version = "0.14.0"
@@ -1433,6 +1493,32 @@ files = [
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
]
[[package]]
name = "h2"
version = "4.1.0"
description = "HTTP/2 State-Machine based protocol implementation"
optional = true
python-versions = ">=3.6.1"
files = [
{file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"},
{file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"},
]
[package.dependencies]
hpack = ">=4.0,<5"
hyperframe = ">=6.0,<7"
[[package]]
name = "hpack"
version = "4.0.0"
description = "Pure-Python HPACK header compression"
optional = true
python-versions = ">=3.6.1"
files = [
{file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"},
{file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"},
]
[[package]]
name = "html2text"
version = "2020.1.16"
@@ -1527,6 +1613,7 @@ files = [
[package.dependencies]
anyio = "*"
certifi = "*"
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
httpcore = "==1.*"
idna = "*"
sniffio = "*"
@@ -1598,6 +1685,17 @@ files = [
[package.dependencies]
pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
[[package]]
name = "hyperframe"
version = "6.0.1"
description = "HTTP/2 framing layer for Python"
optional = true
python-versions = ">=3.6.1"
files = [
{file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"},
{file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"},
]
[[package]]
name = "identify"
version = "2.5.36"
@@ -4210,7 +4308,7 @@ files = [
name = "pywin32"
version = "306"
description = "Python for Window Extensions"
optional = true
optional = false
python-versions = "*"
files = [
{file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"},
@@ -4254,6 +4352,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -4288,6 +4387,32 @@ files = [
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
]
[[package]]
name = "qdrant-client"
version = "1.9.1"
description = "Client library for the Qdrant vector search engine"
optional = true
python-versions = ">=3.8"
files = [
{file = "qdrant_client-1.9.1-py3-none-any.whl", hash = "sha256:b9b7e0e5c1a51410d8bb5106a869a51e12f92ab45a99030f27aba790553bd2c8"},
{file = "qdrant_client-1.9.1.tar.gz", hash = "sha256:186b9c31d95aefe8f2db84b7746402d7365bd63b305550e530e31bde2002ce79"},
]
[package.dependencies]
grpcio = ">=1.41.0"
grpcio-tools = ">=1.41.0"
httpx = {version = ">=0.20.0", extras = ["http2"]}
numpy = [
{version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""},
{version = ">=1.26", markers = "python_version >= \"3.12\""},
]
portalocker = ">=2.7.0,<3.0.0"
pydantic = ">=1.10.8"
urllib3 = ">=1.26.14,<3"
[package.extras]
fastembed = ["fastembed (==0.2.6)"]
[[package]]
name = "questionary"
version = "2.0.1"
@@ -6209,9 +6334,10 @@ local = ["llama-index-embeddings-huggingface"]
milvus = ["pymilvus"]
ollama = ["llama-index-embeddings-ollama"]
postgres = ["pg8000", "pgvector"]
qdrant = ["qdrant-client"]
server = ["fastapi", "uvicorn", "websockets"]
[metadata]
lock-version = "2.0"
python-versions = "<3.13,>=3.10"
content-hash = "7e6032228cb8050f8d5d99c589a1e922dfd6e0d1626589671d415e3b82fba67e"
content-hash = "bfa14c084ae06f7d5ceb561406794d93f90808c20b098af13110f4ebe38c7928"

View File

@@ -46,6 +46,7 @@ docx2txt = "^0.8"
sqlalchemy = "^2.0.25"
pexpect = {version = "^4.9.0", optional = true}
pyright = {version = "^1.1.347", optional = true}
qdrant-client = {version="^1.9.1", optional = true}
pymilvus = {version ="^2.4.3", optional = true}
python-box = "^7.1.1"
sqlmodel = "^0.0.16"
@@ -72,6 +73,7 @@ milvus = ["pymilvus"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort"]
server = ["websockets", "fastapi", "uvicorn"]
autogen = ["pyautogen"]
qdrant = ["qdrant-client"]
ollama = ["llama-index-embeddings-ollama"]
[tool.black]

View File

@@ -17,7 +17,9 @@ from memgpt.settings import settings
from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG
from .utils import create_config, wipe_config
from .utils import create_config, wipe_config, with_qdrant_storage
GET_ALL_LIMIT = 1000
@pytest.fixture(autouse=True)
@@ -39,7 +41,7 @@ def recreate_declarative_base():
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres", "milvus"])
@pytest.mark.parametrize("passage_storage_connector", with_qdrant_storage(["chroma", "postgres", "milvus"]))
def test_load_directory(
metadata_storage_connector,
passage_storage_connector,
@@ -75,6 +77,10 @@ def test_load_directory(
elif passage_storage_connector == "chroma":
print("testing chroma passage storage")
# nothing to do (should be config defaults)
elif passage_storage_connector == "qdrant":
print("Testing Qdrant passage storage")
TEST_MEMGPT_CONFIG.archival_storage_type = "qdrant"
TEST_MEMGPT_CONFIG.archival_storage_uri = "localhost:6333"
elif passage_storage_connector == "milvus":
print("Testing Milvus passage storage")
TEST_MEMGPT_CONFIG.archival_storage_type = "milvus"
@@ -157,7 +163,9 @@ def test_load_directory(
passages_conn.delete_table()
print("Re-creating tables...")
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"
assert (
passages_conn.size() == 0
), f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all(limit=GET_ALL_LIMIT)]}"
# test: load directory
print("Loading directory")
@@ -173,11 +181,12 @@ def test_load_directory(
# test to see if contained in storage
assert (
len(passages_conn.get_all()) == passages_conn.size()
), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all())}"
passages = passages_conn.get_all({"data_source": name})
len(passages_conn.get_all(limit=GET_ALL_LIMIT)) == passages_conn.size()
), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all(limit=GET_ALL_LIMIT))}"
passages = passages_conn.get_all({"data_source": name}, limit=GET_ALL_LIMIT)
print("Source", [p.data_source for p in passages])
print("All sources", [p.data_source for p in passages_conn.get_all()])
print(passages_conn.get_all(limit=GET_ALL_LIMIT))
print("All sources", [p.data_source for p in passages_conn.get_all(limit=GET_ALL_LIMIT)])
assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}"
assert len(passages) == passages_conn.size(), f"Expected {passages_conn.size()} passages, but got {len(passages)}"
assert [p.data_source == name for p in passages]
@@ -198,7 +207,7 @@ def test_load_directory(
# print("Deleting agent archival table...")
# conn.delete_table()
# conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
# assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
# assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all(limit=GET_ALL_LIMIT)]}"
## attach data
# print("Attaching data...")
@@ -206,11 +215,11 @@ def test_load_directory(
## test to see if contained in storage
# assert len(passages) == conn.size()
# assert len(passages) == len(conn.get_all({"data_source": name}))
# assert len(passages) == len(conn.get_all({"data_source": name}, limit=GET_ALL_LIMIT))
## test: delete source
# passages_conn.delete({"data_source": name})
# assert len(passages_conn.get_all({"data_source": name})) == 0
# assert len(passages_conn.get_all({"data_source": name}, limit=GET_ALL_LIMIT)) == 0
# cleanup
ms.delete_user(user.id)

View File

@@ -17,6 +17,8 @@ from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG
from tests.utils import create_config, wipe_config
from .utils import with_qdrant_storage
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
start_date = datetime(2009, 10, 5, 18, 00)
@@ -101,7 +103,7 @@ def recreate_declarative_base():
Base.metadata.clear()
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqlite", "milvus"])
@pytest.mark.parametrize("storage_connector", with_qdrant_storage(["postgres", "chroma", "sqlite", "milvus"]))
# @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"])
# @pytest.mark.parametrize("storage_connector", ["postgres"])
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
@@ -159,6 +161,12 @@ def test_storage(
print("Skipping test, sqlite only supported for recall memory")
return
TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite"
if storage_connector == "qdrant":
if table_type == TableType.RECALL_MEMORY:
print("Skipping test, Qdrant only supports archival memory")
return
TEST_MEMGPT_CONFIG.archival_storage_type = "qdrant"
TEST_MEMGPT_CONFIG.archival_storage_uri = "localhost:6333"
if storage_connector == "milvus":
if table_type == TableType.RECALL_MEMORY:
print("Skipping test, Milvus only supports archival memory")
@@ -225,7 +233,7 @@ def test_storage(
conn.insert_many(records[1:])
assert (
conn.size() == 2
), f"Expected 1 record, got {conn.size()}: {conn.get_all()}" # expect 2, since storage connector filters for agent1
), f"Expected 2 records, got {conn.size()}: {conn.get_all()}" # expect 2, since storage connector filters for agent1
# test: update
# NOTE: only testing with messages

View File

@@ -1,7 +1,10 @@
import datetime
import os
from importlib import util
from typing import Dict, Iterator, List, Tuple
import requests
from memgpt.cli.cli import QuickstartChoice, quickstart
from memgpt.data_sources.connectors import DataConnector
from memgpt.data_types import Document
@@ -121,3 +124,24 @@ def configure_memgpt(enable_openai=False, enable_azure=False):
raise NotImplementedError
else:
configure_memgpt_localllm()
def qdrant_server_running() -> bool:
"""Check if Qdrant server is running."""
try:
response = requests.get("http://localhost:6333", timeout=10.0)
response_json = response.json()
return response_json.get("title") == "qdrant - vector search engine"
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
return False
def with_qdrant_storage(storage: list[str]):
"""If Qdrant server is running and `qdrant_client` is installed,
append `'qdrant'` to the storage list"""
if util.find_spec("qdrant_client") is not None and qdrant_server_running():
storage.append("qdrant")
return storage