Add skeleton code for lance integration
This commit is contained in:
199
memgpt/connectors/lancedb.py
Normal file
199
memgpt/connectors/lancedb.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import lancedb
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, List, Iterator, Dict
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.utils import printd
|
||||
from memgpt.data_types import Record, Message, Passage, Source
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from lancedb.pydantic import Vector, LanceModel
|
||||
|
||||
""" Initial implementation - not complete """
|
||||
|
||||
|
||||
def get_db_model(table_name: str, table_type: TableType):
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
# create schema for archival memory
|
||||
class PassageModel(LanceModel):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
id: uuid.UUID
|
||||
user_id: str
|
||||
text: str
|
||||
doc_id: str
|
||||
agent_id: str
|
||||
data_source: str
|
||||
embedding: Vector(config.embedding_dim)
|
||||
metadata_: Dict
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
doc_id=self.doc_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
data_source=self.data_source,
|
||||
agent_id=self.agent_id,
|
||||
metadata=self.metadata_,
|
||||
)
|
||||
|
||||
return PassageModel
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
|
||||
class MessageModel(LanceModel):
|
||||
"""Defines data model for storing Message objects"""
|
||||
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id: uuid.UUID
|
||||
user_id: str
|
||||
agent_id: str
|
||||
|
||||
# openai info
|
||||
role: str
|
||||
text: str
|
||||
model: str
|
||||
user: str
|
||||
|
||||
# function info
|
||||
function_name: str
|
||||
function_args: str
|
||||
function_response: str
|
||||
|
||||
embedding = Vector(config.embedding_dim)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = datetime
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Message(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
role=self.role,
|
||||
user=self.user,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
function_name=self.function_name,
|
||||
function_args=self.function_args,
|
||||
function_response=self.function_response,
|
||||
embedding=self.embedding,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
"""Create database model for table_name"""
|
||||
return MessageModel
|
||||
elif table_type == TableType.DATA_SOURCES:
|
||||
|
||||
class SourceModel(LanceModel):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
id: uuid.UUID
|
||||
user_id: str
|
||||
name: str
|
||||
created_at: datetime
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Source(passage_id='{self.id}', name='{self.name}')>"
|
||||
|
||||
def to_record(self):
|
||||
return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at)
|
||||
|
||||
"""Create database model for table_name"""
|
||||
return SourceModel
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def generate_where_filter(self, filters: Dict) -> str:
|
||||
where_filters = []
|
||||
for key, value in filters.items():
|
||||
where_filters.append(f"{key}={value}")
|
||||
return where_filters.join(" AND ")
|
||||
|
||||
@abstractmethod
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Optional[Record]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, record: Record):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_date(self, start_date, end_date):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_text(self, query):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_table(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
# TODO
|
||||
pass
|
||||
@@ -190,15 +190,23 @@ class StorageConnector:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
|
||||
def get_filters(self, filters: Optional[Dict] = {}):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Record:
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Optional[Record]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -206,7 +214,7 @@ class StorageConnector:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert_many(self, records: List[Record]):
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -214,11 +222,21 @@ class StorageConnector:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save state of storage connector"""
|
||||
def query_date(self, start_date, end_date):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
"""Get number of passages (text/embedding pairs) in storage"""
|
||||
def query_text(self, query):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_table(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user