Add in memory storage connector implementation for refactored storage
This commit is contained in:
@@ -7,10 +7,11 @@ from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, t
|
||||
from sqlalchemy.orm import sessionmaker, mapped_column
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import Column, BIGINT, String, DateTime
|
||||
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, List, Iterator
|
||||
from typing import Optional, List, Iterator, Dict
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
@@ -20,10 +21,18 @@ from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.utils import printd
|
||||
from memgpt.data_types import Record, Message, Passage
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def parse_formatted_time(formatted_time):
|
||||
# parse times returned by memgpt.utils.get_formatted_time()
|
||||
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
||||
def get_db_model(table_name: str, table_type: TableType):
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
@@ -37,6 +46,8 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
||||
doc_id = Column(String)
|
||||
agent_id = Column(String)
|
||||
data_source = Column(String) # agent_name if agent, data_source name if from data source
|
||||
text = Column(String, nullable=False)
|
||||
embedding = mapped_column(Vector(config.embedding_dim))
|
||||
metadata_ = Column(JSON(astext_type=Text()))
|
||||
@@ -48,9 +59,37 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
class_name = f"{table_name.capitalize()}Model"
|
||||
Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
|
||||
return Model
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
|
||||
class MessageModel(Base):
|
||||
"""Defines data model for storing Message objects"""
|
||||
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
content = Column(String, nullable=False)
|
||||
model = Column(String, nullable=False)
|
||||
function_name = Column(String)
|
||||
function_args = Column(String)
|
||||
function_response = Column(String)
|
||||
embedding = mapped_column(Vector(config.embedding_dim))
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', content='{self.content}', embedding='{self.embedding})>"
|
||||
|
||||
"""Create database model for table_name"""
|
||||
class_name = f"{table_name.capitalize()}Model"
|
||||
Model = type(class_name, (MessageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
|
||||
return Model
|
||||
else:
|
||||
# TODO: implement recall memory, document store
|
||||
raise NotImplementedError(f"Table type {table_type} not implemented")
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
|
||||
class PostgresStorageConnector(StorageConnector):
|
||||
@@ -58,21 +97,10 @@ class PostgresStorageConnector(StorageConnector):
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# determine table name
|
||||
if agent_config:
|
||||
assert name is None, f"Cannot specify both agent config and name {name}"
|
||||
self.table_name = self.generate_table_name_agent(agent_config)
|
||||
elif name:
|
||||
assert agent_config is None, f"Cannot specify both agent config and name {name}"
|
||||
self.table_name = self.generate_table_name(name)
|
||||
else:
|
||||
raise ValueError("Must specify either agent config or name")
|
||||
|
||||
printd(f"Using table name {self.table_name}")
|
||||
|
||||
# create table
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
@@ -83,39 +111,43 @@ class PostgresStorageConnector(StorageConnector):
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]:
|
||||
session = self.Session()
|
||||
offset = 0
|
||||
filters = self.get_filters(filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
db_passages_chunk = session.query(self.db_model).offset(offset).limit(page_size).all()
|
||||
db_passages_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
|
||||
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
if not db_passages_chunk:
|
||||
break
|
||||
|
||||
# Yield a list of Passage objects converted from the chunk
|
||||
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]
|
||||
# Yield a list of Record objects converted from the chunk
|
||||
yield [self.type(**p.to_dict()) for p in db_passages_chunk]
|
||||
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
|
||||
def get_all(self, limit=10) -> List[Passage]:
|
||||
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
db_passages = session.query(self.db_model).limit(limit).all()
|
||||
return [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages]
|
||||
filters = self.get_filters(filters)
|
||||
db_passages = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
return [self.type(**p.to_dict()) for p in db_passages]
|
||||
|
||||
def get(self, id: str) -> Optional[Passage]:
|
||||
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Passage]:
|
||||
session = self.Session()
|
||||
db_passage = session.query(self.db_model).get(id)
|
||||
filters = self.get_filters(filters)
|
||||
db_passage = session.query(self.db_model).filter(*filters).get(id)
|
||||
if db_passage is None:
|
||||
return None
|
||||
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)
|
||||
|
||||
def size(self) -> int:
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# return size of table
|
||||
session = self.Session()
|
||||
return session.query(self.db_model).count()
|
||||
filters = self.get_filters(filters)
|
||||
return session.query(self.db_model).filter(*filters).count()
|
||||
|
||||
def insert(self, passage: Passage):
|
||||
session = self.Session()
|
||||
@@ -123,38 +155,35 @@ class PostgresStorageConnector(StorageConnector):
|
||||
session.add(db_passage)
|
||||
session.commit()
|
||||
|
||||
def insert_many(self, passages: List[Passage], show_progress=True):
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
session = self.Session()
|
||||
iterable = tqdm(passages) if show_progress else passages
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for passage in iterable:
|
||||
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
|
||||
session.add(db_passage)
|
||||
session.commit()
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
# Assuming PassageModel.embedding has the capability of computing l2_distance
|
||||
results = session.scalars(select(self.db_model).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)).all()
|
||||
filters = self.get_filters(filters)
|
||||
results = session.scalars(
|
||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||
).all()
|
||||
|
||||
# Convert the results into Passage objects
|
||||
passages = [
|
||||
Passage(text=result.text, embedding=np.frombuffer(result.embedding), doc_id=result.doc_id, passage_id=result.id)
|
||||
for result in results
|
||||
]
|
||||
return passages
|
||||
|
||||
def delete(self):
|
||||
"""Drop the passage table from the database."""
|
||||
# Bind the engine to the metadata of the base class so that the
|
||||
# declaratives can be accessed through a DBSession instance
|
||||
Base.metadata.bind = self.engine
|
||||
|
||||
# Drop the table specified by the PassageModel class
|
||||
self.db_model.__table__.drop(self.engine)
|
||||
records = [self.type(**vars(result)) for result in results]
|
||||
return records
|
||||
|
||||
def save(self):
|
||||
return
|
||||
|
||||
def list_data_sources(self):
|
||||
assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
|
||||
session = self.Session()
|
||||
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
|
||||
return unique_data_sources
|
||||
|
||||
@staticmethod
|
||||
def list_loaded_data():
|
||||
config = MemGPTConfig.load()
|
||||
@@ -166,29 +195,6 @@ class PostgresStorageConnector(StorageConnector):
|
||||
tables = [table[start_chars:] for table in tables]
|
||||
return tables
|
||||
|
||||
def sanitize_table_name(self, name: str) -> str:
|
||||
# Remove leading and trailing whitespace
|
||||
name = name.strip()
|
||||
|
||||
# Replace spaces and invalid characters with underscores
|
||||
name = re.sub(r"\s+|\W+", "_", name)
|
||||
|
||||
# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
|
||||
max_length = 63
|
||||
if len(name) > max_length:
|
||||
name = name[:max_length].rstrip("_")
|
||||
|
||||
# Convert to lowercase
|
||||
name = name.lower()
|
||||
|
||||
return name
|
||||
|
||||
def generate_table_name_agent(self, agent_config: AgentConfig):
|
||||
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
|
||||
|
||||
def generate_table_name(self, name: str):
|
||||
return f"memgpt_{self.sanitize_table_name(name)}"
|
||||
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
@@ -309,26 +315,3 @@ class LanceDBConnector(StorageConnector):
|
||||
start_chars = len("memgpt_")
|
||||
tables = [table[start_chars:] for table in tables]
|
||||
return tables
|
||||
|
||||
def sanitize_table_name(self, name: str) -> str:
|
||||
# Remove leading and trailing whitespace
|
||||
name = name.strip()
|
||||
|
||||
# Replace spaces and invalid characters with underscores
|
||||
name = re.sub(r"\s+|\W+", "_", name)
|
||||
|
||||
# Truncate to the maximum identifier length
|
||||
max_length = 63
|
||||
if len(name) > max_length:
|
||||
name = name[:max_length].rstrip("_")
|
||||
|
||||
# Convert to lowercase
|
||||
name = name.lower()
|
||||
|
||||
return name
|
||||
|
||||
def generate_table_name_agent(self, agent_config: AgentConfig):
|
||||
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
|
||||
|
||||
def generate_table_name(self, name: str):
|
||||
return f"memgpt_{self.sanitize_table_name(name)}"
|
||||
|
||||
@@ -14,6 +14,7 @@ from llama_index.retrievers import VectorIndexRetriever
|
||||
from llama_index.schema import TextNode
|
||||
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.data_types import Record
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.connectors.storage import StorageConnector, Passage
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
@@ -137,3 +138,50 @@ class VectorIndexStorageConnector(StorageConnector):
|
||||
|
||||
def size(self):
|
||||
return len(self.get_nodes())
|
||||
|
||||
|
||||
class InMemoryStorageConnector(StorageConnector):
|
||||
"""Really dumb class so we can have a unified storae connector interface - keeps everything in memory"""
|
||||
|
||||
# TODO: maybae replace this with sqllite?
|
||||
|
||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
from memgpt.embeddings import embedding_model
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
# TODO: figure out save location
|
||||
|
||||
self.rows = []
|
||||
|
||||
@abstractmethod
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Record:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, record: Record):
|
||||
self.rows.append(record)
|
||||
|
||||
@abstractmethod
|
||||
def insert_many(self, records: List[Record]):
|
||||
self.rows += records
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save state of storage connector"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
pass
|
||||
|
||||
@@ -8,74 +8,137 @@ import pickle
|
||||
import os
|
||||
|
||||
|
||||
from typing import List, Optional
|
||||
from abc import abstractmethod
|
||||
import numpy as np
|
||||
from typing import List, Optional, Dict
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.data_types import Record
|
||||
from memgpt.data_types import Record, Passage, Document, Message
|
||||
|
||||
|
||||
# ENUM representing table types in MemGPT
|
||||
# each table corresponds to a different table schema (specified in data_types.py)
|
||||
class TableType:
|
||||
ARCHIVAL_MEMORY = "archival_memory" # recall memory table: memgpt_agent_{agent_id}
|
||||
RECALL_MEMORY = "recall_memory" # archival memory table: memgpt_agent_recall_{agent_id}
|
||||
DOCUMENTS = "documents"
|
||||
USERS = "users"
|
||||
AGENTS = "agents"
|
||||
PASSAGES = "passages" # TODO
|
||||
DOCUMENTS = "documents" # TODO
|
||||
USERS = "users" # TODO
|
||||
AGENTS = "agents" # TODO
|
||||
|
||||
|
||||
# Defining schema objects:
|
||||
# Note: user/agent can borrow from MemGPTConfig/AgentConfig classes
|
||||
# table names used by MemGPT
|
||||
RECALL_TABLE_NAME = "memgpt_recall_memory"
|
||||
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory"
|
||||
PASSAGE_TABLE_NAME = "memgpt_passages"
|
||||
DOCUMENT_TABLE_NAME = "memgpt_documents"
|
||||
|
||||
|
||||
class StorageConnector:
|
||||
def __init__(self, table_type: TableType, agent_config: Optional[AgentConfig] = None):
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
self.agent_config = agent_config
|
||||
self.user_id = config.anon_clientid
|
||||
self.table_type = table_type
|
||||
|
||||
# get object type
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
self.type = Passage
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.type = Message
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
# determine name of database table
|
||||
self.table_name = self.generate_table_name(agent_config, table_type=table_type)
|
||||
printd(f"Using table name {self.table_name}")
|
||||
|
||||
# setup base filters
|
||||
if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY:
|
||||
# agent-specific table
|
||||
self.filters = {"user_id": self.user_id, "agent_id": self.agent_config.name}
|
||||
else:
|
||||
self.filters = {"user_id": self.user_id}
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}):
|
||||
# get all filters for query
|
||||
if filters is not None:
|
||||
filter_conditions = {**self.filters, **filters}
|
||||
return self.filters + [self.db_model[key] == value for key, value in filter_conditions.items()]
|
||||
else:
|
||||
return self.filters
|
||||
|
||||
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
|
||||
|
||||
if agent_config is not None:
|
||||
# Table names for agent-specific tables
|
||||
if agent_config.memgpt_version < "0.2.6":
|
||||
# if agent is prior version, use old table name
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
else:
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
return ARCHIVAL_TABLE_NAME
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
return RECALL_TABLE_NAME
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
else:
|
||||
# table names for non-agent specific tables
|
||||
if table_type == TableType.PASSAGES:
|
||||
return PASSAGE_TABLE_NAME
|
||||
elif table_type == TableType.DOCUMENTS:
|
||||
return DOCUMENT_TABLE_NAME
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
@staticmethod
|
||||
def get_archival_storage_connector(name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None):
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
|
||||
return VectorIndexStorageConnector(name=name, agent_config=agent_config)
|
||||
return VectorIndexStorageConnector(agent_config=agent_config)
|
||||
|
||||
elif storage_type == "postgres":
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
|
||||
return PostgresStorageConnector(name=name, agent_config=agent_config)
|
||||
elif storage_type == "chroma":
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
return PostgresStorageConnector(agent_config=agent_config)
|
||||
|
||||
return ChromaStorageConnector(name=name, agent_config=agent_config)
|
||||
elif storage_type == "lancedb":
|
||||
from memgpt.connectors.db import LanceDBConnector
|
||||
|
||||
return LanceDBConnector(name=name, agent_config=agent_config)
|
||||
return LanceDBConnector(agent_config=agent_config)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@staticmethod
|
||||
def get_recall_storage_connector(name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
|
||||
storage_type = MemGPTConfig.load().recall_storage_type
|
||||
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
from memgpt.connectors.local import InMemoryStorageConnector
|
||||
|
||||
# maintains in-memory list for storage
|
||||
return InMemoryStorageConnector(name=name, agent_config=agent_config)
|
||||
return InMemoryStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
|
||||
|
||||
elif storage_type == "postgres":
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
|
||||
return PostgresStorageConnector(name=name, agent_config=agent_config)
|
||||
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@staticmethod
|
||||
def list_loaded_data():
|
||||
# TODO: modify this to simply list loaded data from a given user
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
@@ -97,11 +160,11 @@ class StorageConnector:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_all_paginated(self, page_size: int) -> Iterator[List[Record]]:
|
||||
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, limit: int) -> List[Record]:
|
||||
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -109,15 +172,15 @@ class StorageConnector:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, passage: Record):
|
||||
def insert(self, record: Record):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert_many(self, passages: List[Record]):
|
||||
def insert_many(self, records: List[Record]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Record]:
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -126,6 +189,6 @@ class StorageConnector:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def size(self):
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
"""Get number of passages (text/embedding pairs) in storage"""
|
||||
pass
|
||||
|
||||
@@ -35,7 +35,7 @@ class Message(Record):
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
role: str,
|
||||
text: str,
|
||||
content: str,
|
||||
model: str, # model used to make function call
|
||||
function_name: Optional[str] = None, # name of function called
|
||||
function_args: Optional[str] = None, # args of function called
|
||||
@@ -43,7 +43,7 @@ class Message(Record):
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
id: Optional[str] = None,
|
||||
):
|
||||
super().__init__(user_id, agent_id, text, id)
|
||||
super().__init__(user_id, agent_id, content, id)
|
||||
self.role = role # role (agent/user/function)
|
||||
self.model = model # model name (e.g. gpt-4)
|
||||
|
||||
@@ -62,10 +62,11 @@ class Message(Record):
|
||||
class Document(Record):
|
||||
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
|
||||
|
||||
def __init__(self, user_id: str, text: str, document_id: Optional[str] = None):
|
||||
def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
|
||||
super().__init__(user_id)
|
||||
self.text = text
|
||||
self.document_id = document_id
|
||||
self.data_source = data_source
|
||||
# TODO: add optional embedding?
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -78,9 +79,18 @@ class Passage(Record):
|
||||
It is a string of text with an associated embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str, text: str, embedding: np.ndarray, doc_id: Optional[str] = None, passage_id: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
text: str,
|
||||
data_source: str,
|
||||
embedding: np.ndarray,
|
||||
doc_id: Optional[str] = None,
|
||||
passage_id: Optional[str] = None,
|
||||
):
|
||||
super().__init__(user_id)
|
||||
self.text = text
|
||||
self.data_source = data_source
|
||||
self.embedding = embedding
|
||||
self.doc_id = doc_id
|
||||
self.passage_id = passage_id
|
||||
|
||||
@@ -4,11 +4,10 @@ import re
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC
|
||||
from memgpt.utils import get_local_time, printd, count_tokens
|
||||
from memgpt.utils import get_local_time, printd, count_tokens, validate_date_format, extract_date_from_timestamp
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from memgpt.openai_tools import create
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Message, Passage
|
||||
from llama_index import Document
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
@@ -137,7 +136,7 @@ def summarize_messages(
|
||||
|
||||
class ArchivalMemory(ABC):
|
||||
@abstractmethod
|
||||
def insert(self, memory_string):
|
||||
def insert(self, memory_string: str):
|
||||
"""Insert new archival memory
|
||||
|
||||
:param memory_string: Memory string to insert
|
||||
@@ -178,6 +177,10 @@ class RecallMemory(ABC):
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, message: Message):
|
||||
pass
|
||||
|
||||
|
||||
class DummyRecallMemory(RecallMemory):
|
||||
"""Dummy in-memory version of a recall memory database (eg run on MongoDB)
|
||||
@@ -189,29 +192,12 @@ class DummyRecallMemory(RecallMemory):
|
||||
effectively allowing it to 'remember' prior engagements with a user.
|
||||
"""
|
||||
|
||||
# TODO: replace this with StorageConnector based implementation
|
||||
|
||||
def __init__(self, agent_config, restrict_search_to_summaries=False):
|
||||
def __init__(self, message_database=None, restrict_search_to_summaries=False):
|
||||
self._message_logs = [] if message_database is None else message_database # consists of full message dicts
|
||||
|
||||
# If true, the pool of messages that can be queried are the automated summaries only
|
||||
# (generated when the conversation window needs to be shortened)
|
||||
self.restrict_search_to_summaries = restrict_search_to_summaries
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
|
||||
self.top_k = top_k
|
||||
self.agent_config = agent_config
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# create embedding model
|
||||
self.embed_model = embedding_model()
|
||||
self.embedding_chunk_size = config.embedding_chunk_size
|
||||
|
||||
# create storage backend
|
||||
self.storage = StorageConnector.get_archival_storage_connector(
|
||||
agent_config=agent_config, table_type="recall_memory" # TODO: change to enum
|
||||
)
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self._message_logs)
|
||||
@@ -267,25 +253,11 @@ class DummyRecallMemory(RecallMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
def _validate_date_format(self, date_str):
|
||||
"""Validate the given date string in the format 'YYYY-MM-DD'."""
|
||||
try:
|
||||
datetime.datetime.strptime(date_str, "%Y-%m-%d")
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def _extract_date_from_timestamp(self, timestamp):
|
||||
"""Extracts and returns the date from the given timestamp."""
|
||||
# Extracts the date (ignoring the time and timezone)
|
||||
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def date_search(self, start_date, end_date, count=None, start=None):
|
||||
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
||||
|
||||
# First, validate the start_date and end_date format
|
||||
if not self._validate_date_format(start_date) or not self._validate_date_format(end_date):
|
||||
if not validate_date_format(start_date) or not validate_date_format(end_date):
|
||||
raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
|
||||
|
||||
# Convert dates to datetime objects for comparison
|
||||
@@ -296,7 +268,7 @@ class DummyRecallMemory(RecallMemory):
|
||||
matches = [
|
||||
d
|
||||
for d in message_pool
|
||||
if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
|
||||
if start_date_dt <= datetime.datetime.strptime(extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
|
||||
]
|
||||
|
||||
# start/count support paging through results
|
||||
@@ -312,6 +284,42 @@ class DummyRecallMemory(RecallMemory):
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class RecallMemorySQL(RecallMemory):
|
||||
def __init__(self, agent_config, restrict_search_to_summaries=False):
|
||||
|
||||
# If true, the pool of messages that can be queried are the automated summaries only
|
||||
# (generated when the conversation window needs to be shortened)
|
||||
self.restrict_search_to_summaries = restrict_search_to_summaries
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
|
||||
self.agent_config = agent_config
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# create embedding model
|
||||
self.embed_model = embedding_model()
|
||||
self.embedding_chunk_size = config.embedding_chunk_size
|
||||
|
||||
# create storage backend
|
||||
self.storage = StorageConnector.get_recall_storage_connector(agent_config=agent_config)
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
@abstractmethod
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def date_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
def insert(self, message: Message):
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
"""Archival memory with embedding based search"""
|
||||
|
||||
|
||||
@@ -1,13 +1,5 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
|
||||
import re
|
||||
import difflib
|
||||
import demjson3 as demjson
|
||||
import pytz
|
||||
@@ -288,3 +280,20 @@ def get_schema_diff(schema_a, schema_b):
|
||||
difference = [line for line in difference if line.startswith("+ ") or line.startswith("- ")]
|
||||
|
||||
return "".join(difference)
|
||||
|
||||
|
||||
# datetime related
|
||||
def validate_date_format(date_str):
|
||||
"""Validate the given date string in the format 'YYYY-MM-DD'."""
|
||||
try:
|
||||
datetime.datetime.strptime(date_str, "%Y-%m-%d")
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def extract_date_from_timestamp(timestamp):
|
||||
"""Extracts and returns the date from the given timestamp."""
|
||||
# Extracts the date (ignoring the time and timezone)
|
||||
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
|
||||
return match.group(1) if match else None
|
||||
|
||||
@@ -14,11 +14,54 @@ from memgpt.connectors.storage import StorageConnector, Passage
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Message, Passage
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def test_recall_db() -> None:
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
|
||||
storage_type = "postgres"
|
||||
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri)
|
||||
print(config.config_path)
|
||||
assert config.recall_storage_uri is not None
|
||||
config.save()
|
||||
print(config)
|
||||
|
||||
conn = StorageConnector.get_recall_storage_connector()
|
||||
|
||||
# construct recall memory messages
|
||||
message1 = Message(
|
||||
agent_id="test_agent1",
|
||||
role="agent",
|
||||
content="This is a test message",
|
||||
id="test_id1",
|
||||
)
|
||||
message2 = Message(
|
||||
agent_id="test_agent2",
|
||||
role="user",
|
||||
content="This is a test message",
|
||||
id="test_id2",
|
||||
)
|
||||
|
||||
# test insert
|
||||
conn.insert(message1)
|
||||
conn.insert_many([message2])
|
||||
|
||||
# test size
|
||||
assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}"
|
||||
assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}"
|
||||
|
||||
# test get
|
||||
assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}"
|
||||
assert (
|
||||
len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1
|
||||
), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
||||
def test_postgres_openai():
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
|
||||
Reference in New Issue
Block a user