diff --git a/memgpt/connectors/lancedb.py b/memgpt/connectors/lancedb.py new file mode 100644 index 00000000..5664c581 --- /dev/null +++ b/memgpt/connectors/lancedb.py @@ -0,0 +1,199 @@ +import lancedb +import uuid +from datetime import datetime +from tqdm import tqdm +from typing import Optional, List, Iterator, Dict + +from memgpt.config import MemGPTConfig +from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.config import AgentConfig, MemGPTConfig +from memgpt.constants import MEMGPT_DIR +from memgpt.utils import printd +from memgpt.data_types import Record, Message, Passage, Source + +from datetime import datetime + +from lancedb.pydantic import Vector, LanceModel + +""" Initial implementation - not complete """ + + +def get_db_model(table_name: str, table_type: TableType): + config = MemGPTConfig.load() + + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: + # create schema for archival memory + class PassageModel(LanceModel): + """Defines data model for storing Passages (consisting of text, embedding)""" + + id: uuid.UUID + user_id: str + text: str + doc_id: str + agent_id: str + data_source: str + embedding: Vector(config.embedding_dim) + metadata_: Dict + + def __repr__(self): + return f"" + + def to_record(self): + return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at) + + """Create database model for table_name""" + return SourceModel + else: + raise ValueError(f"Table type {table_type} not implemented") + + +class LanceDBConnector(StorageConnector): + """Storage via LanceDB""" + + # TODO: this should probably eventually be moved into a parent DB class + + def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): + # TODO + pass + + def generate_where_filter(self, filters: Dict) -> str: + where_filters = [] + for key, value in filters.items(): + where_filters.append(f"{key}={value}") + return where_filters.join(" AND ") + + @abstractmethod + def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: + # TODO + pass + + @abstractmethod + def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]: + # TODO + pass + + @abstractmethod + def get(self, id: str) -> Optional[Record]: + # TODO + pass + + @abstractmethod + def size(self, filters: Optional[Dict] = {}) -> int: + # TODO + pass + + @abstractmethod + def insert(self, record: Record): + # TODO + pass + + @abstractmethod + def insert_many(self, records: List[Record], show_progress=True): + # TODO + pass + + @abstractmethod + def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: + # TODO + pass + + @abstractmethod + def query_date(self, start_date, end_date): + # TODO + pass + + @abstractmethod + def query_text(self, query): + # TODO + pass + + @abstractmethod + def delete_table(self): + # TODO + pass + + @abstractmethod + def delete(self, filters: Optional[Dict] = {}): + # TODO + pass + + @abstractmethod + def save(self): + # TODO + pass diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 0564b12e..b2c0462c 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -190,15 +190,23 @@ class StorageConnector: raise NotImplementedError(f"Storage type {storage_type} not implemented") @abstractmethod - def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: + def get_filters(self, filters: Optional[Dict] = {}): pass @abstractmethod - def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]: + def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: pass @abstractmethod - def get(self, id: str) -> Record: + def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]: + pass + + @abstractmethod + def get(self, id: str) -> Optional[Record]: + pass + + @abstractmethod + def size(self, filters: Optional[Dict] = {}) -> int: pass @abstractmethod @@ -206,7 +214,7 @@ class StorageConnector: pass @abstractmethod - def insert_many(self, records: List[Record]): + def insert_many(self, records: List[Record], show_progress=True): pass @abstractmethod @@ -214,11 +222,21 @@ class StorageConnector: pass @abstractmethod - def save(self): - """Save state of storage connector""" + def query_date(self, start_date, end_date): pass @abstractmethod - def size(self, filters: Optional[Dict] = {}) -> int: - """Get number of passages (text/embedding pairs) in storage""" + def query_text(self, query): + pass + + @abstractmethod + def delete_table(self): + pass + + @abstractmethod + def delete(self, filters: Optional[Dict] = {}): + pass + + @abstractmethod + def save(self): pass