diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 4e3a662d..eec7e0ac 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -351,6 +351,9 @@ class PostgresStorageConnector(SQLStorageConnector): # TODO: this should probably eventually be moved into a parent DB class def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): + + from pgvector.sqlalchemy import Vector + super().__init__(table_type=table_type, agent_config=agent_config) # get storage URI @@ -371,6 +374,10 @@ class PostgresStorageConnector(SQLStorageConnector): # create table self.db_model = get_db_model(self.table_name, table_type) self.engine = create_engine(self.uri) + for c in self.db_model.__table__.columns: + print(c.name, c.type) + if c.name == "embedding": + assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}" Base.metadata.create_all(self.engine) # Create the table if it doesn't exist self.Session = sessionmaker(bind=self.engine) self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension diff --git a/poetry.lock b/poetry.lock index 24eaddac..e455ef60 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5016,4 +5016,4 @@ server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "bd3fce0f7b5d6ef093cee74cbe9772344b4c576313cad970d48411af5a327dc5" +content-hash = "8bdeced6b44f57a5a1bc544d1615205cf002505f173a3561e69355df863a4f98" diff --git a/pyproject.toml b/pyproject.toml index 76f3a6df..c9567926 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ pydantic = "^2.5.2" pyautogen = {version = "0.2.0", optional = true} html2text = "^2020.1.16" docx2txt = "^0.8" -sqlalchemy = "^2.0.23" +sqlalchemy = "^2.0.25" pexpect = {version = "^4.9.0", optional = true} [tool.poetry.extras]