Add in memory storage connector implementation for refactored storage

This commit is contained in:
Sarah Wooders
2023-12-04 10:04:06 -08:00
parent d041455375
commit 408df89c9c
7 changed files with 332 additions and 168 deletions

View File

@@ -7,10 +7,11 @@ from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, t
from sqlalchemy.orm import sessionmaker, mapped_column
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
from sqlalchemy import Column, BIGINT, String, DateTime
import re
from tqdm import tqdm
from typing import Optional, List, Iterator
from typing import Optional, List, Iterator, Dict
import numpy as np
from tqdm import tqdm
import pandas as pd
@@ -20,10 +21,18 @@ from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.utils import printd
from memgpt.data_types import Record, Message, Passage
from datetime import datetime
Base = declarative_base()
def parse_formatted_time(formatted_time):
# parse times returned by memgpt.utils.get_formatted_time()
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
def get_db_model(table_name: str, table_type: TableType):
config = MemGPTConfig.load()
@@ -37,6 +46,8 @@ def get_db_model(table_name: str, table_type: TableType):
# Assuming passage_id is the primary key
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
doc_id = Column(String)
agent_id = Column(String)
data_source = Column(String) # agent_name if agent, data_source name if from data source
text = Column(String, nullable=False)
embedding = mapped_column(Vector(config.embedding_dim))
metadata_ = Column(JSON(astext_type=Text()))
@@ -48,9 +59,37 @@ def get_db_model(table_name: str, table_type: TableType):
class_name = f"{table_name.capitalize()}Model"
Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
return Model
elif table_type == TableType.RECALL_MEMORY:
class MessageModel(Base):
"""Defines data model for storing Message objects"""
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
role = Column(String, nullable=False)
content = Column(String, nullable=False)
model = Column(String, nullable=False)
function_name = Column(String)
function_args = Column(String)
function_response = Column(String)
embedding = mapped_column(Vector(config.embedding_dim))
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<Message(message_id='{self.id}', content='{self.content}', embedding='{self.embedding})>"
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model"
Model = type(class_name, (MessageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
return Model
else:
# TODO: implement recall memory, document store
raise NotImplementedError(f"Table type {table_type} not implemented")
raise ValueError(f"Table type {table_type} not implemented")
class PostgresStorageConnector(StorageConnector):
@@ -58,21 +97,10 @@ class PostgresStorageConnector(StorageConnector):
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
# determine table name
if agent_config:
assert name is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name_agent(agent_config)
elif name:
assert agent_config is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name(name)
else:
raise ValueError("Must specify either agent config or name")
printd(f"Using table name {self.table_name}")
# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
@@ -83,39 +111,43 @@ class PostgresStorageConnector(StorageConnector):
self.Session = sessionmaker(bind=self.engine)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]:
session = self.Session()
offset = 0
filters = self.get_filters(filters)
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = session.query(self.db_model).offset(offset).limit(page_size).all()
db_passages_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break
# Yield a list of Passage objects converted from the chunk
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]
# Yield a list of Record objects converted from the chunk
yield [self.type(**p.to_dict()) for p in db_passages_chunk]
# Increment the offset to get the next chunk in the next iteration
offset += page_size
def get_all(self, limit=10) -> List[Passage]:
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
db_passages = session.query(self.db_model).limit(limit).all()
return [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages]
filters = self.get_filters(filters)
db_passages = session.query(self.db_model).filter(*filters).limit(limit).all()
return [self.type(**p.to_dict()) for p in db_passages]
def get(self, id: str) -> Optional[Passage]:
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Passage]:
session = self.Session()
db_passage = session.query(self.db_model).get(id)
filters = self.get_filters(filters)
db_passage = session.query(self.db_model).filter(*filters).get(id)
if db_passage is None:
return None
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)
def size(self) -> int:
def size(self, filters: Optional[Dict] = {}) -> int:
# return size of table
session = self.Session()
return session.query(self.db_model).count()
filters = self.get_filters(filters)
return session.query(self.db_model).filter(*filters).count()
def insert(self, passage: Passage):
session = self.Session()
@@ -123,38 +155,35 @@ class PostgresStorageConnector(StorageConnector):
session.add(db_passage)
session.commit()
def insert_many(self, passages: List[Passage], show_progress=True):
def insert_many(self, records: List[Record], show_progress=True):
session = self.Session()
iterable = tqdm(passages) if show_progress else passages
iterable = tqdm(records) if show_progress else records
for passage in iterable:
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
session.add(db_passage)
session.commit()
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
# Assuming PassageModel.embedding has the capability of computing l2_distance
results = session.scalars(select(self.db_model).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)).all()
filters = self.get_filters(filters)
results = session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
# Convert the results into Passage objects
passages = [
Passage(text=result.text, embedding=np.frombuffer(result.embedding), doc_id=result.doc_id, passage_id=result.id)
for result in results
]
return passages
def delete(self):
"""Drop the passage table from the database."""
# Bind the engine to the metadata of the base class so that the
# declaratives can be accessed through a DBSession instance
Base.metadata.bind = self.engine
# Drop the table specified by the PassageModel class
self.db_model.__table__.drop(self.engine)
records = [self.type(**vars(result)) for result in results]
return records
def save(self):
return
def list_data_sources(self):
assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
session = self.Session()
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
return unique_data_sources
@staticmethod
def list_loaded_data():
config = MemGPTConfig.load()
@@ -166,29 +195,6 @@ class PostgresStorageConnector(StorageConnector):
tables = [table[start_chars:] for table in tables]
return tables
def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()
# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)
# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")
# Convert to lowercase
name = name.lower()
return name
def generate_table_name_agent(self, agent_config: AgentConfig):
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"
class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""
@@ -309,26 +315,3 @@ class LanceDBConnector(StorageConnector):
start_chars = len("memgpt_")
tables = [table[start_chars:] for table in tables]
return tables
def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()
# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)
# Truncate to the maximum identifier length
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")
# Convert to lowercase
name = name.lower()
return name
def generate_table_name_agent(self, agent_config: AgentConfig):
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"

