import lancedb import uuid from datetime import datetime from tqdm import tqdm from typing import Optional, List, Iterator, Dict from memgpt.config import MemGPTConfig from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.config import AgentConfig, MemGPTConfig from memgpt.constants import MEMGPT_DIR from memgpt.utils import printd from memgpt.data_types import Record, Message, Passage, Source from datetime import datetime from lancedb.pydantic import Vector, LanceModel """ Initial implementation - not complete """ def get_db_model(table_name: str, table_type: TableType): config = MemGPTConfig.load() if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: # create schema for archival memory class PassageModel(LanceModel): """Defines data model for storing Passages (consisting of text, embedding)""" id: uuid.UUID user_id: str text: str doc_id: str agent_id: str data_source: str embedding: Vector(config.embedding_dim) metadata_: Dict def __repr__(self): return f"" def to_record(self): return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at) """Create database model for table_name""" return SourceModel else: raise ValueError(f"Table type {table_type} not implemented") class LanceDBConnector(StorageConnector): """Storage via LanceDB""" # TODO: this should probably eventually be moved into a parent DB class def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): # TODO pass def generate_where_filter(self, filters: Dict) -> str: where_filters = [] for key, value in filters.items(): where_filters.append(f"{key}={value}") return where_filters.join(" AND ") @abstractmethod def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: # TODO pass @abstractmethod def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]: # TODO pass @abstractmethod def get(self, id: str) -> Optional[Record]: # TODO pass @abstractmethod def size(self, filters: Optional[Dict] = {}) -> int: # TODO pass @abstractmethod def insert(self, record: Record): # TODO pass @abstractmethod def insert_many(self, records: List[Record], show_progress=False): # TODO pass @abstractmethod def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: # TODO pass @abstractmethod def query_date(self, start_date, end_date): # TODO pass @abstractmethod def query_text(self, query): # TODO pass @abstractmethod def delete_table(self): # TODO pass @abstractmethod def delete(self, filters: Optional[Dict] = {}): # TODO pass @abstractmethod def save(self): # TODO pass