Files
matrix-bridge-legacy/sqlite_crypto_store.py
2026-03-28 23:50:54 -04:00

486 lines
20 KiB
Python

# sqlite_crypto_store.py
"""SQLite-backed CryptoStore for mautrix-python"""
from __future__ import annotations
import pickle
import aiosqlite
from pathlib import Path
from contextlib import asynccontextmanager
from mautrix.types import (
CrossSigner,
CrossSigningUsage,
DeviceID,
DeviceIdentity,
EventID,
IdentityKey,
RoomID,
SessionID,
SigningKey,
SyncToken,
TOFUSigningKey,
UserID,
)
from mautrix.crypto.account import OlmAccount
from mautrix.crypto.sessions import InboundGroupSession, OutboundGroupSession, Session
from mautrix.crypto.store.abstract import CryptoStore
from mautrix.client.state_store import SyncStore
class SQLiteCryptoStore(CryptoStore, SyncStore):
"""SQLite-backed crypto store for mautrix-python"""
def __init__(self, account_id: str, pickle_key: str, db_path: Path | str) -> None:
self.account_id = account_id
self.pickle_key = pickle_key
self.db_path = Path(db_path)
self.db: aiosqlite.Connection | None = None
async def open(self) -> None:
"""Open database and create tables"""
self.db = await aiosqlite.connect(self.db_path)
self.db.row_factory = aiosqlite.Row
await self._create_tables()
async def close(self) -> None:
"""Close database connection"""
if self.db:
await self.db.close()
self.db = None
async def _create_tables(self) -> None:
await self.db.executescript("""
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT,
sync_token TEXT,
shared INTEGER DEFAULT 0,
account BLOB
);
CREATE TABLE IF NOT EXISTS crypto_olm_session (
account_id TEXT,
sender_key TEXT,
session_id TEXT,
session BLOB,
creation_time REAL,
PRIMARY KEY (account_id, sender_key, session_id)
);
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound (
account_id TEXT,
room_id TEXT,
session_id TEXT,
sender_key TEXT,
signing_key TEXT,
session BLOB,
PRIMARY KEY (account_id, room_id, session_id)
);
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound (
account_id TEXT,
room_id TEXT PRIMARY KEY,
session BLOB,
max_age INTEGER,
max_messages INTEGER,
creation_time REAL,
use_time REAL,
message_count INTEGER,
shared INTEGER
);
CREATE TABLE IF NOT EXISTS crypto_device (
account_id TEXT,
user_id TEXT,
device_id TEXT,
device BLOB,
PRIMARY KEY (account_id, user_id, device_id)
);
CREATE TABLE IF NOT EXISTS crypto_message_index (
account_id TEXT,
sender_key TEXT,
session_id TEXT,
idx INTEGER,
event_id TEXT,
timestamp INTEGER,
PRIMARY KEY (account_id, sender_key, session_id, idx)
);
CREATE TABLE IF NOT EXISTS crypto_cross_signing (
account_id TEXT,
user_id TEXT,
usage TEXT,
key TEXT,
first_key TEXT,
PRIMARY KEY (account_id, user_id, usage)
);
CREATE TABLE IF NOT EXISTS crypto_signature (
account_id TEXT,
signer TEXT,
target TEXT,
signature TEXT,
PRIMARY KEY (account_id, signer, target)
);
""")
await self.db.commit()
def _pickle(self, obj) -> bytes:
return pickle.dumps(obj)
def _unpickle(self, data: bytes):
return pickle.loads(data) if data else None
@asynccontextmanager
async def transaction(self):
"""Async context manager for database transactions"""
try:
yield
await self.db.commit()
except Exception:
await self.db.rollback()
raise
# Device ID
async def get_device_id(self) -> DeviceID | None:
async with self.db.execute(
"SELECT device_id FROM crypto_account WHERE account_id = ?",
(self.account_id,)
) as cur:
row = await cur.fetchone()
return DeviceID(row["device_id"]) if row and row["device_id"] else None
async def put_device_id(self, device_id: DeviceID) -> None:
await self.db.execute(
"INSERT OR REPLACE INTO crypto_account (account_id, device_id) VALUES (?, ?)",
(self.account_id, device_id)
)
await self.db.commit()
# Sync token
async def put_next_batch(self, next_batch: SyncToken) -> None:
await self.db.execute(
"UPDATE crypto_account SET sync_token = ? WHERE account_id = ?",
(next_batch, self.account_id)
)
await self.db.commit()
async def get_next_batch(self) -> SyncToken | None:
async with self.db.execute(
"SELECT sync_token FROM crypto_account WHERE account_id = ?",
(self.account_id,)
) as cur:
row = await cur.fetchone()
return SyncToken(row["sync_token"]) if row and row["sync_token"] else None
# Account
# Account - use olm's built-in pickle, not Python's
async def put_account(self, account: OlmAccount) -> None:
await self.db.execute(
"""INSERT OR REPLACE INTO crypto_account (account_id, device_id, sync_token, shared, account)
VALUES (?,
COALESCE((SELECT device_id FROM crypto_account WHERE account_id = ?), NULL),
COALESCE((SELECT sync_token FROM crypto_account WHERE account_id = ?), NULL),
?,
?)""",
(self.account_id, self.account_id, self.account_id, account.shared, account.pickle(self.pickle_key))
)
await self.db.commit()
async def get_account(self) -> OlmAccount | None:
async with self.db.execute(
"SELECT account, shared FROM crypto_account WHERE account_id = ?",
(self.account_id,)
) as cur:
row = await cur.fetchone()
if row and row["account"]:
return OlmAccount.from_pickle(row["account"], self.pickle_key, bool(row["shared"]))
return None
async def delete(self) -> None:
await self.db.execute("DELETE FROM crypto_account WHERE account_id = ?", (self.account_id,))
await self.db.execute("DELETE FROM crypto_olm_session WHERE account_id = ?", (self.account_id,))
await self.db.execute("DELETE FROM crypto_megolm_inbound WHERE account_id = ?", (self.account_id,))
await self.db.execute("DELETE FROM crypto_megolm_outbound WHERE account_id = ?", (self.account_id,))
await self.db.execute("DELETE FROM crypto_device WHERE account_id = ?", (self.account_id,))
await self.db.commit()
# Olm sessions
async def has_session(self, key: IdentityKey) -> bool:
async with self.db.execute(
"SELECT 1 FROM crypto_olm_session WHERE account_id = ? AND sender_key = ? LIMIT 1",
(self.account_id, key)
) as cur:
return await cur.fetchone() is not None
async def get_latest_session(self, key: IdentityKey) -> Session | None:
async with self.db.execute(
"SELECT session, creation_time FROM crypto_olm_session WHERE account_id = ? AND sender_key = ? ORDER BY rowid DESC LIMIT 1",
(self.account_id, key)
) as cur:
row = await cur.fetchone()
if row and row["session"]:
return Session.from_pickle(row["session"], self.pickle_key, row["creation_time"])
return None
async def get_sessions(self, key: IdentityKey) -> list[Session]:
async with self.db.execute(
"SELECT session, creation_time FROM crypto_olm_session WHERE account_id = ? AND sender_key = ?",
(self.account_id, key)
) as cur:
rows = await cur.fetchall()
return [Session.from_pickle(row["session"], self.pickle_key, row["creation_time"]) for row in rows]
async def add_session(self, key: IdentityKey, session: Session) -> None:
await self.db.execute(
"INSERT OR REPLACE INTO crypto_olm_session (account_id, sender_key, session_id, session, creation_time) VALUES (?, ?, ?, ?, ?)",
(self.account_id, key, session.id, session.pickle(self.pickle_key), session.creation_time)
)
await self.db.commit()
async def update_session(self, key: IdentityKey, session: Session) -> None:
await self.db.execute(
"UPDATE crypto_olm_session SET session = ? WHERE account_id = ? AND sender_key = ? AND session_id = ?",
(session.pickle(self.pickle_key), self.account_id, key, session.id)
)
await self.db.commit()
# Megolm inbound sessions
async def put_group_session(
self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, session: InboundGroupSession
) -> None:
await self.db.execute(
"INSERT OR REPLACE INTO crypto_megolm_inbound (account_id, room_id, session_id, sender_key, signing_key, session) VALUES (?, ?, ?, ?, ?, ?)",
(self.account_id, room_id, session_id, sender_key, session.signing_key, session.pickle(self.pickle_key))
)
await self.db.commit()
async def get_group_session(self, room_id: RoomID, session_id: SessionID) -> InboundGroupSession | None:
async with self.db.execute(
"SELECT session, sender_key, signing_key FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ?",
(self.account_id, room_id, session_id)
) as cur:
row = await cur.fetchone()
if row and row["session"]:
return InboundGroupSession.from_pickle(
row["session"],
self.pickle_key,
row["signing_key"],
row["sender_key"],
room_id,
)
return None
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
async with self.db.execute(
"SELECT 1 FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ? LIMIT 1",
(self.account_id, room_id, session_id)
) as cur:
return await cur.fetchone() is not None
async def redact_group_session(self, room_id: RoomID, session_id: SessionID, reason: str) -> None:
await self.db.execute(
"DELETE FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ?",
(self.account_id, room_id, session_id)
)
await self.db.commit()
async def redact_group_sessions(self, room_id: RoomID, sender_key: IdentityKey, reason: str) -> list[SessionID]:
async with self.db.execute(
"SELECT session_id FROM crypto_megolm_inbound WHERE account_id = ? AND (room_id = ? OR sender_key = ?)",
(self.account_id, room_id, sender_key)
) as cur:
rows = await cur.fetchall()
deleted = [SessionID(row["session_id"]) for row in rows]
await self.db.execute(
"DELETE FROM crypto_megolm_inbound WHERE account_id = ? AND (room_id = ? OR sender_key = ?)",
(self.account_id, room_id, sender_key)
)
await self.db.commit()
return deleted
async def redact_expired_group_sessions(self) -> list[SessionID]:
return [] # Not implemented for simplicity
async def redact_outdated_group_sessions(self) -> list[SessionID]:
return [] # Not implemented for simplicity
# Megolm outbound sessions
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
# Convert timedelta to milliseconds for storage
max_age_ms = int(session.max_age.total_seconds() * 1000) if session.max_age else None
creation_time_str = session.creation_time.isoformat() if session.creation_time else None
use_time_str = session.use_time.isoformat() if session.use_time else None
await self.db.execute(
"""INSERT OR REPLACE INTO crypto_megolm_outbound
(account_id, room_id, session, max_age, max_messages, creation_time, use_time, message_count, shared)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(self.account_id, session.room_id, session.pickle(self.pickle_key),
max_age_ms, session.max_messages, creation_time_str,
use_time_str, session.message_count, session.shared)
)
await self.db.commit()
async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
from datetime import datetime, timedelta
async with self.db.execute(
"SELECT session, max_age, max_messages, creation_time, use_time, message_count, shared FROM crypto_megolm_outbound WHERE account_id = ? AND room_id = ?",
(self.account_id, room_id)
) as cur:
row = await cur.fetchone()
if row and row["session"]:
max_age = timedelta(milliseconds=row["max_age"]) if row["max_age"] else None
# Convert string timestamps to datetime
creation_time = row["creation_time"]
if isinstance(creation_time, str):
creation_time = datetime.fromisoformat(creation_time)
use_time = row["use_time"]
if isinstance(use_time, str):
use_time = datetime.fromisoformat(use_time)
return OutboundGroupSession.from_pickle(
row["session"], self.pickle_key,
max_age, row["max_messages"], creation_time,
use_time, row["message_count"], room_id, bool(row["shared"])
)
return None
async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
await self.add_outbound_group_session(session)
async def remove_outbound_group_session(self, room_id: RoomID) -> None:
await self.db.execute(
"DELETE FROM crypto_megolm_outbound WHERE account_id = ? AND room_id = ?",
(self.account_id, room_id)
)
await self.db.commit()
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
for room_id in rooms:
await self.remove_outbound_group_session(room_id)
# Message index validation
async def validate_message_index(
self, sender_key: IdentityKey, session_id: SessionID, event_id: EventID, index: int, timestamp: int
) -> bool:
async with self.db.execute(
"SELECT event_id, timestamp FROM crypto_message_index WHERE account_id = ? AND sender_key = ? AND session_id = ? AND idx = ?",
(self.account_id, sender_key, session_id, index)
) as cur:
row = await cur.fetchone()
if row:
return row["event_id"] == event_id and row["timestamp"] == timestamp
await self.db.execute(
"INSERT INTO crypto_message_index (account_id, sender_key, session_id, idx, event_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)",
(self.account_id, sender_key, session_id, index, event_id, timestamp)
)
await self.db.commit()
return True
# Devices
async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
async with self.db.execute(
"SELECT device_id, device FROM crypto_device WHERE account_id = ? AND user_id = ?",
(self.account_id, user_id)
) as cur:
rows = await cur.fetchall()
if not rows:
return None
return {DeviceID(row["device_id"]): self._unpickle(row["device"]) for row in rows}
async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
async with self.db.execute(
"SELECT device FROM crypto_device WHERE account_id = ? AND user_id = ? AND device_id = ?",
(self.account_id, user_id, device_id)
) as cur:
row = await cur.fetchone()
return self._unpickle(row["device"]) if row else None
async def find_device_by_key(self, user_id: UserID, identity_key: IdentityKey) -> DeviceIdentity | None:
devices = await self.get_devices(user_id)
if devices:
for device in devices.values():
if device.identity_key == identity_key:
return device
return None
async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
await self.db.execute(
"DELETE FROM crypto_device WHERE account_id = ? AND user_id = ?",
(self.account_id, user_id)
)
for device_id, device in devices.items():
await self.db.execute(
"INSERT INTO crypto_device (account_id, user_id, device_id, device) VALUES (?, ?, ?, ?)",
(self.account_id, user_id, device_id, self._pickle(device))
)
await self.db.commit()
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
result = []
for user_id in users:
async with self.db.execute(
"SELECT 1 FROM crypto_device WHERE account_id = ? AND user_id = ? LIMIT 1",
(self.account_id, user_id)
) as cur:
if await cur.fetchone():
result.append(user_id)
return result
# Cross-signing
async def put_cross_signing_key(self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey) -> None:
async with self.db.execute(
"SELECT first_key FROM crypto_cross_signing WHERE account_id = ? AND user_id = ? AND usage = ?",
(self.account_id, user_id, usage.value)
) as cur:
row = await cur.fetchone()
first_key = row["first_key"] if row else key
await self.db.execute(
"INSERT OR REPLACE INTO crypto_cross_signing (account_id, user_id, usage, key, first_key) VALUES (?, ?, ?, ?, ?)",
(self.account_id, user_id, usage.value, key, first_key)
)
await self.db.commit()
async def get_cross_signing_keys(self, user_id: UserID) -> dict[CrossSigningUsage, TOFUSigningKey]:
async with self.db.execute(
"SELECT usage, key, first_key FROM crypto_cross_signing WHERE account_id = ? AND user_id = ?",
(self.account_id, user_id)
) as cur:
rows = await cur.fetchall()
return {
CrossSigningUsage(row["usage"]): TOFUSigningKey(key=row["key"], first=row["first_key"])
for row in rows
}
# Signatures
async def put_signature(self, target: CrossSigner, signer: CrossSigner, signature: str) -> None:
await self.db.execute(
"INSERT OR REPLACE INTO crypto_signature (account_id, signer, target, signature) VALUES (?, ?, ?, ?)",
(self.account_id, str(signer), str(target), signature)
)
await self.db.commit()
async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
async with self.db.execute(
"SELECT 1 FROM crypto_signature WHERE account_id = ? AND signer = ? AND target = ? LIMIT 1",
(self.account_id, str(signer), str(target))
) as cur:
return await cur.fetchone() is not None
async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
async with self.db.execute(
"SELECT COUNT(*) as cnt FROM crypto_signature WHERE account_id = ? AND signer = ?",
(self.account_id, str(signer))
) as cur:
row = await cur.fetchone()
count = row["cnt"]
await self.db.execute(
"DELETE FROM crypto_signature WHERE account_id = ? AND signer = ?",
(self.account_id, str(signer))
)
await self.db.commit()
return count