View File

@@ -14,6 +14,7 @@ from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import TextNode
from memgpt.constants import MEMGPT_DIR
from memgpt.data_types import Record
from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.config import AgentConfig, MemGPTConfig
@@ -137,3 +138,50 @@ class VectorIndexStorageConnector(StorageConnector):
def size(self):
return len(self.get_nodes())
class InMemoryStorageConnector(StorageConnector):
"""Really dumb class so we can have a unified storae connector interface - keeps everything in memory"""
# TODO: maybae replace this with sqllite?
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
from memgpt.embeddings import embedding_model
config = MemGPTConfig.load()
# TODO: figure out save location
self.rows = []
@abstractmethod
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
pass
@abstractmethod
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
pass
@abstractmethod
def get(self, id: str) -> Record:
pass
@abstractmethod
def insert(self, record: Record):
self.rows.append(record)
@abstractmethod
def insert_many(self, records: List[Record]):
self.rows += records
@abstractmethod
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
pass
@abstractmethod
def save(self):
"""Save state of storage connector"""
pass
@abstractmethod
def size(self, filters: Optional[Dict] = {}) -> int:
pass

View File

@@ -8,74 +8,137 @@ import pickle
import os
from typing import List, Optional
from abc import abstractmethod
import numpy as np
from typing import List, Optional, Dict
from tqdm import tqdm
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.data_types import Record
from memgpt.data_types import Record, Passage, Document, Message
# ENUM representing table types in MemGPT
# each table corresponds to a different table schema (specified in data_types.py)
class TableType:
ARCHIVAL_MEMORY = "archival_memory" # recall memory table: memgpt_agent_{agent_id}
RECALL_MEMORY = "recall_memory" # archival memory table: memgpt_agent_recall_{agent_id}
DOCUMENTS = "documents"
USERS = "users"
AGENTS = "agents"
PASSAGES = "passages" # TODO
DOCUMENTS = "documents" # TODO
USERS = "users" # TODO
AGENTS = "agents" # TODO
# Defining schema objects:
# Note: user/agent can borrow from MemGPTConfig/AgentConfig classes
# table names used by MemGPT
RECALL_TABLE_NAME = "memgpt_recall_memory"
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory"
PASSAGE_TABLE_NAME = "memgpt_passages"
DOCUMENT_TABLE_NAME = "memgpt_documents"
class StorageConnector:
def __init__(self, table_type: TableType, agent_config: Optional[AgentConfig] = None):
config = MemGPTConfig.load()
self.agent_config = agent_config
self.user_id = config.anon_clientid
self.table_type = table_type
# get object type
if table_type == TableType.ARCHIVAL_MEMORY:
self.type = Passage
elif table_type == TableType.RECALL_MEMORY:
self.type = Message
else:
raise ValueError(f"Table type {table_type} not implemented")
# determine name of database table
self.table_name = self.generate_table_name(agent_config, table_type=table_type)
printd(f"Using table name {self.table_name}")
# setup base filters
if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY:
# agent-specific table
self.filters = {"user_id": self.user_id, "agent_id": self.agent_config.name}
else:
self.filters = {"user_id": self.user_id}
def get_filters(self, filters: Optional[Dict] = {}):
# get all filters for query
if filters is not None:
filter_conditions = {**self.filters, **filters}
return self.filters + [self.db_model[key] == value for key, value in filter_conditions.items()]
else:
return self.filters
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
if agent_config is not None:
# Table names for agent-specific tables
if agent_config.memgpt_version < "0.2.6":
# if agent is prior version, use old table name
if table_type == TableType.ARCHIVAL_MEMORY:
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
else:
raise ValueError(f"Table type {table_type} not implemented")
else:
if table_type == TableType.ARCHIVAL_MEMORY:
return ARCHIVAL_TABLE_NAME
elif table_type == TableType.RECALL_MEMORY:
return RECALL_TABLE_NAME
else:
raise ValueError(f"Table type {table_type} not implemented")
else:
# table names for non-agent specific tables
if table_type == TableType.PASSAGES:
return PASSAGE_TABLE_NAME
elif table_type == TableType.DOCUMENTS:
return DOCUMENT_TABLE_NAME
else:
raise ValueError(f"Table type {table_type} not implemented")
@staticmethod
def get_archival_storage_connector(name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None):
storage_type = MemGPTConfig.load().archival_storage_type
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector
return VectorIndexStorageConnector(name=name, agent_config=agent_config)
return VectorIndexStorageConnector(agent_config=agent_config)
elif storage_type == "postgres":
from memgpt.connectors.db import PostgresStorageConnector
return PostgresStorageConnector(name=name, agent_config=agent_config)
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector
return PostgresStorageConnector(agent_config=agent_config)
return ChromaStorageConnector(name=name, agent_config=agent_config)
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector
return LanceDBConnector(name=name, agent_config=agent_config)
return LanceDBConnector(agent_config=agent_config)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@staticmethod
def get_recall_storage_connector(name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
storage_type = MemGPTConfig.load().recall_storage_type
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector
from memgpt.connectors.local import InMemoryStorageConnector
# maintains in-memory list for storage
return InMemoryStorageConnector(name=name, agent_config=agent_config)
return InMemoryStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
elif storage_type == "postgres":
from memgpt.connectors.db import PostgresStorageConnector
return PostgresStorageConnector(name=name, agent_config=agent_config)
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@staticmethod
def list_loaded_data():
# TODO: modify this to simply list loaded data from a given user
storage_type = MemGPTConfig.load().archival_storage_type
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector
@@ -97,11 +160,11 @@ class StorageConnector:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@abstractmethod
def get_all_paginated(self, page_size: int) -> Iterator[List[Record]]:
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
pass
@abstractmethod
def get_all(self, limit: int) -> List[Record]:
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
pass
@abstractmethod
@@ -109,15 +172,15 @@ class StorageConnector:
pass
@abstractmethod
def insert(self, passage: Record):
def insert(self, record: Record):
pass
@abstractmethod
def insert_many(self, passages: List[Record]):
def insert_many(self, records: List[Record]):
pass
@abstractmethod
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Record]:
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
pass
@abstractmethod
@@ -126,6 +189,6 @@ class StorageConnector:
pass
@abstractmethod
def size(self):
def size(self, filters: Optional[Dict] = {}) -> int:
"""Get number of passages (text/embedding pairs) in storage"""
pass

