diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 0c10f4f2..9c833968 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -7,10 +7,11 @@ from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, t from sqlalchemy.orm import sessionmaker, mapped_column from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import func +from sqlalchemy import Column, BIGINT, String, DateTime import re from tqdm import tqdm -from typing import Optional, List, Iterator +from typing import Optional, List, Iterator, Dict import numpy as np from tqdm import tqdm import pandas as pd @@ -20,10 +21,18 @@ from memgpt.connectors.storage import StorageConnector, TableType from memgpt.config import AgentConfig, MemGPTConfig from memgpt.constants import MEMGPT_DIR from memgpt.utils import printd +from memgpt.data_types import Record, Message, Passage + +from datetime import datetime Base = declarative_base() +def parse_formatted_time(formatted_time): + # parse times returned by memgpt.utils.get_formatted_time() + return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z") + + def get_db_model(table_name: str, table_type: TableType): config = MemGPTConfig.load() @@ -37,6 +46,8 @@ def get_db_model(table_name: str, table_type: TableType): # Assuming passage_id is the primary key id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True) doc_id = Column(String) + agent_id = Column(String) + data_source = Column(String) # agent_name if agent, data_source name if from data source text = Column(String, nullable=False) embedding = mapped_column(Vector(config.embedding_dim)) metadata_ = Column(JSON(astext_type=Text())) @@ -48,9 +59,37 @@ def get_db_model(table_name: str, table_type: TableType): class_name = f"{table_name.capitalize()}Model" Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}}) return Model + elif table_type == TableType.RECALL_MEMORY: + + class MessageModel(Base): + """Defines data model for storing Message objects""" + + __abstract__ = True # this line is necessary + + # Assuming message_id is the primary key + id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True) + user_id = Column(String, nullable=False) + agent_id = Column(String, nullable=False) + role = Column(String, nullable=False) + content = Column(String, nullable=False) + model = Column(String, nullable=False) + function_name = Column(String) + function_args = Column(String) + function_response = Column(String) + embedding = mapped_column(Vector(config.embedding_dim)) + + # Add a datetime column, with default value as the current time + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + def __repr__(self): + return f" List[Passage]: + def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]: session = self.Session() - db_passages = session.query(self.db_model).limit(limit).all() - return [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages] + filters = self.get_filters(filters) + db_passages = session.query(self.db_model).filter(*filters).limit(limit).all() + return [self.type(**p.to_dict()) for p in db_passages] - def get(self, id: str) -> Optional[Passage]: + def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Passage]: session = self.Session() - db_passage = session.query(self.db_model).get(id) + filters = self.get_filters(filters) + db_passage = session.query(self.db_model).filter(*filters).get(id) if db_passage is None: return None return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id) - def size(self) -> int: + def size(self, filters: Optional[Dict] = {}) -> int: # return size of table session = self.Session() - return session.query(self.db_model).count() + filters = self.get_filters(filters) + return session.query(self.db_model).filter(*filters).count() def insert(self, passage: Passage): session = self.Session() @@ -123,38 +155,35 @@ class PostgresStorageConnector(StorageConnector): session.add(db_passage) session.commit() - def insert_many(self, passages: List[Passage], show_progress=True): + def insert_many(self, records: List[Record], show_progress=True): session = self.Session() - iterable = tqdm(passages) if show_progress else passages + iterable = tqdm(records) if show_progress else records for passage in iterable: db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding) session.add(db_passage) session.commit() - def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]: + def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: session = self.Session() # Assuming PassageModel.embedding has the capability of computing l2_distance - results = session.scalars(select(self.db_model).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)).all() + filters = self.get_filters(filters) + results = session.scalars( + select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k) + ).all() # Convert the results into Passage objects - passages = [ - Passage(text=result.text, embedding=np.frombuffer(result.embedding), doc_id=result.doc_id, passage_id=result.id) - for result in results - ] - return passages - - def delete(self): - """Drop the passage table from the database.""" - # Bind the engine to the metadata of the base class so that the - # declaratives can be accessed through a DBSession instance - Base.metadata.bind = self.engine - - # Drop the table specified by the PassageModel class - self.db_model.__table__.drop(self.engine) + records = [self.type(**vars(result)) for result in results] + return records def save(self): return + def list_data_sources(self): + assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY" + session = self.Session() + unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all() + return unique_data_sources + @staticmethod def list_loaded_data(): config = MemGPTConfig.load() @@ -166,29 +195,6 @@ class PostgresStorageConnector(StorageConnector): tables = [table[start_chars:] for table in tables] return tables - def sanitize_table_name(self, name: str) -> str: - # Remove leading and trailing whitespace - name = name.strip() - - # Replace spaces and invalid characters with underscores - name = re.sub(r"\s+|\W+", "_", name) - - # Truncate to the maximum identifier length (e.g., 63 for PostgreSQL) - max_length = 63 - if len(name) > max_length: - name = name[:max_length].rstrip("_") - - # Convert to lowercase - name = name.lower() - - return name - - def generate_table_name_agent(self, agent_config: AgentConfig): - return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}" - - def generate_table_name(self, name: str): - return f"memgpt_{self.sanitize_table_name(name)}" - class LanceDBConnector(StorageConnector): """Storage via LanceDB""" @@ -309,26 +315,3 @@ class LanceDBConnector(StorageConnector): start_chars = len("memgpt_") tables = [table[start_chars:] for table in tables] return tables - - def sanitize_table_name(self, name: str) -> str: - # Remove leading and trailing whitespace - name = name.strip() - - # Replace spaces and invalid characters with underscores - name = re.sub(r"\s+|\W+", "_", name) - - # Truncate to the maximum identifier length - max_length = 63 - if len(name) > max_length: - name = name[:max_length].rstrip("_") - - # Convert to lowercase - name = name.lower() - - return name - - def generate_table_name_agent(self, agent_config: AgentConfig): - return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}" - - def generate_table_name(self, name: str): - return f"memgpt_{self.sanitize_table_name(name)}" diff --git a/memgpt/connectors/local.py b/memgpt/connectors/local.py index e80efac5..1566df55 100644 --- a/memgpt/connectors/local.py +++ b/memgpt/connectors/local.py @@ -14,6 +14,7 @@ from llama_index.retrievers import VectorIndexRetriever from llama_index.schema import TextNode from memgpt.constants import MEMGPT_DIR +from memgpt.data_types import Record from memgpt.config import MemGPTConfig from memgpt.connectors.storage import StorageConnector, Passage from memgpt.config import AgentConfig, MemGPTConfig @@ -137,3 +138,50 @@ class VectorIndexStorageConnector(StorageConnector): def size(self): return len(self.get_nodes()) + + +class InMemoryStorageConnector(StorageConnector): + """Really dumb class so we can have a unified storae connector interface - keeps everything in memory""" + + # TODO: maybae replace this with sqllite? + + def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): + from memgpt.embeddings import embedding_model + + config = MemGPTConfig.load() + # TODO: figure out save location + + self.rows = [] + + @abstractmethod + def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: + pass + + @abstractmethod + def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]: + pass + + @abstractmethod + def get(self, id: str) -> Record: + pass + + @abstractmethod + def insert(self, record: Record): + self.rows.append(record) + + @abstractmethod + def insert_many(self, records: List[Record]): + self.rows += records + + @abstractmethod + def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: + pass + + @abstractmethod + def save(self): + """Save state of storage connector""" + pass + + @abstractmethod + def size(self, filters: Optional[Dict] = {}) -> int: + pass diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 5e64f39b..ed2d37eb 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -8,74 +8,137 @@ import pickle import os -from typing import List, Optional -from abc import abstractmethod -import numpy as np +from typing import List, Optional, Dict from tqdm import tqdm from memgpt.config import AgentConfig, MemGPTConfig -from memgpt.data_types import Record +from memgpt.data_types import Record, Passage, Document, Message # ENUM representing table types in MemGPT +# each table corresponds to a different table schema (specified in data_types.py) class TableType: ARCHIVAL_MEMORY = "archival_memory" # recall memory table: memgpt_agent_{agent_id} RECALL_MEMORY = "recall_memory" # archival memory table: memgpt_agent_recall_{agent_id} - DOCUMENTS = "documents" - USERS = "users" - AGENTS = "agents" + PASSAGES = "passages" # TODO + DOCUMENTS = "documents" # TODO + USERS = "users" # TODO + AGENTS = "agents" # TODO -# Defining schema objects: -# Note: user/agent can borrow from MemGPTConfig/AgentConfig classes +# table names used by MemGPT +RECALL_TABLE_NAME = "memgpt_recall_memory" +ARCHIVAL_TABLE_NAME = "memgpt_archival_memory" +PASSAGE_TABLE_NAME = "memgpt_passages" +DOCUMENT_TABLE_NAME = "memgpt_documents" class StorageConnector: + def __init__(self, table_type: TableType, agent_config: Optional[AgentConfig] = None): + + config = MemGPTConfig.load() + self.agent_config = agent_config + self.user_id = config.anon_clientid + self.table_type = table_type + + # get object type + if table_type == TableType.ARCHIVAL_MEMORY: + self.type = Passage + elif table_type == TableType.RECALL_MEMORY: + self.type = Message + else: + raise ValueError(f"Table type {table_type} not implemented") + + # determine name of database table + self.table_name = self.generate_table_name(agent_config, table_type=table_type) + printd(f"Using table name {self.table_name}") + + # setup base filters + if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY: + # agent-specific table + self.filters = {"user_id": self.user_id, "agent_id": self.agent_config.name} + else: + self.filters = {"user_id": self.user_id} + + def get_filters(self, filters: Optional[Dict] = {}): + # get all filters for query + if filters is not None: + filter_conditions = {**self.filters, **filters} + return self.filters + [self.db_model[key] == value for key, value in filter_conditions.items()] + else: + return self.filters + + def generate_table_name(self, agent_config: AgentConfig, table_type: TableType): + + if agent_config is not None: + # Table names for agent-specific tables + if agent_config.memgpt_version < "0.2.6": + # if agent is prior version, use old table name + if table_type == TableType.ARCHIVAL_MEMORY: + return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}" + else: + raise ValueError(f"Table type {table_type} not implemented") + else: + if table_type == TableType.ARCHIVAL_MEMORY: + return ARCHIVAL_TABLE_NAME + elif table_type == TableType.RECALL_MEMORY: + return RECALL_TABLE_NAME + else: + raise ValueError(f"Table type {table_type} not implemented") + else: + # table names for non-agent specific tables + if table_type == TableType.PASSAGES: + return PASSAGE_TABLE_NAME + elif table_type == TableType.DOCUMENTS: + return DOCUMENT_TABLE_NAME + else: + raise ValueError(f"Table type {table_type} not implemented") + @staticmethod - def get_archival_storage_connector(name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): + def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None): storage_type = MemGPTConfig.load().archival_storage_type if storage_type == "local": from memgpt.connectors.local import VectorIndexStorageConnector - return VectorIndexStorageConnector(name=name, agent_config=agent_config) + return VectorIndexStorageConnector(agent_config=agent_config) elif storage_type == "postgres": from memgpt.connectors.db import PostgresStorageConnector - return PostgresStorageConnector(name=name, agent_config=agent_config) - elif storage_type == "chroma": - from memgpt.connectors.chroma import ChromaStorageConnector + return PostgresStorageConnector(agent_config=agent_config) return ChromaStorageConnector(name=name, agent_config=agent_config) elif storage_type == "lancedb": from memgpt.connectors.db import LanceDBConnector - return LanceDBConnector(name=name, agent_config=agent_config) + return LanceDBConnector(agent_config=agent_config) + else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @staticmethod - def get_recall_storage_connector(name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): + def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None): storage_type = MemGPTConfig.load().recall_storage_type if storage_type == "local": - from memgpt.connectors.local import VectorIndexStorageConnector + from memgpt.connectors.local import InMemoryStorageConnector # maintains in-memory list for storage - return InMemoryStorageConnector(name=name, agent_config=agent_config) + return InMemoryStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY) elif storage_type == "postgres": from memgpt.connectors.db import PostgresStorageConnector - return PostgresStorageConnector(name=name, agent_config=agent_config) + return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.RECALL_MEMORY) else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @staticmethod def list_loaded_data(): + # TODO: modify this to simply list loaded data from a given user storage_type = MemGPTConfig.load().archival_storage_type if storage_type == "local": from memgpt.connectors.local import VectorIndexStorageConnector @@ -97,11 +160,11 @@ class StorageConnector: raise NotImplementedError(f"Storage type {storage_type} not implemented") @abstractmethod - def get_all_paginated(self, page_size: int) -> Iterator[List[Record]]: + def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: pass @abstractmethod - def get_all(self, limit: int) -> List[Record]: + def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]: pass @abstractmethod @@ -109,15 +172,15 @@ class StorageConnector: pass @abstractmethod - def insert(self, passage: Record): + def insert(self, record: Record): pass @abstractmethod - def insert_many(self, passages: List[Record]): + def insert_many(self, records: List[Record]): pass @abstractmethod - def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Record]: + def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: pass @abstractmethod @@ -126,6 +189,6 @@ class StorageConnector: pass @abstractmethod - def size(self): + def size(self, filters: Optional[Dict] = {}) -> int: """Get number of passages (text/embedding pairs) in storage""" pass diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 0b6dcf29..e408f723 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -35,7 +35,7 @@ class Message(Record): user_id: str, agent_id: str, role: str, - text: str, + content: str, model: str, # model used to make function call function_name: Optional[str] = None, # name of function called function_args: Optional[str] = None, # args of function called @@ -43,7 +43,7 @@ class Message(Record): embedding: Optional[np.ndarray] = None, id: Optional[str] = None, ): - super().__init__(user_id, agent_id, text, id) + super().__init__(user_id, agent_id, content, id) self.role = role # role (agent/user/function) self.model = model # model name (e.g. gpt-4) @@ -62,10 +62,11 @@ class Message(Record): class Document(Record): """A document represent a document loaded into MemGPT, which is broken down into passages.""" - def __init__(self, user_id: str, text: str, document_id: Optional[str] = None): + def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None): super().__init__(user_id) self.text = text self.document_id = document_id + self.data_source = data_source # TODO: add optional embedding? def __repr__(self) -> str: @@ -78,9 +79,18 @@ class Passage(Record): It is a string of text with an associated embedding. """ - def __init__(self, user_id: str, text: str, embedding: np.ndarray, doc_id: Optional[str] = None, passage_id: Optional[str] = None): + def __init__( + self, + user_id: str, + text: str, + data_source: str, + embedding: np.ndarray, + doc_id: Optional[str] = None, + passage_id: Optional[str] = None, + ): super().__init__(user_id) self.text = text + self.data_source = data_source self.embedding = embedding self.doc_id = doc_id self.passage_id = passage_id diff --git a/memgpt/memory.py b/memgpt/memory.py index 09673849..0ac27aab 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -4,11 +4,10 @@ import re from typing import Optional, List, Tuple from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC -from memgpt.utils import get_local_time, printd, count_tokens +from memgpt.utils import get_local_time, printd, count_tokens, validate_date_format, extract_date_from_timestamp from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from memgpt.openai_tools import create -from memgpt.config import MemGPTConfig -from memgpt.embeddings import embedding_model +from memgpt.data_types import Message, Passage from llama_index import Document from llama_index.node_parser import SimpleNodeParser from llama_index.node_parser import SimpleNodeParser @@ -137,7 +136,7 @@ def summarize_messages( class ArchivalMemory(ABC): @abstractmethod - def insert(self, memory_string): + def insert(self, memory_string: str): """Insert new archival memory :param memory_string: Memory string to insert @@ -178,6 +177,10 @@ class RecallMemory(ABC): def __repr__(self) -> str: pass + @abstractmethod + def insert(self, message: Message): + pass + class DummyRecallMemory(RecallMemory): """Dummy in-memory version of a recall memory database (eg run on MongoDB) @@ -189,29 +192,12 @@ class DummyRecallMemory(RecallMemory): effectively allowing it to 'remember' prior engagements with a user. """ - # TODO: replace this with StorageConnector based implementation - - def __init__(self, agent_config, restrict_search_to_summaries=False): + def __init__(self, message_database=None, restrict_search_to_summaries=False): + self._message_logs = [] if message_database is None else message_database # consists of full message dicts # If true, the pool of messages that can be queried are the automated summaries only # (generated when the conversation window needs to be shortened) self.restrict_search_to_summaries = restrict_search_to_summaries - from memgpt.connectors.storage import StorageConnector - - self.top_k = top_k - self.agent_config = agent_config - config = MemGPTConfig.load() - - # create embedding model - self.embed_model = embedding_model() - self.embedding_chunk_size = config.embedding_chunk_size - - # create storage backend - self.storage = StorageConnector.get_archival_storage_connector( - agent_config=agent_config, table_type="recall_memory" # TODO: change to enum - ) - # TODO: have some mechanism for cleanup otherwise will lead to OOM - self.cache = {} def __len__(self): return len(self._message_logs) @@ -267,25 +253,11 @@ class DummyRecallMemory(RecallMemory): else: return matches, len(matches) - def _validate_date_format(self, date_str): - """Validate the given date string in the format 'YYYY-MM-DD'.""" - try: - datetime.datetime.strptime(date_str, "%Y-%m-%d") - return True - except (ValueError, TypeError): - return False - - def _extract_date_from_timestamp(self, timestamp): - """Extracts and returns the date from the given timestamp.""" - # Extracts the date (ignoring the time and timezone) - match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp) - return match.group(1) if match else None - def date_search(self, start_date, end_date, count=None, start=None): message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]] # First, validate the start_date and end_date format - if not self._validate_date_format(start_date) or not self._validate_date_format(end_date): + if not validate_date_format(start_date) or not validate_date_format(end_date): raise ValueError("Invalid date format. Expected format: YYYY-MM-DD") # Convert dates to datetime objects for comparison @@ -296,7 +268,7 @@ class DummyRecallMemory(RecallMemory): matches = [ d for d in message_pool - if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt + if start_date_dt <= datetime.datetime.strptime(extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt ] # start/count support paging through results @@ -312,6 +284,42 @@ class DummyRecallMemory(RecallMemory): return matches, len(matches) +class RecallMemorySQL(RecallMemory): + def __init__(self, agent_config, restrict_search_to_summaries=False): + + # If true, the pool of messages that can be queried are the automated summaries only + # (generated when the conversation window needs to be shortened) + self.restrict_search_to_summaries = restrict_search_to_summaries + from memgpt.connectors.storage import StorageConnector + + self.agent_config = agent_config + config = MemGPTConfig.load() + + # create embedding model + self.embed_model = embedding_model() + self.embedding_chunk_size = config.embedding_chunk_size + + # create storage backend + self.storage = StorageConnector.get_recall_storage_connector(agent_config=agent_config) + # TODO: have some mechanism for cleanup otherwise will lead to OOM + self.cache = {} + + @abstractmethod + def text_search(self, query_string, count=None, start=None): + pass + + @abstractmethod + def date_search(self, query_string, count=None, start=None): + pass + + @abstractmethod + def __repr__(self) -> str: + pass + + def insert(self, message: Message): + pass + + class EmbeddingArchivalMemory(ArchivalMemory): """Archival memory with embedding based search""" diff --git a/memgpt/utils.py b/memgpt/utils.py index 8e864a42..b0e858a4 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -1,13 +1,5 @@ from datetime import datetime -import json -import os -import pickle -import platform -import subprocess -import sys -import io -from contextlib import contextmanager - +import re import difflib import demjson3 as demjson import pytz @@ -288,3 +280,20 @@ def get_schema_diff(schema_a, schema_b): difference = [line for line in difference if line.startswith("+ ") or line.startswith("- ")] return "".join(difference) + + +# datetime related +def validate_date_format(date_str): + """Validate the given date string in the format 'YYYY-MM-DD'.""" + try: + datetime.datetime.strptime(date_str, "%Y-%m-%d") + return True + except (ValueError, TypeError): + return False + + +def extract_date_from_timestamp(timestamp): + """Extracts and returns the date from the given timestamp.""" + # Extracts the date (ignoring the time and timezone) + match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp) + return match.group(1) if match else None diff --git a/tests/test_storage.py b/tests/test_storage.py index fc941fa1..6688a550 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -14,11 +14,54 @@ from memgpt.connectors.storage import StorageConnector, Passage from memgpt.connectors.chroma import ChromaStorageConnector from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector from memgpt.embeddings import embedding_model +from memgpt.data_types import Message, Passage from memgpt.config import MemGPTConfig, AgentConfig import argparse +def test_recall_db() -> None: + # os.environ["MEMGPT_CONFIG_PATH"] = "./config" + + storage_type = "postgres" + storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri) + print(config.config_path) + assert config.recall_storage_uri is not None + config.save() + print(config) + + conn = StorageConnector.get_recall_storage_connector() + + # construct recall memory messages + message1 = Message( + agent_id="test_agent1", + role="agent", + content="This is a test message", + id="test_id1", + ) + message2 = Message( + agent_id="test_agent2", + role="user", + content="This is a test message", + id="test_id2", + ) + + # test insert + conn.insert(message1) + conn.insert_many([message2]) + + # test size + assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}" + assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}" + + # test get + assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}" + assert ( + len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1 + ), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}" + + @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key") def test_postgres_openai(): if not os.getenv("PGVECTOR_TEST_DB_URL"):