184 lines
7.2 KiB
Python
184 lines
7.2 KiB
Python
from pgvector.psycopg import register_vector
|
|
from pgvector.sqlalchemy import Vector
|
|
import psycopg
|
|
|
|
|
|
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text
|
|
from sqlalchemy.orm import sessionmaker, mapped_column
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.sql import func
|
|
|
|
import re
|
|
from tqdm import tqdm
|
|
from typing import Optional, List, Iterator
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from memgpt.config import MemGPTConfig
|
|
from memgpt.connectors.storage import StorageConnector, Passage
|
|
from memgpt.config import AgentConfig, MemGPTConfig
|
|
from memgpt.constants import MEMGPT_DIR
|
|
from memgpt.utils import printd
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
def get_db_model(table_name: str):
|
|
config = MemGPTConfig.load()
|
|
|
|
class PassageModel(Base):
|
|
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
|
|
|
__abstract__ = True # this line is necessary
|
|
|
|
# Assuming passage_id is the primary key
|
|
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
|
doc_id = Column(String)
|
|
text = Column(String, nullable=False)
|
|
embedding = mapped_column(Vector(config.embedding_dim))
|
|
# metadata_ = Column(JSON(astext_type=Text()))
|
|
|
|
def __repr__(self):
|
|
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
|
|
|
"""Create database model for table_name"""
|
|
class_name = f"{table_name.capitalize()}Model"
|
|
Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
|
|
return Model
|
|
|
|
|
|
class PostgresStorageConnector(StorageConnector):
|
|
"""Storage via Postgres"""
|
|
|
|
# TODO: this should probably eventually be moved into a parent DB class
|
|
|
|
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
|
config = MemGPTConfig.load()
|
|
|
|
# determine table name
|
|
if agent_config:
|
|
assert name is None, f"Cannot specify both agent config and name {name}"
|
|
self.table_name = self.generate_table_name_agent(agent_config)
|
|
elif name:
|
|
assert agent_config is None, f"Cannot specify both agent config and name {name}"
|
|
self.table_name = self.generate_table_name(name)
|
|
else:
|
|
raise ValueError("Must specify either agent config or name")
|
|
|
|
printd(f"Using table name {self.table_name}")
|
|
|
|
# create table
|
|
self.uri = config.archival_storage_uri
|
|
if config.archival_storage_uri is None:
|
|
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
|
self.db_model = get_db_model(self.table_name)
|
|
self.engine = create_engine(self.uri)
|
|
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
|
self.Session = sessionmaker(bind=self.engine)
|
|
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
|
|
|
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
|
|
session = self.Session()
|
|
offset = 0
|
|
while True:
|
|
# Retrieve a chunk of records with the given page_size
|
|
db_passages_chunk = session.query(self.db_model).offset(offset).limit(page_size).all()
|
|
|
|
# If the chunk is empty, we've retrieved all records
|
|
if not db_passages_chunk:
|
|
break
|
|
|
|
# Yield a list of Passage objects converted from the chunk
|
|
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]
|
|
|
|
# Increment the offset to get the next chunk in the next iteration
|
|
offset += page_size
|
|
|
|
def get_all(self, limit=10) -> List[Passage]:
|
|
session = self.Session()
|
|
db_passages = session.query(self.db_model).limit(limit).all()
|
|
return [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages]
|
|
|
|
def get(self, id: str) -> Optional[Passage]:
|
|
session = self.Session()
|
|
db_passage = session.query(self.db_model).get(id)
|
|
if db_passage is None:
|
|
return None
|
|
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)
|
|
|
|
def size(self) -> int:
|
|
# return size of table
|
|
session = self.Session()
|
|
return session.query(self.db_model).count()
|
|
|
|
def insert(self, passage: Passage):
|
|
session = self.Session()
|
|
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
|
|
session.add(db_passage)
|
|
session.commit()
|
|
|
|
def insert_many(self, passages: List[Passage], show_progress=True):
|
|
session = self.Session()
|
|
iterable = tqdm(passages) if show_progress else passages
|
|
for passage in iterable:
|
|
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
|
|
session.add(db_passage)
|
|
session.commit()
|
|
|
|
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
|
|
session = self.Session()
|
|
# Assuming PassageModel.embedding has the capability of computing l2_distance
|
|
results = session.scalars(select(self.db_model).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)).all()
|
|
|
|
# Convert the results into Passage objects
|
|
passages = [
|
|
Passage(text=result.text, embedding=np.frombuffer(result.embedding), doc_id=result.doc_id, passage_id=result.id)
|
|
for result in results
|
|
]
|
|
return passages
|
|
|
|
def delete(self):
|
|
"""Drop the passage table from the database."""
|
|
# Bind the engine to the metadata of the base class so that the
|
|
# declaratives can be accessed through a DBSession instance
|
|
Base.metadata.bind = self.engine
|
|
|
|
# Drop the table specified by the PassageModel class
|
|
self.db_model.__table__.drop(self.engine)
|
|
|
|
def save(self):
|
|
return
|
|
|
|
@staticmethod
|
|
def list_loaded_data():
|
|
config = MemGPTConfig.load()
|
|
engine = create_engine(config.archival_storage_uri)
|
|
inspector = inspect(engine)
|
|
tables = inspector.get_table_names()
|
|
tables = [table for table in tables if table.startswith("memgpt_") and not table.startswith("memgpt_agent_")]
|
|
tables = [table.replace("memgpt_", "") for table in tables]
|
|
return tables
|
|
|
|
def sanitize_table_name(self, name: str) -> str:
|
|
# Remove leading and trailing whitespace
|
|
name = name.strip()
|
|
|
|
# Replace spaces and invalid characters with underscores
|
|
name = re.sub(r"\s+|\W+", "_", name)
|
|
|
|
# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
|
|
max_length = 63
|
|
if len(name) > max_length:
|
|
name = name[:max_length].rstrip("_")
|
|
|
|
# Convert to lowercase
|
|
name = name.lower()
|
|
|
|
return name
|
|
|
|
def generate_table_name_agent(self, agent_config: AgentConfig):
|
|
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
|
|
|
|
def generate_table_name(self, name: str):
|
|
return f"memgpt_{self.sanitize_table_name(name)}"
|