Move to FK pragma and prep other code points in light of that

This commit is contained in:
Jack Kingsman
2026-04-01 21:22:01 -07:00
parent 47c4f038fe
commit 4a2d7ed100
6 changed files with 238 additions and 34 deletions

View File

@@ -66,7 +66,7 @@ CREATE TABLE IF NOT EXISTS raw_packets (
data BLOB NOT NULL,
message_id INTEGER,
payload_hash BLOB,
FOREIGN KEY (message_id) REFERENCES messages(id)
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL
);
CREATE TABLE IF NOT EXISTS contact_advert_paths (
@@ -78,7 +78,7 @@ CREATE TABLE IF NOT EXISTS contact_advert_paths (
last_seen INTEGER NOT NULL,
heard_count INTEGER NOT NULL DEFAULT 1,
UNIQUE(public_key, path_hex, path_len),
FOREIGN KEY (public_key) REFERENCES contacts(public_key)
FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS contact_name_history (
@@ -88,7 +88,7 @@ CREATE TABLE IF NOT EXISTS contact_name_history (
first_seen INTEGER NOT NULL,
last_seen INTEGER NOT NULL,
UNIQUE(public_key, name),
FOREIGN KEY (public_key) REFERENCES contacts(public_key)
FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_messages_received ON messages(received_at);
@@ -132,6 +132,12 @@ class Database:
# migration 20 handles the one-time VACUUM to restructure the file.
await self._connection.execute("PRAGMA auto_vacuum = INCREMENTAL")
# Foreign key enforcement: must be set per-connection (not persisted).
# Disabled during schema init and migrations to avoid issues with
# historical table-rebuild migrations that may temporarily violate
# constraints, then re-enabled for all subsequent application queries.
await self._connection.execute("PRAGMA foreign_keys = OFF")
await self._connection.executescript(SCHEMA)
await self._connection.commit()
logger.debug("Database schema initialized")
@@ -141,6 +147,10 @@ class Database:
await run_migrations(self._connection)
# Enable FK enforcement for all application queries from this point on.
await self._connection.execute("PRAGMA foreign_keys = ON")
logger.debug("Foreign key enforcement enabled")
async def disconnect(self) -> None:
if self._connection:
await self._connection.close()

View File

@@ -374,6 +374,14 @@ async def run_migrations(conn: aiosqlite.Connection) -> int:
await set_version(conn, 48)
applied += 1
# Migration 49: Enable foreign key enforcement — rebuild tables with
# CASCADE / SET NULL and clean up any orphaned rows first.
if version < 49:
logger.info("Applying migration 49: add foreign key cascade/set-null and clean orphans")
await _migrate_049_foreign_key_cascade(conn)
await set_version(conn, 49)
applied += 1
if applied > 0:
logger.info(
"Applied %d migration(s), schema now at version %d", applied, await get_version(conn)
@@ -2938,3 +2946,159 @@ async def _migrate_048_discovery_blocked_types(conn: aiosqlite.Connection) -> No
else:
raise
await conn.commit()
async def _migrate_049_foreign_key_cascade(conn: aiosqlite.Connection) -> None:
"""Rebuild FK tables with CASCADE/SET NULL and clean orphaned rows.
SQLite cannot ALTER existing FK constraints, so each table is rebuilt.
Orphaned child rows are cleaned up before the rebuild to ensure the
INSERT...SELECT into the new table (which has enforced FKs) succeeds.
"""
import shutil
from pathlib import Path
# Back up the database before table rebuilds (skip for in-memory DBs).
cursor = await conn.execute("PRAGMA database_list")
db_row = await cursor.fetchone()
db_path = db_row[2] if db_row else ""
if db_path and db_path != ":memory:" and Path(db_path).exists():
backup_path = db_path + ".pre-fk-migration.bak"
shutil.copy2(db_path, backup_path)
logger.info("Database backed up to %s before FK migration", backup_path)
# --- Phase 1: clean orphans (guard each table's existence) ---
tables_cursor = await conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
existing_tables = {row[0] for row in await tables_cursor.fetchall()}
if "contact_advert_paths" in existing_tables and "contacts" in existing_tables:
await conn.execute(
"DELETE FROM contact_advert_paths "
"WHERE public_key NOT IN (SELECT public_key FROM contacts)"
)
if "contact_name_history" in existing_tables and "contacts" in existing_tables:
await conn.execute(
"DELETE FROM contact_name_history "
"WHERE public_key NOT IN (SELECT public_key FROM contacts)"
)
if "raw_packets" in existing_tables and "messages" in existing_tables:
# Guard: message_id column may not exist on very old schemas
col_cursor = await conn.execute("PRAGMA table_info(raw_packets)")
raw_cols = {row[1] for row in await col_cursor.fetchall()}
if "message_id" in raw_cols:
await conn.execute(
"UPDATE raw_packets SET message_id = NULL WHERE message_id IS NOT NULL "
"AND message_id NOT IN (SELECT id FROM messages)"
)
await conn.commit()
logger.debug("Cleaned orphaned child rows before FK rebuild")
# --- Phase 2: rebuild raw_packets with ON DELETE SET NULL ---
# Skip if raw_packets doesn't have message_id (pre-migration-18 schema)
raw_has_message_id = False
if "raw_packets" in existing_tables:
col_cursor2 = await conn.execute("PRAGMA table_info(raw_packets)")
raw_has_message_id = "message_id" in {row[1] for row in await col_cursor2.fetchall()}
if raw_has_message_id:
# Dynamically build column list based on what the old table actually has,
# since very old schemas may lack payload_hash (added in migration 28).
col_cursor3 = await conn.execute("PRAGMA table_info(raw_packets)")
old_cols = [row[1] for row in await col_cursor3.fetchall()]
new_col_defs = [
"id INTEGER PRIMARY KEY AUTOINCREMENT",
"timestamp INTEGER NOT NULL",
"data BLOB NOT NULL",
"message_id INTEGER",
]
copy_cols = ["id", "timestamp", "data", "message_id"]
if "payload_hash" in old_cols:
new_col_defs.append("payload_hash BLOB")
copy_cols.append("payload_hash")
new_col_defs.append(
"FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL"
)
cols_sql = ", ".join(new_col_defs)
copy_sql = ", ".join(copy_cols)
await conn.execute(f"CREATE TABLE raw_packets_fk ({cols_sql})")
await conn.execute(
f"INSERT INTO raw_packets_fk ({copy_sql}) SELECT {copy_sql} FROM raw_packets"
)
await conn.execute("DROP TABLE raw_packets")
await conn.execute("ALTER TABLE raw_packets_fk RENAME TO raw_packets")
await conn.execute(
"CREATE INDEX IF NOT EXISTS idx_raw_packets_message_id ON raw_packets(message_id)"
)
await conn.execute(
"CREATE INDEX IF NOT EXISTS idx_raw_packets_timestamp ON raw_packets(timestamp)"
)
if "payload_hash" in old_cols:
await conn.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_raw_packets_payload_hash ON raw_packets(payload_hash)"
)
await conn.commit()
logger.debug("Rebuilt raw_packets with ON DELETE SET NULL")
# --- Phase 3: rebuild contact_advert_paths with ON DELETE CASCADE ---
if "contact_advert_paths" in existing_tables:
await conn.execute(
"""
CREATE TABLE contact_advert_paths_fk (
id INTEGER PRIMARY KEY AUTOINCREMENT,
public_key TEXT NOT NULL,
path_hex TEXT NOT NULL,
path_len INTEGER NOT NULL,
first_seen INTEGER NOT NULL,
last_seen INTEGER NOT NULL,
heard_count INTEGER NOT NULL DEFAULT 1,
UNIQUE(public_key, path_hex, path_len),
FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE
)
"""
)
await conn.execute(
"INSERT INTO contact_advert_paths_fk (id, public_key, path_hex, path_len, first_seen, last_seen, heard_count) "
"SELECT id, public_key, path_hex, path_len, first_seen, last_seen, heard_count FROM contact_advert_paths"
)
await conn.execute("DROP TABLE contact_advert_paths")
await conn.execute("ALTER TABLE contact_advert_paths_fk RENAME TO contact_advert_paths")
await conn.execute(
"CREATE INDEX IF NOT EXISTS idx_contact_advert_paths_recent "
"ON contact_advert_paths(public_key, last_seen DESC)"
)
await conn.commit()
logger.debug("Rebuilt contact_advert_paths with ON DELETE CASCADE")
# --- Phase 4: rebuild contact_name_history with ON DELETE CASCADE ---
if "contact_name_history" in existing_tables:
await conn.execute(
"""
CREATE TABLE contact_name_history_fk (
id INTEGER PRIMARY KEY AUTOINCREMENT,
public_key TEXT NOT NULL,
name TEXT NOT NULL,
first_seen INTEGER NOT NULL,
last_seen INTEGER NOT NULL,
UNIQUE(public_key, name),
FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE
)
"""
)
await conn.execute(
"INSERT INTO contact_name_history_fk (id, public_key, name, first_seen, last_seen) "
"SELECT id, public_key, name, first_seen, last_seen FROM contact_name_history"
)
await conn.execute("DROP TABLE contact_name_history")
await conn.execute("ALTER TABLE contact_name_history_fk RENAME TO contact_name_history")
await conn.execute(
"CREATE INDEX IF NOT EXISTS idx_contact_name_history_key "
"ON contact_name_history(public_key, last_seen DESC)"
)
await conn.commit()
logger.debug("Rebuilt contact_name_history with ON DELETE CASCADE")

View File

@@ -476,15 +476,6 @@ async def _process_advertisement(
)
return
# Keep recent unique advert paths for all contacts.
await ContactAdvertPathRepository.record_observation(
public_key=advert.public_key.lower(),
path_hex=new_path_hex,
timestamp=timestamp,
max_paths=10,
hop_count=new_path_len,
)
contact_upsert = ContactUpsert(
public_key=advert.public_key.lower(),
name=advert.name,
@@ -496,7 +487,18 @@ async def _process_advertisement(
first_seen=timestamp, # COALESCE in upsert preserves existing value
)
# Upsert the contact BEFORE recording advert paths so the parent row
# exists when foreign key enforcement is enabled.
await ContactRepository.upsert(contact_upsert)
# Keep recent unique advert paths for all contacts.
await ContactAdvertPathRepository.record_observation(
public_key=advert.public_key.lower(),
path_hex=new_path_hex,
timestamp=timestamp,
max_paths=10,
hop_count=new_path_len,
)
promoted_keys = await promote_prefix_contacts_for_contact(
public_key=advert.public_key,
log=logger,

View File

@@ -572,6 +572,9 @@ class MessageRepository:
@staticmethod
async def delete_by_id(message_id: int) -> None:
"""Delete a message row by ID."""
await db.conn.execute(
"UPDATE raw_packets SET message_id = NULL WHERE message_id = ?", (message_id,)
)
await db.conn.execute("DELETE FROM messages WHERE id = ?", (message_id,))
await db.conn.commit()

View File

@@ -1167,7 +1167,14 @@ class TestRawPacketRepository:
await RawPacketRepository.create(b"\x04\x05\x06", recent_timestamp)
# Insert old but decrypted packet (should NOT be deleted)
old_id, _ = await RawPacketRepository.create(b"\x07\x08\x09", old_timestamp)
await RawPacketRepository.mark_decrypted(old_id, 1)
msg_id = await MessageRepository.create(
msg_type="PRIV",
conversation_key="test_key",
text="test",
sender_timestamp=old_timestamp,
received_at=old_timestamp,
)
await RawPacketRepository.mark_decrypted(old_id, msg_id)
# Prune packets older than 10 days
deleted = await RawPacketRepository.prune_old_undecrypted(10)
@@ -1191,10 +1198,18 @@ class TestRawPacketRepository:
async def test_purge_linked_to_messages_deletes_only_linked_packets(self, test_db):
"""Purge linked raw packets removes only rows with a message_id."""
ts = int(time.time())
msg_id_1 = await MessageRepository.create(
msg_type="PRIV", conversation_key="k1", text="t1",
sender_timestamp=ts, received_at=ts,
)
msg_id_2 = await MessageRepository.create(
msg_type="PRIV", conversation_key="k2", text="t2",
sender_timestamp=ts, received_at=ts,
)
linked_1, _ = await RawPacketRepository.create(b"\x01\x02\x03", ts)
linked_2, _ = await RawPacketRepository.create(b"\x04\x05\x06", ts)
await RawPacketRepository.mark_decrypted(linked_1, 101)
await RawPacketRepository.mark_decrypted(linked_2, 102)
await RawPacketRepository.mark_decrypted(linked_1, msg_id_1)
await RawPacketRepository.mark_decrypted(linked_2, msg_id_2)
await RawPacketRepository.create(b"\x07\x08\x09", ts) # undecrypted, should remain
@@ -1232,10 +1247,18 @@ class TestMaintenanceEndpoint:
from app.routers.packets import MaintenanceRequest, run_maintenance
ts = int(time.time())
msg_id_1 = await MessageRepository.create(
msg_type="PRIV", conversation_key="k1", text="t1",
sender_timestamp=ts, received_at=ts,
)
msg_id_2 = await MessageRepository.create(
msg_type="PRIV", conversation_key="k2", text="t2",
sender_timestamp=ts, received_at=ts,
)
linked_1, _ = await RawPacketRepository.create(b"\x0a\x0b\x0c", ts)
linked_2, _ = await RawPacketRepository.create(b"\x0d\x0e\x0f", ts)
await RawPacketRepository.mark_decrypted(linked_1, 201)
await RawPacketRepository.mark_decrypted(linked_2, 202)
await RawPacketRepository.mark_decrypted(linked_1, msg_id_1)
await RawPacketRepository.mark_decrypted(linked_2, msg_id_2)
request = MaintenanceRequest(purge_linked_raw_packets=True)
result = await run_maintenance(request)

View File

@@ -513,7 +513,9 @@ class TestMigration018:
from hashlib import sha256
assert bytes(rows[0]["payload_hash"]) == sha256(b"hash_a").digest()
assert rows[1]["message_id"] == 42
# message_id=42 was orphaned (no matching messages row), so
# migration 49's orphan cleanup NULLs it out.
assert rows[1]["message_id"] is None
# Verify payload_hash unique index still works
cursor = await conn.execute(
@@ -1247,8 +1249,8 @@ class TestMigration039:
applied = await run_migrations(conn)
assert applied == 10
assert await get_version(conn) == 48
assert applied == 11
assert await get_version(conn) == 49
cursor = await conn.execute(
"""
@@ -1319,8 +1321,8 @@ class TestMigration039:
applied = await run_migrations(conn)
assert applied == 10
assert await get_version(conn) == 48
assert applied == 11
assert await get_version(conn) == 49
cursor = await conn.execute(
"""
@@ -1386,8 +1388,8 @@ class TestMigration039:
applied = await run_migrations(conn)
assert applied == 4
assert await get_version(conn) == 48
assert applied == 5
assert await get_version(conn) == 49
cursor = await conn.execute(
"""
@@ -1439,8 +1441,8 @@ class TestMigration040:
applied = await run_migrations(conn)
assert applied == 9
assert await get_version(conn) == 48
assert applied == 10
assert await get_version(conn) == 49
await conn.execute(
"""
@@ -1501,8 +1503,8 @@ class TestMigration041:
applied = await run_migrations(conn)
assert applied == 8
assert await get_version(conn) == 48
assert applied == 9
assert await get_version(conn) == 49
await conn.execute(
"""
@@ -1554,8 +1556,8 @@ class TestMigration042:
applied = await run_migrations(conn)
assert applied == 7
assert await get_version(conn) == 48
assert applied == 8
assert await get_version(conn) == 49
await conn.execute(
"""
@@ -1694,8 +1696,8 @@ class TestMigration046:
applied = await run_migrations(conn)
assert applied == 3
assert await get_version(conn) == 48
assert applied == 4
assert await get_version(conn) == 49
cursor = await conn.execute(
"""
@@ -1788,8 +1790,8 @@ class TestMigration047:
applied = await run_migrations(conn)
assert applied == 2
assert await get_version(conn) == 48
assert applied == 3
assert await get_version(conn) == 49
cursor = await conn.execute(
"""