View File

@@ -35,7 +35,7 @@ class Message(Record):
user_id: str,
agent_id: str,
role: str,
text: str,
content: str,
model: str, # model used to make function call
function_name: Optional[str] = None, # name of function called
function_args: Optional[str] = None, # args of function called
@@ -43,7 +43,7 @@ class Message(Record):
embedding: Optional[np.ndarray] = None,
id: Optional[str] = None,
):
super().__init__(user_id, agent_id, text, id)
super().__init__(user_id, agent_id, content, id)
self.role = role # role (agent/user/function)
self.model = model # model name (e.g. gpt-4)
@@ -62,10 +62,11 @@ class Message(Record):
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
def __init__(self, user_id: str, text: str, document_id: Optional[str] = None):
def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
super().__init__(user_id)
self.text = text
self.document_id = document_id
self.data_source = data_source
# TODO: add optional embedding?
def __repr__(self) -> str:
@@ -78,9 +79,18 @@ class Passage(Record):
It is a string of text with an associated embedding.
"""
def __init__(self, user_id: str, text: str, embedding: np.ndarray, doc_id: Optional[str] = None, passage_id: Optional[str] = None):
def __init__(
self,
user_id: str,
text: str,
data_source: str,
embedding: np.ndarray,
doc_id: Optional[str] = None,
passage_id: Optional[str] = None,
):
super().__init__(user_id)
self.text = text
self.data_source = data_source
self.embedding = embedding
self.doc_id = doc_id
self.passage_id = passage_id

View File

@@ -4,11 +4,10 @@ import re
from typing import Optional, List, Tuple
from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC
from memgpt.utils import get_local_time, printd, count_tokens
from memgpt.utils import get_local_time, printd, count_tokens, validate_date_format, extract_date_from_timestamp
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt.openai_tools import create
from memgpt.config import MemGPTConfig
from memgpt.embeddings import embedding_model
from memgpt.data_types import Message, Passage
from llama_index import Document
from llama_index.node_parser import SimpleNodeParser
from llama_index.node_parser import SimpleNodeParser
@@ -137,7 +136,7 @@ def summarize_messages(
class ArchivalMemory(ABC):
@abstractmethod
def insert(self, memory_string):
def insert(self, memory_string: str):
"""Insert new archival memory
:param memory_string: Memory string to insert
@@ -178,6 +177,10 @@ class RecallMemory(ABC):
def __repr__(self) -> str:
pass
@abstractmethod
def insert(self, message: Message):
pass
class DummyRecallMemory(RecallMemory):
"""Dummy in-memory version of a recall memory database (eg run on MongoDB)
@@ -189,29 +192,12 @@ class DummyRecallMemory(RecallMemory):
effectively allowing it to 'remember' prior engagements with a user.
"""
# TODO: replace this with StorageConnector based implementation
def __init__(self, agent_config, restrict_search_to_summaries=False):
def __init__(self, message_database=None, restrict_search_to_summaries=False):
self._message_logs = [] if message_database is None else message_database # consists of full message dicts
# If true, the pool of messages that can be queried are the automated summaries only
# (generated when the conversation window needs to be shortened)
self.restrict_search_to_summaries = restrict_search_to_summaries
from memgpt.connectors.storage import StorageConnector
self.top_k = top_k
self.agent_config = agent_config
config = MemGPTConfig.load()
# create embedding model
self.embed_model = embedding_model()
self.embedding_chunk_size = config.embedding_chunk_size
# create storage backend
self.storage = StorageConnector.get_archival_storage_connector(
agent_config=agent_config, table_type="recall_memory" # TODO: change to enum
)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
def __len__(self):
return len(self._message_logs)
@@ -267,25 +253,11 @@ class DummyRecallMemory(RecallMemory):
else:
return matches, len(matches)
def _validate_date_format(self, date_str):
"""Validate the given date string in the format 'YYYY-MM-DD'."""
try:
datetime.datetime.strptime(date_str, "%Y-%m-%d")
return True
except (ValueError, TypeError):
return False
def _extract_date_from_timestamp(self, timestamp):
"""Extracts and returns the date from the given timestamp."""
# Extracts the date (ignoring the time and timezone)
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
return match.group(1) if match else None
def date_search(self, start_date, end_date, count=None, start=None):
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
# First, validate the start_date and end_date format
if not self._validate_date_format(start_date) or not self._validate_date_format(end_date):
if not validate_date_format(start_date) or not validate_date_format(end_date):
raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
# Convert dates to datetime objects for comparison
@@ -296,7 +268,7 @@ class DummyRecallMemory(RecallMemory):
matches = [
d
for d in message_pool
if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
if start_date_dt <= datetime.datetime.strptime(extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
]
# start/count support paging through results
@@ -312,6 +284,42 @@ class DummyRecallMemory(RecallMemory):
return matches, len(matches)
class RecallMemorySQL(RecallMemory):
def __init__(self, agent_config, restrict_search_to_summaries=False):
# If true, the pool of messages that can be queried are the automated summaries only
# (generated when the conversation window needs to be shortened)
self.restrict_search_to_summaries = restrict_search_to_summaries
from memgpt.connectors.storage import StorageConnector
self.agent_config = agent_config
config = MemGPTConfig.load()
# create embedding model
self.embed_model = embedding_model()
self.embedding_chunk_size = config.embedding_chunk_size
# create storage backend
self.storage = StorageConnector.get_recall_storage_connector(agent_config=agent_config)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
@abstractmethod
def text_search(self, query_string, count=None, start=None):
pass
@abstractmethod
def date_search(self, query_string, count=None, start=None):
pass
@abstractmethod
def __repr__(self) -> str:
pass
def insert(self, message: Message):
pass
class EmbeddingArchivalMemory(ArchivalMemory):
"""Archival memory with embedding based search"""

View File

@@ -1,13 +1,5 @@
from datetime import datetime
import json
import os
import pickle
import platform
import subprocess
import sys
import io
from contextlib import contextmanager
import re
import difflib
import demjson3 as demjson
import pytz
@@ -288,3 +280,20 @@ def get_schema_diff(schema_a, schema_b):
difference = [line for line in difference if line.startswith("+ ") or line.startswith("- ")]
return "".join(difference)
# datetime related
def validate_date_format(date_str):
"""Validate the given date string in the format 'YYYY-MM-DD'."""
try:
datetime.datetime.strptime(date_str, "%Y-%m-%d")
return True
except (ValueError, TypeError):
return False
def extract_date_from_timestamp(timestamp):
"""Extracts and returns the date from the given timestamp."""
# Extracts the date (ignoring the time and timezone)
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
return match.group(1) if match else None

View File

@@ -14,11 +14,54 @@ from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.connectors.chroma import ChromaStorageConnector
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
from memgpt.embeddings import embedding_model
from memgpt.data_types import Message, Passage
from memgpt.config import MemGPTConfig, AgentConfig
import argparse
def test_recall_db() -> None:
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
storage_type = "postgres"
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri)
print(config.config_path)
assert config.recall_storage_uri is not None
config.save()
print(config)
conn = StorageConnector.get_recall_storage_connector()
# construct recall memory messages
message1 = Message(
agent_id="test_agent1",
role="agent",
content="This is a test message",
id="test_id1",
)
message2 = Message(
agent_id="test_agent2",
role="user",
content="This is a test message",
id="test_id2",
)
# test insert
conn.insert(message1)
conn.insert_many([message2])
# test size
assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}"
assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}"
# test get
assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}"
assert (
len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1
), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}"
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
def test_postgres_openai():
if not os.getenv("PGVECTOR_TEST_DB_URL"):