diff --git a/memgpt/config.py b/memgpt/config.py index e6342095..2c41b055 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -70,12 +70,12 @@ class MemGPTConfig: # database configs: archival archival_storage_type: str = "local" # local, db - archival_storage_path: str = None # TODO: set to memgpt dir + archival_storage_path: str = MEMGPT_DIR # TODO: set to memgpt dir archival_storage_uri: str = None # TODO: eventually allow external vector DB # database configs: recall recall_storage_type: str = "local" # local, db - recall_storage_path: str = None # TODO: set to memgpt dir + recall_storage_path: str = MEMGPT_DIR recall_storage_uri: str = None # TODO: eventually allow external vector DB # database configs: agent state diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 29974bc7..ce82c55e 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -1,9 +1,10 @@ from pgvector.psycopg import register_vector +import os from pgvector.sqlalchemy import Vector import psycopg -from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text +from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON from sqlalchemy import func from sqlalchemy.orm import sessionmaker, mapped_column from sqlalchemy.ext.declarative import declarative_base @@ -11,6 +12,7 @@ from sqlalchemy.sql import func from sqlalchemy import Column, BIGINT, String, DateTime from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy_json import mutable_json_type +from sqlalchemy import TypeDecorator, CHAR import uuid import re @@ -29,6 +31,31 @@ from memgpt.data_types import Record, Message, Passage from datetime import datetime + +# Custom UUID type +class CommonUUID(TypeDecorator): + + impl = CHAR + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(UUID(as_uuid=True)) + else: + return dialect.type_descriptor(CHAR()) + + def process_bind_param(self, value, dialect): + if dialect.name == "postgresql" or value is None: + return value + else: + return str(value) # Convert UUID to string for SQLite + + def process_result_value(self, value, dialect): + if dialect.name == "postgresql" or value is None: + return value + else: + return uuid.UUID(value) + + Base = declarative_base() @@ -43,7 +70,9 @@ def get_db_model(table_name: str, table_type: TableType): __abstract__ = True # this line is necessary # Assuming passage_id is the primary key - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + # id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) user_id = Column(String, nullable=False) text = Column(String, nullable=False) doc_id = Column(String) @@ -79,7 +108,9 @@ def get_db_model(table_name: str, table_type: TableType): __abstract__ = True # this line is necessary # Assuming message_id is the primary key - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + # id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) user_id = Column(String, nullable=False) agent_id = Column(String, nullable=False) @@ -127,34 +158,9 @@ def get_db_model(table_name: str, table_type: TableType): class SQLStorageConnector(StorageConnector): - """Storage via Postgres""" - - # TODO: this should probably eventually be moved into a parent DB class - def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): super().__init__(table_type=table_type, agent_config=agent_config) - config = MemGPTConfig.load() - - # TODO: only support recall memory (need postgres for archival) - - # get storage URI - if table_type == TableType.ARCHIVAL_MEMORY: - self.uri = config.archival_storage_uri - if config.archival_storage_uri is None: - raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}") - elif table_type == TableType.RECALL_MEMORY: - self.uri = config.recall_storage_uri - if config.recall_storage_uri is None: - raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}") - else: - raise ValueError(f"Table type {table_type} not implemented") - - # create table - self.db_model = get_db_model(self.table_name, table_type) - self.engine = create_engine(self.uri) - Base.metadata.create_all(self.engine) # Create the table if it doesn't exist - self.Session = sessionmaker(bind=self.engine) - self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension + self.config = MemGPTConfig.load() def get_filters(self, filters: Optional[Dict] = {}): if filters is not None: @@ -279,6 +285,23 @@ class PostgresStorageConnector(SQLStorageConnector): def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): super().__init__(table_type=table_type, agent_config=agent_config) + + # get storage URI + if table_type == TableType.ARCHIVAL_MEMORY: + self.uri = self.config.archival_storage_uri + if self.config.archival_storage_uri is None: + raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}") + elif table_type == TableType.RECALL_MEMORY: + self.uri = self.config.recall_storage_uri + if self.config.recall_storage_uri is None: + raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}") + else: + raise ValueError(f"Table type {table_type} not implemented") + # create table + self.db_model = get_db_model(self.table_name, table_type) + self.engine = create_engine(self.uri) + Base.metadata.create_all(self.engine) # Create the table if it doesn't exist + self.Session = sessionmaker(bind=self.engine) self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: @@ -320,6 +343,36 @@ class PostgresStorageConnector(SQLStorageConnector): return records +class SQLLiteStorageConnector(SQLStorageConnector): + def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): + super().__init__(table_type=table_type, agent_config=agent_config) + + # get storage URI + if table_type == TableType.ARCHIVAL_MEMORY: + raise ValueError(f"Table type {table_type} not implemented") + elif table_type == TableType.RECALL_MEMORY: + # TODO: eventually implement URI option + self.path = self.config.recall_storage_path + if self.path is None: + raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}") + else: + raise ValueError(f"Table type {table_type} not implemented") + + self.path = os.path.join(self.path, f"{self.table_name}.db") + self.db_model = get_db_model(self.table_name, table_type) + + # Create the SQLAlchemy engine + self.db_model = get_db_model(self.table_name, table_type) + self.engine = create_engine(f"sqlite:///{self.path}") + Base.metadata.create_all(self.engine) # Create the table if it doesn't exist + self.Session = sessionmaker(bind=self.engine) + + import sqlite3 + + sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le) + sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) + + class LanceDBConnector(StorageConnector): """Storage via LanceDB""" diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 764ce1cc..cf978490 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -122,6 +122,11 @@ class StorageConnector: return InMemoryStorageConnector(agent_config=agent_config, table_type=table_type) + elif storage_type == "sqllite": + from memgpt.connectors.db import SQLLiteStorageConnector + + return SQLLiteStorageConnector(agent_config=agent_config, table_type=table_type) + else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @@ -144,6 +149,7 @@ class StorageConnector: if storage_type == "local": from memgpt.connectors.local import VectorIndexStorageConnector + # TODO: remove return VectorIndexStorageConnector.list_loaded_data() elif storage_type == "postgres": from memgpt.connectors.db import PostgresStorageConnector diff --git a/memgpt/data_types.py b/memgpt/data_types.py index ddd0cdba..b14d0f80 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -23,6 +23,8 @@ class Record: self.id = uuid.uuid4() else: self.id = id + + assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type" # todo: generate unique uuid # todo: self.role = role (?) @@ -78,8 +80,8 @@ class Document(Record): self.data_source = data_source # TODO: add optional embedding? - def __repr__(self) -> str: - pass + # def __repr__(self) -> str: + # pass class Passage(Record): @@ -106,5 +108,5 @@ class Passage(Record): self.doc_id = doc_id self.metadata = metadata - def __repr__(self): - return str(vars(self)) + # def __repr__(self): + # pass diff --git a/tests/test_storage.py b/tests/test_storage.py index 3ce717d0..8c9d57c0 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -56,8 +56,8 @@ def generate_messages(): return messages -@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"]) -# @pytest.mark.parametrize("storage_connector", ["postgres"]) +@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqllite", "lancedb"]) +# @pytest.mark.parametrize("storage_connector", ["sqllite"]) @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY]) def test_storage(storage_connector, table_type): @@ -86,9 +86,9 @@ def test_storage(storage_connector, table_type): return config.archival_storage_type = "chroma" config.archival_storage_path = "./test_chroma" - if storage_connector == "local": + if storage_connector == "sqllite": if table_type == TableType.ARCHIVAL_MEMORY: - print("Skipping test, local only supported for recall memory") + print("Skipping test, sqllite only supported for recall memory") return config.recall_storage_type = "local"