From da7ecbf103246c3335caff293b33875a6c572727 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 8 Feb 2024 20:46:21 -0800 Subject: [PATCH] fix: Remove document truncation and replace DB inserts with upserts (#973) --- .github/workflows/tests.yml | 9 +++- memgpt/agent_store/chroma.py | 6 +-- memgpt/agent_store/db.py | 80 ++++++++++++++++++++++++++++++------ memgpt/cli/cli_load.py | 7 ++-- memgpt/data_types.py | 28 +++++++------ memgpt/utils.py | 10 +++++ 6 files changed, 109 insertions(+), 31 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8fd9e0d5..391c9918 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | - poetry run pytest -s -vv -k "not test_storage" tests + poetry run pytest -s -vv -k "not test_storage and not test_server" tests - name: Run storage tests env: @@ -41,3 +41,10 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | poetry run pytest -s -vv tests/test_storage.py + + - name: Run server tests + env: + PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + poetry run pytest -s -vv tests/test_server.py diff --git a/memgpt/agent_store/chroma.py b/memgpt/agent_store/chroma.py index f9943863..e6578fa9 100644 --- a/memgpt/agent_store/chroma.py +++ b/memgpt/agent_store/chroma.py @@ -140,9 +140,9 @@ class ChromaStorageConnector(StorageConnector): metadata.pop("embedding") if "created_at" in metadata: metadata["created_at"] = datetime_to_timestamp(metadata["created_at"]) - if "metadata" in metadata and metadata["metadata"] is not None: - record_metadata = dict(metadata["metadata"]) - metadata.pop("metadata") + if "metadata_" in metadata and metadata["metadata_"] is not None: + record_metadata = dict(metadata["metadata_"]) + metadata.pop("metadata_") else: record_metadata = {} metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index 38cfef8c..fea91bd3 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -136,7 +136,7 @@ def get_db_model( id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) # id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) user_id = Column(CommonUUID, nullable=False) - text = Column(String, nullable=False) + text = Column(String) doc_id = Column(CommonUUID) agent_id = Column(CommonUUID) data_source = Column(String) # agent_name if agent, data_source name if from data source @@ -167,7 +167,7 @@ def get_db_model( id=self.id, data_source=self.data_source, agent_id=self.agent_id, - metadata=self.metadata_, + metadata_=self.metadata_, ) """Create database model for table_name""" @@ -351,18 +351,10 @@ class SQLStorageConnector(StorageConnector): return session.query(self.db_model).filter(*filters).count() def insert(self, record: Record): - db_record = self.db_model(**vars(record)) - with self.session_maker() as session: - session.add(db_record) - session.commit() + raise NotImplementedError def insert_many(self, records: List[RecordType], show_progress=False): - iterable = tqdm(records) if show_progress else records - with self.session_maker() as session: - for record in iterable: - db_record = self.db_model(**vars(record)) - session.add(db_record) - session.commit() + raise NotImplementedError def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]: raise NotImplementedError("Vector query not implemented for SQLStorageConnector") @@ -466,6 +458,38 @@ class PostgresStorageConnector(SQLStorageConnector): records = [result.to_record() for result in results] return records + def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False): + from sqlalchemy.dialects.postgresql import insert + + # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) + if len(records) == 0: + return + if isinstance(records[0], Passage): + with self.engine.connect() as conn: + db_records = [vars(record) for record in records] + # print("records", db_records) + stmt = insert(self.db_model.__table__).values(db_records) + # print(stmt) + if exists_ok: + upsert_stmt = stmt.on_conflict_do_update( + index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column + ) + print(upsert_stmt) + conn.execute(upsert_stmt) + else: + conn.execute(stmt) + conn.commit() + else: + with self.session_maker() as session: + iterable = tqdm(records) if show_progress else records + for record in iterable: + db_record = self.db_model(**vars(record)) + session.add(db_record) + session.commit() + + def insert(self, record: Record, exists_ok=True): + self.insert_many([record], exists_ok=exists_ok) + class SQLLiteStorageConnector(SQLStorageConnector): def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): @@ -494,3 +518,35 @@ class SQLLiteStorageConnector(SQLStorageConnector): sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le) sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) + + def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False): + from sqlalchemy.dialects.sqlite import insert + + # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) + if len(records) == 0: + return + if isinstance(records[0], Passage): + with self.engine.connect() as conn: + db_records = [vars(record) for record in records] + # print("records", db_records) + stmt = insert(self.db_model.__table__).values(db_records) + # print(stmt) + if exists_ok: + upsert_stmt = stmt.on_conflict_do_update( + index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column + ) + print(upsert_stmt) + conn.execute(upsert_stmt) + else: + conn.execute(stmt) + conn.commit() + else: + with self.session_maker() as session: + iterable = tqdm(records) if show_progress else records + for record in iterable: + db_record = self.db_model(**vars(record)) + session.add(db_record) + session.commit() + + def insert(self, record: Record, exists_ok=True): + self.insert_many([record], exists_ok=exists_ok) diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index bd8811f1..1affa4ac 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -56,9 +56,10 @@ def insert_passages_into_source(passages: List[Passage], source_name: str, user_ # add and save all passages storage.insert_many(passages) - - assert orig_size + len(passages) == storage.size(), f"Expected {orig_size + len(passages)} passages, got {storage.size()}" storage.save() + num_new_passages = storage.size() - orig_size + print(f"Updated {len(passages)}, inserted {num_new_passages} new passages into {source_name}") + print("Total passages in source:", storage.size()) def store_docs(name, docs, user_id=None, show_progress=True): @@ -129,7 +130,7 @@ def store_docs(name, docs, user_id=None, show_progress=True): text=text, data_source=name, embedding=node.embedding, - metadata=None, + metadata_=None, embedding_dim=config.default_embedding_config.embedding_dim, embedding_model=config.default_embedding_config.embedding_model, ) diff --git a/memgpt/data_types.py b/memgpt/data_types.py index dfa48dda..8dda6f44 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -6,7 +6,8 @@ from typing import Optional, List, Dict, TypeVar import numpy as np from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM -from memgpt.utils import get_local_time, format_datetime, get_utc_time +from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string +from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string from memgpt.models import chat_completion_response @@ -118,9 +119,6 @@ class Message(Record): assert tool_call_id is None self.tool_call_id = tool_call_id - # def __repr__(self): - # pass - @staticmethod def dict_to_message( user_id: uuid.UUID, @@ -273,17 +271,18 @@ class Message(Record): class Document(Record): """A document represent a document loaded into MemGPT, which is broken down into passages.""" - def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None): + def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None): + if id is None: + # by default, generate ID as a hash of the text (avoid duplicates) + self.id = create_uuid_from_string("".join([text, str(user_id)])) + else: + self.id = id super().__init__(id) self.user_id = user_id self.text = text - self.document_id = document_id self.data_source = data_source # TODO: add optional embedding? - # def __repr__(self) -> str: - # pass - class Passage(Record): """A passage is a single unit of memory, and a standard format accross all storage backends. @@ -302,15 +301,20 @@ class Passage(Record): data_source: Optional[str] = None, # None if created by agent doc_id: Optional[uuid.UUID] = None, id: Optional[uuid.UUID] = None, - metadata: Optional[dict] = {}, + metadata_: Optional[dict] = {}, ): - super().__init__(id) + if id is None: + # by default, generate ID as a hash of the text (avoid duplicates) + self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)])) + else: + self.id = id + super().__init__(self.id) self.user_id = user_id self.agent_id = agent_id self.text = text self.data_source = data_source self.doc_id = doc_id - self.metadata = metadata + self.metadata_ = metadata_ # pad and store embeddings if isinstance(embedding, list): diff --git a/memgpt/utils.py b/memgpt/utils.py index 8ca73a12..5931147c 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -10,6 +10,7 @@ import subprocess import uuid import sys import io +import hashlib from typing import List import inspect from functools import wraps @@ -1009,3 +1010,12 @@ def extract_date_from_timestamp(timestamp): # Extracts the date (ignoring the time and timezone) match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp) return match.group(1) if match else None + + +def create_uuid_from_string(val: str): + """ + Generate consistent UUID from a string + from: https://samos-it.com/posts/python-create-uuid-from-random-string-of-words.html + """ + hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest() + return uuid.UUID(hex=hex_string)