# sqlite_crypto_store.py """SQLite-backed CryptoStore for mautrix-python""" from __future__ import annotations import pickle import aiosqlite from pathlib import Path from contextlib import asynccontextmanager from mautrix.types import ( CrossSigner, CrossSigningUsage, DeviceID, DeviceIdentity, EventID, IdentityKey, RoomID, SessionID, SigningKey, SyncToken, TOFUSigningKey, UserID, ) from mautrix.crypto.account import OlmAccount from mautrix.crypto.sessions import InboundGroupSession, OutboundGroupSession, Session from mautrix.crypto.store.abstract import CryptoStore from mautrix.client.state_store import SyncStore class SQLiteCryptoStore(CryptoStore, SyncStore): """SQLite-backed crypto store for mautrix-python""" def __init__(self, account_id: str, pickle_key: str, db_path: Path | str) -> None: self.account_id = account_id self.pickle_key = pickle_key self.db_path = Path(db_path) self.db: aiosqlite.Connection | None = None async def open(self) -> None: """Open database and create tables""" self.db = await aiosqlite.connect(self.db_path) self.db.row_factory = aiosqlite.Row await self._create_tables() async def close(self) -> None: """Close database connection""" if self.db: await self.db.close() self.db = None async def _create_tables(self) -> None: await self.db.executescript(""" CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT, sync_token TEXT, shared INTEGER DEFAULT 0, account BLOB ); CREATE TABLE IF NOT EXISTS crypto_olm_session ( account_id TEXT, sender_key TEXT, session_id TEXT, session BLOB, creation_time REAL, PRIMARY KEY (account_id, sender_key, session_id) ); CREATE TABLE IF NOT EXISTS crypto_megolm_inbound ( account_id TEXT, room_id TEXT, session_id TEXT, sender_key TEXT, signing_key TEXT, session BLOB, PRIMARY KEY (account_id, room_id, session_id) ); CREATE TABLE IF NOT EXISTS crypto_megolm_outbound ( account_id TEXT, room_id TEXT PRIMARY KEY, session BLOB, max_age INTEGER, max_messages INTEGER, creation_time REAL, use_time REAL, message_count INTEGER, shared INTEGER ); CREATE TABLE IF NOT EXISTS crypto_device ( account_id TEXT, user_id TEXT, device_id TEXT, device BLOB, PRIMARY KEY (account_id, user_id, device_id) ); CREATE TABLE IF NOT EXISTS crypto_message_index ( account_id TEXT, sender_key TEXT, session_id TEXT, idx INTEGER, event_id TEXT, timestamp INTEGER, PRIMARY KEY (account_id, sender_key, session_id, idx) ); CREATE TABLE IF NOT EXISTS crypto_cross_signing ( account_id TEXT, user_id TEXT, usage TEXT, key TEXT, first_key TEXT, PRIMARY KEY (account_id, user_id, usage) ); CREATE TABLE IF NOT EXISTS crypto_signature ( account_id TEXT, signer TEXT, target TEXT, signature TEXT, PRIMARY KEY (account_id, signer, target) ); """) await self.db.commit() def _pickle(self, obj) -> bytes: return pickle.dumps(obj) def _unpickle(self, data: bytes): return pickle.loads(data) if data else None @asynccontextmanager async def transaction(self): """Async context manager for database transactions""" try: yield await self.db.commit() except Exception: await self.db.rollback() raise # Device ID async def get_device_id(self) -> DeviceID | None: async with self.db.execute( "SELECT device_id FROM crypto_account WHERE account_id = ?", (self.account_id,) ) as cur: row = await cur.fetchone() return DeviceID(row["device_id"]) if row and row["device_id"] else None async def put_device_id(self, device_id: DeviceID) -> None: await self.db.execute( "INSERT OR REPLACE INTO crypto_account (account_id, device_id) VALUES (?, ?)", (self.account_id, device_id) ) await self.db.commit() # Sync token async def put_next_batch(self, next_batch: SyncToken) -> None: await self.db.execute( "UPDATE crypto_account SET sync_token = ? WHERE account_id = ?", (next_batch, self.account_id) ) await self.db.commit() async def get_next_batch(self) -> SyncToken | None: async with self.db.execute( "SELECT sync_token FROM crypto_account WHERE account_id = ?", (self.account_id,) ) as cur: row = await cur.fetchone() return SyncToken(row["sync_token"]) if row and row["sync_token"] else None # Account # Account - use olm's built-in pickle, not Python's async def put_account(self, account: OlmAccount) -> None: await self.db.execute( """INSERT OR REPLACE INTO crypto_account (account_id, device_id, sync_token, shared, account) VALUES (?, COALESCE((SELECT device_id FROM crypto_account WHERE account_id = ?), NULL), COALESCE((SELECT sync_token FROM crypto_account WHERE account_id = ?), NULL), ?, ?)""", (self.account_id, self.account_id, self.account_id, account.shared, account.pickle(self.pickle_key)) ) await self.db.commit() async def get_account(self) -> OlmAccount | None: async with self.db.execute( "SELECT account, shared FROM crypto_account WHERE account_id = ?", (self.account_id,) ) as cur: row = await cur.fetchone() if row and row["account"]: return OlmAccount.from_pickle(row["account"], self.pickle_key, bool(row["shared"])) return None async def delete(self) -> None: await self.db.execute("DELETE FROM crypto_account WHERE account_id = ?", (self.account_id,)) await self.db.execute("DELETE FROM crypto_olm_session WHERE account_id = ?", (self.account_id,)) await self.db.execute("DELETE FROM crypto_megolm_inbound WHERE account_id = ?", (self.account_id,)) await self.db.execute("DELETE FROM crypto_megolm_outbound WHERE account_id = ?", (self.account_id,)) await self.db.execute("DELETE FROM crypto_device WHERE account_id = ?", (self.account_id,)) await self.db.commit() # Olm sessions async def has_session(self, key: IdentityKey) -> bool: async with self.db.execute( "SELECT 1 FROM crypto_olm_session WHERE account_id = ? AND sender_key = ? LIMIT 1", (self.account_id, key) ) as cur: return await cur.fetchone() is not None async def get_latest_session(self, key: IdentityKey) -> Session | None: async with self.db.execute( "SELECT session, creation_time FROM crypto_olm_session WHERE account_id = ? AND sender_key = ? ORDER BY rowid DESC LIMIT 1", (self.account_id, key) ) as cur: row = await cur.fetchone() if row and row["session"]: return Session.from_pickle(row["session"], self.pickle_key, row["creation_time"]) return None async def get_sessions(self, key: IdentityKey) -> list[Session]: async with self.db.execute( "SELECT session, creation_time FROM crypto_olm_session WHERE account_id = ? AND sender_key = ?", (self.account_id, key) ) as cur: rows = await cur.fetchall() return [Session.from_pickle(row["session"], self.pickle_key, row["creation_time"]) for row in rows] async def add_session(self, key: IdentityKey, session: Session) -> None: await self.db.execute( "INSERT OR REPLACE INTO crypto_olm_session (account_id, sender_key, session_id, session, creation_time) VALUES (?, ?, ?, ?, ?)", (self.account_id, key, session.id, session.pickle(self.pickle_key), session.creation_time) ) await self.db.commit() async def update_session(self, key: IdentityKey, session: Session) -> None: await self.db.execute( "UPDATE crypto_olm_session SET session = ? WHERE account_id = ? AND sender_key = ? AND session_id = ?", (session.pickle(self.pickle_key), self.account_id, key, session.id) ) await self.db.commit() # Megolm inbound sessions async def put_group_session( self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, session: InboundGroupSession ) -> None: await self.db.execute( "INSERT OR REPLACE INTO crypto_megolm_inbound (account_id, room_id, session_id, sender_key, signing_key, session) VALUES (?, ?, ?, ?, ?, ?)", (self.account_id, room_id, session_id, sender_key, session.signing_key, session.pickle(self.pickle_key)) ) await self.db.commit() async def get_group_session(self, room_id: RoomID, session_id: SessionID) -> InboundGroupSession | None: async with self.db.execute( "SELECT session, sender_key, signing_key FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ?", (self.account_id, room_id, session_id) ) as cur: row = await cur.fetchone() if row and row["session"]: return InboundGroupSession.from_pickle( row["session"], self.pickle_key, row["signing_key"], row["sender_key"], room_id, ) return None async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool: async with self.db.execute( "SELECT 1 FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ? LIMIT 1", (self.account_id, room_id, session_id) ) as cur: return await cur.fetchone() is not None async def redact_group_session(self, room_id: RoomID, session_id: SessionID, reason: str) -> None: await self.db.execute( "DELETE FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ?", (self.account_id, room_id, session_id) ) await self.db.commit() async def redact_group_sessions(self, room_id: RoomID, sender_key: IdentityKey, reason: str) -> list[SessionID]: async with self.db.execute( "SELECT session_id FROM crypto_megolm_inbound WHERE account_id = ? AND (room_id = ? OR sender_key = ?)", (self.account_id, room_id, sender_key) ) as cur: rows = await cur.fetchall() deleted = [SessionID(row["session_id"]) for row in rows] await self.db.execute( "DELETE FROM crypto_megolm_inbound WHERE account_id = ? AND (room_id = ? OR sender_key = ?)", (self.account_id, room_id, sender_key) ) await self.db.commit() return deleted async def redact_expired_group_sessions(self) -> list[SessionID]: return [] # Not implemented for simplicity async def redact_outdated_group_sessions(self) -> list[SessionID]: return [] # Not implemented for simplicity # Megolm outbound sessions async def add_outbound_group_session(self, session: OutboundGroupSession) -> None: # Convert timedelta to milliseconds for storage max_age_ms = int(session.max_age.total_seconds() * 1000) if session.max_age else None creation_time_str = session.creation_time.isoformat() if session.creation_time else None use_time_str = session.use_time.isoformat() if session.use_time else None await self.db.execute( """INSERT OR REPLACE INTO crypto_megolm_outbound (account_id, room_id, session, max_age, max_messages, creation_time, use_time, message_count, shared) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", (self.account_id, session.room_id, session.pickle(self.pickle_key), max_age_ms, session.max_messages, creation_time_str, use_time_str, session.message_count, session.shared) ) await self.db.commit() async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None: from datetime import datetime, timedelta async with self.db.execute( "SELECT session, max_age, max_messages, creation_time, use_time, message_count, shared FROM crypto_megolm_outbound WHERE account_id = ? AND room_id = ?", (self.account_id, room_id) ) as cur: row = await cur.fetchone() if row and row["session"]: max_age = timedelta(milliseconds=row["max_age"]) if row["max_age"] else None # Convert string timestamps to datetime creation_time = row["creation_time"] if isinstance(creation_time, str): creation_time = datetime.fromisoformat(creation_time) use_time = row["use_time"] if isinstance(use_time, str): use_time = datetime.fromisoformat(use_time) return OutboundGroupSession.from_pickle( row["session"], self.pickle_key, max_age, row["max_messages"], creation_time, use_time, row["message_count"], room_id, bool(row["shared"]) ) return None async def update_outbound_group_session(self, session: OutboundGroupSession) -> None: await self.add_outbound_group_session(session) async def remove_outbound_group_session(self, room_id: RoomID) -> None: await self.db.execute( "DELETE FROM crypto_megolm_outbound WHERE account_id = ? AND room_id = ?", (self.account_id, room_id) ) await self.db.commit() async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None: for room_id in rooms: await self.remove_outbound_group_session(room_id) # Message index validation async def validate_message_index( self, sender_key: IdentityKey, session_id: SessionID, event_id: EventID, index: int, timestamp: int ) -> bool: async with self.db.execute( "SELECT event_id, timestamp FROM crypto_message_index WHERE account_id = ? AND sender_key = ? AND session_id = ? AND idx = ?", (self.account_id, sender_key, session_id, index) ) as cur: row = await cur.fetchone() if row: return row["event_id"] == event_id and row["timestamp"] == timestamp await self.db.execute( "INSERT INTO crypto_message_index (account_id, sender_key, session_id, idx, event_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)", (self.account_id, sender_key, session_id, index, event_id, timestamp) ) await self.db.commit() return True # Devices async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None: async with self.db.execute( "SELECT device_id, device FROM crypto_device WHERE account_id = ? AND user_id = ?", (self.account_id, user_id) ) as cur: rows = await cur.fetchall() if not rows: return None return {DeviceID(row["device_id"]): self._unpickle(row["device"]) for row in rows} async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None: async with self.db.execute( "SELECT device FROM crypto_device WHERE account_id = ? AND user_id = ? AND device_id = ?", (self.account_id, user_id, device_id) ) as cur: row = await cur.fetchone() return self._unpickle(row["device"]) if row else None async def find_device_by_key(self, user_id: UserID, identity_key: IdentityKey) -> DeviceIdentity | None: devices = await self.get_devices(user_id) if devices: for device in devices.values(): if device.identity_key == identity_key: return device return None async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None: await self.db.execute( "DELETE FROM crypto_device WHERE account_id = ? AND user_id = ?", (self.account_id, user_id) ) for device_id, device in devices.items(): await self.db.execute( "INSERT INTO crypto_device (account_id, user_id, device_id, device) VALUES (?, ?, ?, ?)", (self.account_id, user_id, device_id, self._pickle(device)) ) await self.db.commit() async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]: result = [] for user_id in users: async with self.db.execute( "SELECT 1 FROM crypto_device WHERE account_id = ? AND user_id = ? LIMIT 1", (self.account_id, user_id) ) as cur: if await cur.fetchone(): result.append(user_id) return result # Cross-signing async def put_cross_signing_key(self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey) -> None: async with self.db.execute( "SELECT first_key FROM crypto_cross_signing WHERE account_id = ? AND user_id = ? AND usage = ?", (self.account_id, user_id, usage.value) ) as cur: row = await cur.fetchone() first_key = row["first_key"] if row else key await self.db.execute( "INSERT OR REPLACE INTO crypto_cross_signing (account_id, user_id, usage, key, first_key) VALUES (?, ?, ?, ?, ?)", (self.account_id, user_id, usage.value, key, first_key) ) await self.db.commit() async def get_cross_signing_keys(self, user_id: UserID) -> dict[CrossSigningUsage, TOFUSigningKey]: async with self.db.execute( "SELECT usage, key, first_key FROM crypto_cross_signing WHERE account_id = ? AND user_id = ?", (self.account_id, user_id) ) as cur: rows = await cur.fetchall() return { CrossSigningUsage(row["usage"]): TOFUSigningKey(key=row["key"], first=row["first_key"]) for row in rows } # Signatures async def put_signature(self, target: CrossSigner, signer: CrossSigner, signature: str) -> None: await self.db.execute( "INSERT OR REPLACE INTO crypto_signature (account_id, signer, target, signature) VALUES (?, ?, ?, ?)", (self.account_id, str(signer), str(target), signature) ) await self.db.commit() async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool: async with self.db.execute( "SELECT 1 FROM crypto_signature WHERE account_id = ? AND signer = ? AND target = ? LIMIT 1", (self.account_id, str(signer), str(target)) ) as cur: return await cur.fetchone() is not None async def drop_signatures_by_key(self, signer: CrossSigner) -> int: async with self.db.execute( "SELECT COUNT(*) as cnt FROM crypto_signature WHERE account_id = ? AND signer = ?", (self.account_id, str(signer)) ) as cur: row = await cur.fetchone() count = row["cnt"] await self.db.execute( "DELETE FROM crypto_signature WHERE account_id = ? AND signer = ?", (self.account_id, str(signer)) ) await self.db.commit() return count