Files
letta-server/memgpt/agent_store/db.py
2024-04-18 22:39:11 -07:00

603 lines
25 KiB
Python

import os
import base64
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime
from sqlalchemy import func, or_, and_
from sqlalchemy import desc, asc
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
from sqlalchemy.orm.session import close_all_sessions
from sqlalchemy.sql import func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy_json import mutable_json_type, MutableJson
from sqlalchemy import TypeDecorator, CHAR
import uuid
from tqdm import tqdm
from typing import Optional, List, Iterator, Dict
import numpy as np
from tqdm import tqdm
import pandas as pd
from memgpt.settings import settings
from memgpt.config import MemGPTConfig
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.utils import printd
from memgpt.data_types import Record, Message, Passage, ToolCall, RecordType
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.metadata import MetadataStore
# Custom UUID type
class CommonUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR())
def process_bind_param(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return str(value) # Convert UUID to string for SQLite
def process_result_value(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return uuid.UUID(value)
class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""
impl = BINARY
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(BINARY())
def process_bind_param(self, value, dialect):
if value is None:
return value
# Ensure value is a numpy array
if isinstance(value, list):
value = np.array(value, dtype=np.float32)
# Serialize numpy array to bytes, then encode to base64 for universal compatibility
return base64.b64encode(value.tobytes())
def process_result_value(self, value, dialect):
if not value:
return value
# Check database type and deserialize accordingly
if dialect.name == "sqlite":
# Decode from base64 and convert back to numpy array
value = base64.b64decode(value)
# For PostgreSQL, value is already in bytes
return np.frombuffer(value, dtype=np.float32)
# Custom serialization / de-serialization for JSON columns
class ToolCallColumn(TypeDecorator):
"""Custom type for storing List[ToolCall] as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
return [vars(v) for v in value]
return value
def process_result_value(self, value, dialect):
if value:
return [ToolCall(**v) for v in value]
return value
Base = declarative_base()
def get_db_model(
config: MemGPTConfig,
table_name: str,
table_type: TableType,
user_id: uuid.UUID,
agent_id: Optional[uuid.UUID] = None,
dialect="postgresql",
):
# Define a helper function to create or get the model class
def create_or_get_model(class_name, base_model, table_name):
if class_name in globals():
return globals()[class_name]
Model = type(class_name, (base_model,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
globals()[class_name] = Model
return Model
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
# create schema for archival memory
class PassageModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""
__abstract__ = True # this line is necessary
# Assuming passage_id is the primary key
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(CommonUUID, nullable=False)
text = Column(String)
doc_id = Column(CommonUUID)
agent_id = Column(CommonUUID)
data_source = Column(String) # agent_name if agent, data_source name if from data source
# vector storage
if dialect == "sqlite":
embedding = Column(CommonVector)
else:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
metadata_ = Column(MutableJson)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
doc_id=self.doc_id,
user_id=self.user_id,
id=self.id,
data_source=self.data_source,
agent_id=self.agent_id,
metadata_=self.metadata_,
created_at=self.created_at,
)
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model" + dialect
return create_or_get_model(class_name, PassageModel, table_name)
elif table_type == TableType.RECALL_MEMORY:
class MessageModel(Base):
"""Defines data model for storing Message objects"""
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(CommonUUID, nullable=False)
agent_id = Column(CommonUUID, nullable=False)
# openai info
role = Column(String, nullable=False)
text = Column(String) # optional: can be null if function call
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
name = Column(String) # optional: multi-agent only
# tool call request info
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls
tool_calls = Column(ToolCallColumn)
# tool call response info
# if role == "tool", then this must be specified
# if role != "tool", this must be null
tool_call_id = Column(String)
# vector storage
if dialect == "sqlite":
embedding = Column(CommonVector)
else:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
name=self.name,
text=self.text,
model=self.model,
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
embedding=self.embedding,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
created_at=self.created_at,
id=self.id,
)
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model" + dialect
return create_or_get_model(class_name, MessageModel, table_name)
else:
raise ValueError(f"Table type {table_type} not implemented")
class SQLStorageConnector(StorageConnector):
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
self.config = config
def get_filters(self, filters: Optional[Dict] = {}):
if filters is not None:
filter_conditions = {**self.filters, **filters}
else:
filter_conditions = self.filters
all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
return all_filters
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0) -> Iterator[List[RecordType]]:
filters = self.get_filters(filters)
while True:
# Retrieve a chunk of records with the given page_size
with self.session_maker() as session:
db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
# If the chunk is empty, we've retrieved all records
if not db_record_chunk:
break
# Yield a list of Record objects converted from the chunk
yield [record.to_record() for record in db_record_chunk]
# Increment the offset to get the next chunk in the next iteration
offset += page_size
def get_all_cursor(
self,
filters: Optional[Dict] = {},
after: uuid.UUID = None,
before: uuid.UUID = None,
limit: Optional[int] = 1000,
order_by: str = "created_at",
reverse: bool = False,
):
"""Get all that returns a cursor (record.id) and records"""
filters = self.get_filters(filters)
# generate query
with self.session_maker() as session:
query = session.query(self.db_model).filter(*filters)
# query = query.order_by(asc(self.db_model.id))
# records are sorted by the order_by field first, and then by the ID if two fields are the same
if reverse:
query = query.order_by(desc(getattr(self.db_model, order_by)), asc(self.db_model.id))
else:
query = query.order_by(asc(getattr(self.db_model, order_by)), asc(self.db_model.id))
# cursor logic: filter records based on before/after ID
if after:
after_value = getattr(self.get(id=after), order_by)
if reverse: # if reverse, then we want to get records that are less than the after_value
sort_exp = getattr(self.db_model, order_by) < after_value
else: # otherwise, we want to get records that are greater than the after_value
sort_exp = getattr(self.db_model, order_by) > after_value
query = query.filter(
or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
)
if before:
before_value = getattr(self.get(id=before), order_by)
if reverse:
sort_exp = getattr(self.db_model, order_by) > before_value
else:
sort_exp = getattr(self.db_model, order_by) < before_value
query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
# get records
db_record_chunk = query.limit(limit).all()
if not db_record_chunk:
return (None, [])
records = [record.to_record() for record in db_record_chunk]
next_cursor = db_record_chunk[-1].id
assert isinstance(next_cursor, uuid.UUID)
# return (cursor, list[records])
return (next_cursor, records)
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
filters = self.get_filters(filters)
with self.session_maker() as session:
if limit:
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
else:
db_records = session.query(self.db_model).filter(*filters).all()
return [record.to_record() for record in db_records]
def get(self, id: uuid.UUID) -> Optional[Record]:
with self.session_maker() as session:
db_record = session.get(self.db_model, id)
if db_record is None:
return None
return db_record.to_record()
def size(self, filters: Optional[Dict] = {}) -> int:
# return size of table
filters = self.get_filters(filters)
with self.session_maker() as session:
return session.query(self.db_model).filter(*filters).count()
def insert(self, record: Record):
raise NotImplementedError
def insert_many(self, records: List[RecordType], show_progress=False):
raise NotImplementedError
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
def save(self):
return
def list_data_sources(self):
assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
with self.session_maker() as session:
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
return unique_data_sources
def query_date(self, start_date, end_date, offset=0, limit=None):
filters = self.get_filters({})
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= start_date)
.filter(self.db_model.created_at <= end_date)
.offset(offset)
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_record() for result in results]
def query_text(self, query, offset=0, limit=None):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
filters = self.get_filters({})
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
.offset(offset)
)
if limit:
query = query.limit(limit)
results = query.all()
# return [self.type(**vars(result)) for result in results]
return [result.to_record() for result in results]
# Should be used only in tests!
def delete_table(self):
close_all_sessions()
with self.session_maker() as session:
self.db_model.__table__.drop(session.bind)
session.commit()
def delete(self, filters: Optional[Dict] = {}):
filters = self.get_filters(filters)
with self.session_maker() as session:
session.query(self.db_model).filter(*filters).delete()
session.commit()
class PostgresStorageConnector(SQLStorageConnector):
"""Storage via Postgres"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
from pgvector.sqlalchemy import Vector
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
# create table
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id)
# construct URI from enviornment variables
if settings.pg_uri:
self.uri = settings.pg_uri
else:
# use config URI
# TODO: remove this eventually (config should NOT contain URI)
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
self.uri = self.config.archival_storage_uri
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
# create engine
self.engine = create_engine(self.uri)
for c in self.db_model.__table__.columns:
if c.name == "embedding":
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
self.session_maker = sessionmaker(bind=self.engine)
with self.session_maker() as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
# create table
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
filters = self.get_filters(filters)
with self.session_maker() as session:
results = session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
# Convert the results into Passage objects
records = [result.to_record() for result in results]
return records
def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.postgresql import insert
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
# print("records", db_records)
stmt = insert(self.db_model.__table__).values(db_records)
# print(stmt)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
conn.execute(upsert_stmt)
else:
conn.execute(stmt)
conn.commit()
else:
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)
def update(self, record: RecordType):
"""
Updates a record in the database based on the provided Record object.
"""
with self.session_maker() as session:
# Find the record by its ID
db_record = session.query(self.db_model).filter_by(id=record.id).first()
if not db_record:
raise ValueError(f"Record with id {record.id} does not exist.")
# Update the record with new values from the provided Record object
for attr, value in vars(record).items():
setattr(db_record, attr, value)
# Commit the changes to the database
session.commit()
class SQLLiteStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
raise ValueError(f"Table type {table_type} not implemented")
elif table_type == TableType.RECALL_MEMORY:
# TODO: eventually implement URI option
self.path = self.config.recall_storage_path
if self.path is None:
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
self.path = os.path.join(self.path, f"sqlite.db")
# Create the SQLAlchemy engine
self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite")
self.engine = create_engine(f"sqlite:///{self.path}")
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
self.session_maker = sessionmaker(bind=self.engine)
import sqlite3
sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.sqlite import insert
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
# print("records", db_records)
stmt = insert(self.db_model.__table__).values(db_records)
# print(stmt)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
)
conn.execute(upsert_stmt)
else:
conn.execute(stmt)
conn.commit()
else:
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)
def update(self, record: Record):
"""
Updates an existing record in the database with values from the provided record object.
"""
if not record.id:
raise ValueError("Record must have an id.")
with self.session_maker() as session:
# Fetch the existing record from the database
db_record = session.query(self.db_model).filter_by(id=record.id).first()
if not db_record:
raise ValueError(f"Record with id {record.id} does not exist.")
# Update the database record with values from the provided record object
for column in self.db_model.__table__.columns:
column_name = column.name
if hasattr(record, column_name):
new_value = getattr(record, column_name)
setattr(db_record, column_name, new_value)
# Commit the changes to the database
session.commit()