486 lines
20 KiB
Python
486 lines
20 KiB
Python
# sqlite_crypto_store.py
|
|
"""SQLite-backed CryptoStore for mautrix-python"""
|
|
from __future__ import annotations
|
|
import pickle
|
|
import aiosqlite
|
|
from pathlib import Path
|
|
from contextlib import asynccontextmanager
|
|
|
|
from mautrix.types import (
|
|
CrossSigner,
|
|
CrossSigningUsage,
|
|
DeviceID,
|
|
DeviceIdentity,
|
|
EventID,
|
|
IdentityKey,
|
|
RoomID,
|
|
SessionID,
|
|
SigningKey,
|
|
SyncToken,
|
|
TOFUSigningKey,
|
|
UserID,
|
|
)
|
|
from mautrix.crypto.account import OlmAccount
|
|
from mautrix.crypto.sessions import InboundGroupSession, OutboundGroupSession, Session
|
|
from mautrix.crypto.store.abstract import CryptoStore
|
|
from mautrix.client.state_store import SyncStore
|
|
|
|
|
|
class SQLiteCryptoStore(CryptoStore, SyncStore):
|
|
"""SQLite-backed crypto store for mautrix-python"""
|
|
|
|
def __init__(self, account_id: str, pickle_key: str, db_path: Path | str) -> None:
|
|
self.account_id = account_id
|
|
self.pickle_key = pickle_key
|
|
self.db_path = Path(db_path)
|
|
self.db: aiosqlite.Connection | None = None
|
|
|
|
async def open(self) -> None:
|
|
"""Open database and create tables"""
|
|
self.db = await aiosqlite.connect(self.db_path)
|
|
self.db.row_factory = aiosqlite.Row
|
|
await self._create_tables()
|
|
|
|
async def close(self) -> None:
|
|
"""Close database connection"""
|
|
if self.db:
|
|
await self.db.close()
|
|
self.db = None
|
|
|
|
async def _create_tables(self) -> None:
|
|
await self.db.executescript("""
|
|
CREATE TABLE IF NOT EXISTS crypto_account (
|
|
account_id TEXT PRIMARY KEY,
|
|
device_id TEXT,
|
|
sync_token TEXT,
|
|
shared INTEGER DEFAULT 0,
|
|
account BLOB
|
|
);
|
|
CREATE TABLE IF NOT EXISTS crypto_olm_session (
|
|
account_id TEXT,
|
|
sender_key TEXT,
|
|
session_id TEXT,
|
|
session BLOB,
|
|
creation_time REAL,
|
|
PRIMARY KEY (account_id, sender_key, session_id)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound (
|
|
account_id TEXT,
|
|
room_id TEXT,
|
|
session_id TEXT,
|
|
sender_key TEXT,
|
|
signing_key TEXT,
|
|
session BLOB,
|
|
PRIMARY KEY (account_id, room_id, session_id)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound (
|
|
account_id TEXT,
|
|
room_id TEXT PRIMARY KEY,
|
|
session BLOB,
|
|
max_age INTEGER,
|
|
max_messages INTEGER,
|
|
creation_time REAL,
|
|
use_time REAL,
|
|
message_count INTEGER,
|
|
shared INTEGER
|
|
);
|
|
CREATE TABLE IF NOT EXISTS crypto_device (
|
|
account_id TEXT,
|
|
user_id TEXT,
|
|
device_id TEXT,
|
|
device BLOB,
|
|
PRIMARY KEY (account_id, user_id, device_id)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS crypto_message_index (
|
|
account_id TEXT,
|
|
sender_key TEXT,
|
|
session_id TEXT,
|
|
idx INTEGER,
|
|
event_id TEXT,
|
|
timestamp INTEGER,
|
|
PRIMARY KEY (account_id, sender_key, session_id, idx)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS crypto_cross_signing (
|
|
account_id TEXT,
|
|
user_id TEXT,
|
|
usage TEXT,
|
|
key TEXT,
|
|
first_key TEXT,
|
|
PRIMARY KEY (account_id, user_id, usage)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS crypto_signature (
|
|
account_id TEXT,
|
|
signer TEXT,
|
|
target TEXT,
|
|
signature TEXT,
|
|
PRIMARY KEY (account_id, signer, target)
|
|
);
|
|
""")
|
|
await self.db.commit()
|
|
|
|
def _pickle(self, obj) -> bytes:
|
|
return pickle.dumps(obj)
|
|
|
|
def _unpickle(self, data: bytes):
|
|
return pickle.loads(data) if data else None
|
|
|
|
@asynccontextmanager
|
|
async def transaction(self):
|
|
"""Async context manager for database transactions"""
|
|
try:
|
|
yield
|
|
await self.db.commit()
|
|
except Exception:
|
|
await self.db.rollback()
|
|
raise
|
|
|
|
# Device ID
|
|
async def get_device_id(self) -> DeviceID | None:
|
|
async with self.db.execute(
|
|
"SELECT device_id FROM crypto_account WHERE account_id = ?",
|
|
(self.account_id,)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
return DeviceID(row["device_id"]) if row and row["device_id"] else None
|
|
|
|
async def put_device_id(self, device_id: DeviceID) -> None:
|
|
await self.db.execute(
|
|
"INSERT OR REPLACE INTO crypto_account (account_id, device_id) VALUES (?, ?)",
|
|
(self.account_id, device_id)
|
|
)
|
|
await self.db.commit()
|
|
|
|
# Sync token
|
|
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
|
await self.db.execute(
|
|
"UPDATE crypto_account SET sync_token = ? WHERE account_id = ?",
|
|
(next_batch, self.account_id)
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_next_batch(self) -> SyncToken | None:
|
|
async with self.db.execute(
|
|
"SELECT sync_token FROM crypto_account WHERE account_id = ?",
|
|
(self.account_id,)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
return SyncToken(row["sync_token"]) if row and row["sync_token"] else None
|
|
|
|
# Account
|
|
# Account - use olm's built-in pickle, not Python's
|
|
async def put_account(self, account: OlmAccount) -> None:
|
|
await self.db.execute(
|
|
"""INSERT OR REPLACE INTO crypto_account (account_id, device_id, sync_token, shared, account)
|
|
VALUES (?,
|
|
COALESCE((SELECT device_id FROM crypto_account WHERE account_id = ?), NULL),
|
|
COALESCE((SELECT sync_token FROM crypto_account WHERE account_id = ?), NULL),
|
|
?,
|
|
?)""",
|
|
(self.account_id, self.account_id, self.account_id, account.shared, account.pickle(self.pickle_key))
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_account(self) -> OlmAccount | None:
|
|
async with self.db.execute(
|
|
"SELECT account, shared FROM crypto_account WHERE account_id = ?",
|
|
(self.account_id,)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
if row and row["account"]:
|
|
return OlmAccount.from_pickle(row["account"], self.pickle_key, bool(row["shared"]))
|
|
return None
|
|
|
|
async def delete(self) -> None:
|
|
await self.db.execute("DELETE FROM crypto_account WHERE account_id = ?", (self.account_id,))
|
|
await self.db.execute("DELETE FROM crypto_olm_session WHERE account_id = ?", (self.account_id,))
|
|
await self.db.execute("DELETE FROM crypto_megolm_inbound WHERE account_id = ?", (self.account_id,))
|
|
await self.db.execute("DELETE FROM crypto_megolm_outbound WHERE account_id = ?", (self.account_id,))
|
|
await self.db.execute("DELETE FROM crypto_device WHERE account_id = ?", (self.account_id,))
|
|
await self.db.commit()
|
|
|
|
# Olm sessions
|
|
async def has_session(self, key: IdentityKey) -> bool:
|
|
async with self.db.execute(
|
|
"SELECT 1 FROM crypto_olm_session WHERE account_id = ? AND sender_key = ? LIMIT 1",
|
|
(self.account_id, key)
|
|
) as cur:
|
|
return await cur.fetchone() is not None
|
|
|
|
async def get_latest_session(self, key: IdentityKey) -> Session | None:
|
|
async with self.db.execute(
|
|
"SELECT session, creation_time FROM crypto_olm_session WHERE account_id = ? AND sender_key = ? ORDER BY rowid DESC LIMIT 1",
|
|
(self.account_id, key)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
if row and row["session"]:
|
|
return Session.from_pickle(row["session"], self.pickle_key, row["creation_time"])
|
|
return None
|
|
|
|
async def get_sessions(self, key: IdentityKey) -> list[Session]:
|
|
async with self.db.execute(
|
|
"SELECT session, creation_time FROM crypto_olm_session WHERE account_id = ? AND sender_key = ?",
|
|
(self.account_id, key)
|
|
) as cur:
|
|
rows = await cur.fetchall()
|
|
return [Session.from_pickle(row["session"], self.pickle_key, row["creation_time"]) for row in rows]
|
|
|
|
async def add_session(self, key: IdentityKey, session: Session) -> None:
|
|
await self.db.execute(
|
|
"INSERT OR REPLACE INTO crypto_olm_session (account_id, sender_key, session_id, session, creation_time) VALUES (?, ?, ?, ?, ?)",
|
|
(self.account_id, key, session.id, session.pickle(self.pickle_key), session.creation_time)
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def update_session(self, key: IdentityKey, session: Session) -> None:
|
|
await self.db.execute(
|
|
"UPDATE crypto_olm_session SET session = ? WHERE account_id = ? AND sender_key = ? AND session_id = ?",
|
|
(session.pickle(self.pickle_key), self.account_id, key, session.id)
|
|
)
|
|
await self.db.commit()
|
|
|
|
# Megolm inbound sessions
|
|
async def put_group_session(
|
|
self, room_id: RoomID, sender_key: IdentityKey, session_id: SessionID, session: InboundGroupSession
|
|
) -> None:
|
|
await self.db.execute(
|
|
"INSERT OR REPLACE INTO crypto_megolm_inbound (account_id, room_id, session_id, sender_key, signing_key, session) VALUES (?, ?, ?, ?, ?, ?)",
|
|
(self.account_id, room_id, session_id, sender_key, session.signing_key, session.pickle(self.pickle_key))
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_group_session(self, room_id: RoomID, session_id: SessionID) -> InboundGroupSession | None:
|
|
async with self.db.execute(
|
|
"SELECT session, sender_key, signing_key FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ?",
|
|
(self.account_id, room_id, session_id)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
if row and row["session"]:
|
|
return InboundGroupSession.from_pickle(
|
|
row["session"],
|
|
self.pickle_key,
|
|
row["signing_key"],
|
|
row["sender_key"],
|
|
room_id,
|
|
)
|
|
return None
|
|
|
|
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
|
|
async with self.db.execute(
|
|
"SELECT 1 FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ? LIMIT 1",
|
|
(self.account_id, room_id, session_id)
|
|
) as cur:
|
|
return await cur.fetchone() is not None
|
|
|
|
async def redact_group_session(self, room_id: RoomID, session_id: SessionID, reason: str) -> None:
|
|
await self.db.execute(
|
|
"DELETE FROM crypto_megolm_inbound WHERE account_id = ? AND room_id = ? AND session_id = ?",
|
|
(self.account_id, room_id, session_id)
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def redact_group_sessions(self, room_id: RoomID, sender_key: IdentityKey, reason: str) -> list[SessionID]:
|
|
async with self.db.execute(
|
|
"SELECT session_id FROM crypto_megolm_inbound WHERE account_id = ? AND (room_id = ? OR sender_key = ?)",
|
|
(self.account_id, room_id, sender_key)
|
|
) as cur:
|
|
rows = await cur.fetchall()
|
|
deleted = [SessionID(row["session_id"]) for row in rows]
|
|
|
|
await self.db.execute(
|
|
"DELETE FROM crypto_megolm_inbound WHERE account_id = ? AND (room_id = ? OR sender_key = ?)",
|
|
(self.account_id, room_id, sender_key)
|
|
)
|
|
await self.db.commit()
|
|
return deleted
|
|
|
|
async def redact_expired_group_sessions(self) -> list[SessionID]:
|
|
return [] # Not implemented for simplicity
|
|
|
|
async def redact_outdated_group_sessions(self) -> list[SessionID]:
|
|
return [] # Not implemented for simplicity
|
|
|
|
# Megolm outbound sessions
|
|
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
|
|
# Convert timedelta to milliseconds for storage
|
|
max_age_ms = int(session.max_age.total_seconds() * 1000) if session.max_age else None
|
|
|
|
creation_time_str = session.creation_time.isoformat() if session.creation_time else None
|
|
use_time_str = session.use_time.isoformat() if session.use_time else None
|
|
|
|
await self.db.execute(
|
|
"""INSERT OR REPLACE INTO crypto_megolm_outbound
|
|
(account_id, room_id, session, max_age, max_messages, creation_time, use_time, message_count, shared)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(self.account_id, session.room_id, session.pickle(self.pickle_key),
|
|
max_age_ms, session.max_messages, creation_time_str,
|
|
use_time_str, session.message_count, session.shared)
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
|
|
from datetime import datetime, timedelta
|
|
|
|
async with self.db.execute(
|
|
"SELECT session, max_age, max_messages, creation_time, use_time, message_count, shared FROM crypto_megolm_outbound WHERE account_id = ? AND room_id = ?",
|
|
(self.account_id, room_id)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
if row and row["session"]:
|
|
max_age = timedelta(milliseconds=row["max_age"]) if row["max_age"] else None
|
|
|
|
# Convert string timestamps to datetime
|
|
creation_time = row["creation_time"]
|
|
if isinstance(creation_time, str):
|
|
creation_time = datetime.fromisoformat(creation_time)
|
|
|
|
use_time = row["use_time"]
|
|
if isinstance(use_time, str):
|
|
use_time = datetime.fromisoformat(use_time)
|
|
|
|
return OutboundGroupSession.from_pickle(
|
|
row["session"], self.pickle_key,
|
|
max_age, row["max_messages"], creation_time,
|
|
use_time, row["message_count"], room_id, bool(row["shared"])
|
|
)
|
|
return None
|
|
|
|
async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
|
|
await self.add_outbound_group_session(session)
|
|
|
|
async def remove_outbound_group_session(self, room_id: RoomID) -> None:
|
|
await self.db.execute(
|
|
"DELETE FROM crypto_megolm_outbound WHERE account_id = ? AND room_id = ?",
|
|
(self.account_id, room_id)
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
|
|
for room_id in rooms:
|
|
await self.remove_outbound_group_session(room_id)
|
|
|
|
# Message index validation
|
|
async def validate_message_index(
|
|
self, sender_key: IdentityKey, session_id: SessionID, event_id: EventID, index: int, timestamp: int
|
|
) -> bool:
|
|
async with self.db.execute(
|
|
"SELECT event_id, timestamp FROM crypto_message_index WHERE account_id = ? AND sender_key = ? AND session_id = ? AND idx = ?",
|
|
(self.account_id, sender_key, session_id, index)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return row["event_id"] == event_id and row["timestamp"] == timestamp
|
|
|
|
await self.db.execute(
|
|
"INSERT INTO crypto_message_index (account_id, sender_key, session_id, idx, event_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)",
|
|
(self.account_id, sender_key, session_id, index, event_id, timestamp)
|
|
)
|
|
await self.db.commit()
|
|
return True
|
|
|
|
# Devices
|
|
async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
|
|
async with self.db.execute(
|
|
"SELECT device_id, device FROM crypto_device WHERE account_id = ? AND user_id = ?",
|
|
(self.account_id, user_id)
|
|
) as cur:
|
|
rows = await cur.fetchall()
|
|
if not rows:
|
|
return None
|
|
return {DeviceID(row["device_id"]): self._unpickle(row["device"]) for row in rows}
|
|
|
|
async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
|
|
async with self.db.execute(
|
|
"SELECT device FROM crypto_device WHERE account_id = ? AND user_id = ? AND device_id = ?",
|
|
(self.account_id, user_id, device_id)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
return self._unpickle(row["device"]) if row else None
|
|
|
|
async def find_device_by_key(self, user_id: UserID, identity_key: IdentityKey) -> DeviceIdentity | None:
|
|
devices = await self.get_devices(user_id)
|
|
if devices:
|
|
for device in devices.values():
|
|
if device.identity_key == identity_key:
|
|
return device
|
|
return None
|
|
|
|
async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
|
|
await self.db.execute(
|
|
"DELETE FROM crypto_device WHERE account_id = ? AND user_id = ?",
|
|
(self.account_id, user_id)
|
|
)
|
|
for device_id, device in devices.items():
|
|
await self.db.execute(
|
|
"INSERT INTO crypto_device (account_id, user_id, device_id, device) VALUES (?, ?, ?, ?)",
|
|
(self.account_id, user_id, device_id, self._pickle(device))
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
|
|
result = []
|
|
for user_id in users:
|
|
async with self.db.execute(
|
|
"SELECT 1 FROM crypto_device WHERE account_id = ? AND user_id = ? LIMIT 1",
|
|
(self.account_id, user_id)
|
|
) as cur:
|
|
if await cur.fetchone():
|
|
result.append(user_id)
|
|
return result
|
|
|
|
# Cross-signing
|
|
async def put_cross_signing_key(self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey) -> None:
|
|
async with self.db.execute(
|
|
"SELECT first_key FROM crypto_cross_signing WHERE account_id = ? AND user_id = ? AND usage = ?",
|
|
(self.account_id, user_id, usage.value)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
first_key = row["first_key"] if row else key
|
|
|
|
await self.db.execute(
|
|
"INSERT OR REPLACE INTO crypto_cross_signing (account_id, user_id, usage, key, first_key) VALUES (?, ?, ?, ?, ?)",
|
|
(self.account_id, user_id, usage.value, key, first_key)
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_cross_signing_keys(self, user_id: UserID) -> dict[CrossSigningUsage, TOFUSigningKey]:
|
|
async with self.db.execute(
|
|
"SELECT usage, key, first_key FROM crypto_cross_signing WHERE account_id = ? AND user_id = ?",
|
|
(self.account_id, user_id)
|
|
) as cur:
|
|
rows = await cur.fetchall()
|
|
return {
|
|
CrossSigningUsage(row["usage"]): TOFUSigningKey(key=row["key"], first=row["first_key"])
|
|
for row in rows
|
|
}
|
|
|
|
# Signatures
|
|
async def put_signature(self, target: CrossSigner, signer: CrossSigner, signature: str) -> None:
|
|
await self.db.execute(
|
|
"INSERT OR REPLACE INTO crypto_signature (account_id, signer, target, signature) VALUES (?, ?, ?, ?)",
|
|
(self.account_id, str(signer), str(target), signature)
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
|
|
async with self.db.execute(
|
|
"SELECT 1 FROM crypto_signature WHERE account_id = ? AND signer = ? AND target = ? LIMIT 1",
|
|
(self.account_id, str(signer), str(target))
|
|
) as cur:
|
|
return await cur.fetchone() is not None
|
|
|
|
async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
|
|
async with self.db.execute(
|
|
"SELECT COUNT(*) as cnt FROM crypto_signature WHERE account_id = ? AND signer = ?",
|
|
(self.account_id, str(signer))
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
count = row["cnt"]
|
|
|
|
await self.db.execute(
|
|
"DELETE FROM crypto_signature WHERE account_id = ? AND signer = ?",
|
|
(self.account_id, str(signer))
|
|
)
|
|
await self.db.commit()
|
|
return count
|
|
|