Add skeleton code for lance integration

This commit is contained in:
Sarah Wooders
2023-12-26 12:16:09 +04:00
parent d4ddf549e3
commit 2a4df2263f
2 changed files with 225 additions and 8 deletions

View File

@@ -0,0 +1,199 @@
import lancedb
import uuid
from datetime import datetime
from tqdm import tqdm
from typing import Optional, List, Iterator, Dict
from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.utils import printd
from memgpt.data_types import Record, Message, Passage, Source
from datetime import datetime
from lancedb.pydantic import Vector, LanceModel
""" Initial implementation - not complete """
def get_db_model(table_name: str, table_type: TableType):
config = MemGPTConfig.load()
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
# create schema for archival memory
class PassageModel(LanceModel):
"""Defines data model for storing Passages (consisting of text, embedding)"""
id: uuid.UUID
user_id: str
text: str
doc_id: str
agent_id: str
data_source: str
embedding: Vector(config.embedding_dim)
metadata_: Dict
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
doc_id=self.doc_id,
user_id=self.user_id,
id=self.id,
data_source=self.data_source,
agent_id=self.agent_id,
metadata=self.metadata_,
)
return PassageModel
elif table_type == TableType.RECALL_MEMORY:
class MessageModel(LanceModel):
"""Defines data model for storing Message objects"""
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
id: uuid.UUID
user_id: str
agent_id: str
# openai info
role: str
text: str
model: str
user: str
# function info
function_name: str
function_args: str
function_response: str
embedding = Vector(config.embedding_dim)
# Add a datetime column, with default value as the current time
created_at = datetime
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
user=self.user,
text=self.text,
model=self.model,
function_name=self.function_name,
function_args=self.function_args,
function_response=self.function_response,
embedding=self.embedding,
created_at=self.created_at,
id=self.id,
)
"""Create database model for table_name"""
return MessageModel
elif table_type == TableType.DATA_SOURCES:
class SourceModel(LanceModel):
"""Defines data model for storing Passages (consisting of text, embedding)"""
# Assuming passage_id is the primary key
id: uuid.UUID
user_id: str
name: str
created_at: datetime
def __repr__(self):
return f"<Source(passage_id='{self.id}', name='{self.name}')>"
def to_record(self):
return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at)
"""Create database model for table_name"""
return SourceModel
else:
raise ValueError(f"Table type {table_type} not implemented")
class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
# TODO
pass
def generate_where_filter(self, filters: Dict) -> str:
where_filters = []
for key, value in filters.items():
where_filters.append(f"{key}={value}")
return where_filters.join(" AND ")
@abstractmethod
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
# TODO
pass
@abstractmethod
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
# TODO
pass
@abstractmethod
def get(self, id: str) -> Optional[Record]:
# TODO
pass
@abstractmethod
def size(self, filters: Optional[Dict] = {}) -> int:
# TODO
pass
@abstractmethod
def insert(self, record: Record):
# TODO
pass
@abstractmethod
def insert_many(self, records: List[Record], show_progress=True):
# TODO
pass
@abstractmethod
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
# TODO
pass
@abstractmethod
def query_date(self, start_date, end_date):
# TODO
pass
@abstractmethod
def query_text(self, query):
# TODO
pass
@abstractmethod
def delete_table(self):
# TODO
pass
@abstractmethod
def delete(self, filters: Optional[Dict] = {}):
# TODO
pass
@abstractmethod
def save(self):
# TODO
pass

View File

@@ -190,15 +190,23 @@ class StorageConnector:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@abstractmethod
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
def get_filters(self, filters: Optional[Dict] = {}):
pass
@abstractmethod
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
pass
@abstractmethod
def get(self, id: str) -> Record:
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
pass
@abstractmethod
def get(self, id: str) -> Optional[Record]:
pass
@abstractmethod
def size(self, filters: Optional[Dict] = {}) -> int:
pass
@abstractmethod
@@ -206,7 +214,7 @@ class StorageConnector:
pass
@abstractmethod
def insert_many(self, records: List[Record]):
def insert_many(self, records: List[Record], show_progress=True):
pass
@abstractmethod
@@ -214,11 +222,21 @@ class StorageConnector:
pass
@abstractmethod
def save(self):
"""Save state of storage connector"""
def query_date(self, start_date, end_date):
pass
@abstractmethod
def size(self, filters: Optional[Dict] = {}) -> int:
"""Get number of passages (text/embedding pairs) in storage"""
def query_text(self, query):
pass
@abstractmethod
def delete_table(self):
pass
@abstractmethod
def delete(self, filters: Optional[Dict] = {}):
pass
@abstractmethod
def save(self):
pass