From 4a2d7ed10083ae37417566a85711af0cbb5536ea Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Wed, 1 Apr 2026 21:22:01 -0700 Subject: [PATCH] Move to FK pragma and prep other code points in light of that --- app/database.py | 16 +++- app/migrations.py | 164 +++++++++++++++++++++++++++++++++++++ app/packet_processor.py | 20 +++-- app/repository/messages.py | 3 + tests/test_api.py | 33 ++++++-- tests/test_migrations.py | 36 ++++---- 6 files changed, 238 insertions(+), 34 deletions(-) diff --git a/app/database.py b/app/database.py index 77f8897..8fba5bf 100644 --- a/app/database.py +++ b/app/database.py @@ -66,7 +66,7 @@ CREATE TABLE IF NOT EXISTS raw_packets ( data BLOB NOT NULL, message_id INTEGER, payload_hash BLOB, - FOREIGN KEY (message_id) REFERENCES messages(id) + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL ); CREATE TABLE IF NOT EXISTS contact_advert_paths ( @@ -78,7 +78,7 @@ CREATE TABLE IF NOT EXISTS contact_advert_paths ( last_seen INTEGER NOT NULL, heard_count INTEGER NOT NULL DEFAULT 1, UNIQUE(public_key, path_hex, path_len), - FOREIGN KEY (public_key) REFERENCES contacts(public_key) + FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS contact_name_history ( @@ -88,7 +88,7 @@ CREATE TABLE IF NOT EXISTS contact_name_history ( first_seen INTEGER NOT NULL, last_seen INTEGER NOT NULL, UNIQUE(public_key, name), - FOREIGN KEY (public_key) REFERENCES contacts(public_key) + FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_messages_received ON messages(received_at); @@ -132,6 +132,12 @@ class Database: # migration 20 handles the one-time VACUUM to restructure the file. await self._connection.execute("PRAGMA auto_vacuum = INCREMENTAL") + # Foreign key enforcement: must be set per-connection (not persisted). + # Disabled during schema init and migrations to avoid issues with + # historical table-rebuild migrations that may temporarily violate + # constraints, then re-enabled for all subsequent application queries. + await self._connection.execute("PRAGMA foreign_keys = OFF") + await self._connection.executescript(SCHEMA) await self._connection.commit() logger.debug("Database schema initialized") @@ -141,6 +147,10 @@ class Database: await run_migrations(self._connection) + # Enable FK enforcement for all application queries from this point on. + await self._connection.execute("PRAGMA foreign_keys = ON") + logger.debug("Foreign key enforcement enabled") + async def disconnect(self) -> None: if self._connection: await self._connection.close() diff --git a/app/migrations.py b/app/migrations.py index 07702b5..add8ef1 100644 --- a/app/migrations.py +++ b/app/migrations.py @@ -374,6 +374,14 @@ async def run_migrations(conn: aiosqlite.Connection) -> int: await set_version(conn, 48) applied += 1 + # Migration 49: Enable foreign key enforcement — rebuild tables with + # CASCADE / SET NULL and clean up any orphaned rows first. + if version < 49: + logger.info("Applying migration 49: add foreign key cascade/set-null and clean orphans") + await _migrate_049_foreign_key_cascade(conn) + await set_version(conn, 49) + applied += 1 + if applied > 0: logger.info( "Applied %d migration(s), schema now at version %d", applied, await get_version(conn) @@ -2938,3 +2946,159 @@ async def _migrate_048_discovery_blocked_types(conn: aiosqlite.Connection) -> No else: raise await conn.commit() + + +async def _migrate_049_foreign_key_cascade(conn: aiosqlite.Connection) -> None: + """Rebuild FK tables with CASCADE/SET NULL and clean orphaned rows. + + SQLite cannot ALTER existing FK constraints, so each table is rebuilt. + Orphaned child rows are cleaned up before the rebuild to ensure the + INSERT...SELECT into the new table (which has enforced FKs) succeeds. + """ + import shutil + from pathlib import Path + + # Back up the database before table rebuilds (skip for in-memory DBs). + cursor = await conn.execute("PRAGMA database_list") + db_row = await cursor.fetchone() + db_path = db_row[2] if db_row else "" + if db_path and db_path != ":memory:" and Path(db_path).exists(): + backup_path = db_path + ".pre-fk-migration.bak" + shutil.copy2(db_path, backup_path) + logger.info("Database backed up to %s before FK migration", backup_path) + + # --- Phase 1: clean orphans (guard each table's existence) --- + tables_cursor = await conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) + existing_tables = {row[0] for row in await tables_cursor.fetchall()} + + if "contact_advert_paths" in existing_tables and "contacts" in existing_tables: + await conn.execute( + "DELETE FROM contact_advert_paths " + "WHERE public_key NOT IN (SELECT public_key FROM contacts)" + ) + if "contact_name_history" in existing_tables and "contacts" in existing_tables: + await conn.execute( + "DELETE FROM contact_name_history " + "WHERE public_key NOT IN (SELECT public_key FROM contacts)" + ) + if "raw_packets" in existing_tables and "messages" in existing_tables: + # Guard: message_id column may not exist on very old schemas + col_cursor = await conn.execute("PRAGMA table_info(raw_packets)") + raw_cols = {row[1] for row in await col_cursor.fetchall()} + if "message_id" in raw_cols: + await conn.execute( + "UPDATE raw_packets SET message_id = NULL WHERE message_id IS NOT NULL " + "AND message_id NOT IN (SELECT id FROM messages)" + ) + await conn.commit() + logger.debug("Cleaned orphaned child rows before FK rebuild") + + # --- Phase 2: rebuild raw_packets with ON DELETE SET NULL --- + # Skip if raw_packets doesn't have message_id (pre-migration-18 schema) + raw_has_message_id = False + if "raw_packets" in existing_tables: + col_cursor2 = await conn.execute("PRAGMA table_info(raw_packets)") + raw_has_message_id = "message_id" in {row[1] for row in await col_cursor2.fetchall()} + + if raw_has_message_id: + # Dynamically build column list based on what the old table actually has, + # since very old schemas may lack payload_hash (added in migration 28). + col_cursor3 = await conn.execute("PRAGMA table_info(raw_packets)") + old_cols = [row[1] for row in await col_cursor3.fetchall()] + + new_col_defs = [ + "id INTEGER PRIMARY KEY AUTOINCREMENT", + "timestamp INTEGER NOT NULL", + "data BLOB NOT NULL", + "message_id INTEGER", + ] + copy_cols = ["id", "timestamp", "data", "message_id"] + if "payload_hash" in old_cols: + new_col_defs.append("payload_hash BLOB") + copy_cols.append("payload_hash") + new_col_defs.append( + "FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL" + ) + + cols_sql = ", ".join(new_col_defs) + copy_sql = ", ".join(copy_cols) + await conn.execute(f"CREATE TABLE raw_packets_fk ({cols_sql})") + await conn.execute( + f"INSERT INTO raw_packets_fk ({copy_sql}) SELECT {copy_sql} FROM raw_packets" + ) + await conn.execute("DROP TABLE raw_packets") + await conn.execute("ALTER TABLE raw_packets_fk RENAME TO raw_packets") + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_raw_packets_message_id ON raw_packets(message_id)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_raw_packets_timestamp ON raw_packets(timestamp)" + ) + if "payload_hash" in old_cols: + await conn.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_raw_packets_payload_hash ON raw_packets(payload_hash)" + ) + await conn.commit() + logger.debug("Rebuilt raw_packets with ON DELETE SET NULL") + + # --- Phase 3: rebuild contact_advert_paths with ON DELETE CASCADE --- + if "contact_advert_paths" in existing_tables: + await conn.execute( + """ + CREATE TABLE contact_advert_paths_fk ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + public_key TEXT NOT NULL, + path_hex TEXT NOT NULL, + path_len INTEGER NOT NULL, + first_seen INTEGER NOT NULL, + last_seen INTEGER NOT NULL, + heard_count INTEGER NOT NULL DEFAULT 1, + UNIQUE(public_key, path_hex, path_len), + FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE + ) + """ + ) + await conn.execute( + "INSERT INTO contact_advert_paths_fk (id, public_key, path_hex, path_len, first_seen, last_seen, heard_count) " + "SELECT id, public_key, path_hex, path_len, first_seen, last_seen, heard_count FROM contact_advert_paths" + ) + await conn.execute("DROP TABLE contact_advert_paths") + await conn.execute("ALTER TABLE contact_advert_paths_fk RENAME TO contact_advert_paths") + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_contact_advert_paths_recent " + "ON contact_advert_paths(public_key, last_seen DESC)" + ) + await conn.commit() + logger.debug("Rebuilt contact_advert_paths with ON DELETE CASCADE") + + # --- Phase 4: rebuild contact_name_history with ON DELETE CASCADE --- + if "contact_name_history" in existing_tables: + await conn.execute( + """ + CREATE TABLE contact_name_history_fk ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + public_key TEXT NOT NULL, + name TEXT NOT NULL, + first_seen INTEGER NOT NULL, + last_seen INTEGER NOT NULL, + UNIQUE(public_key, name), + FOREIGN KEY (public_key) REFERENCES contacts(public_key) ON DELETE CASCADE + ) + """ + ) + await conn.execute( + "INSERT INTO contact_name_history_fk (id, public_key, name, first_seen, last_seen) " + "SELECT id, public_key, name, first_seen, last_seen FROM contact_name_history" + ) + await conn.execute("DROP TABLE contact_name_history") + await conn.execute("ALTER TABLE contact_name_history_fk RENAME TO contact_name_history") + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_contact_name_history_key " + "ON contact_name_history(public_key, last_seen DESC)" + ) + await conn.commit() + logger.debug("Rebuilt contact_name_history with ON DELETE CASCADE") + + diff --git a/app/packet_processor.py b/app/packet_processor.py index c2ce89a..3d1ab18 100644 --- a/app/packet_processor.py +++ b/app/packet_processor.py @@ -476,15 +476,6 @@ async def _process_advertisement( ) return - # Keep recent unique advert paths for all contacts. - await ContactAdvertPathRepository.record_observation( - public_key=advert.public_key.lower(), - path_hex=new_path_hex, - timestamp=timestamp, - max_paths=10, - hop_count=new_path_len, - ) - contact_upsert = ContactUpsert( public_key=advert.public_key.lower(), name=advert.name, @@ -496,7 +487,18 @@ async def _process_advertisement( first_seen=timestamp, # COALESCE in upsert preserves existing value ) + # Upsert the contact BEFORE recording advert paths so the parent row + # exists when foreign key enforcement is enabled. await ContactRepository.upsert(contact_upsert) + + # Keep recent unique advert paths for all contacts. + await ContactAdvertPathRepository.record_observation( + public_key=advert.public_key.lower(), + path_hex=new_path_hex, + timestamp=timestamp, + max_paths=10, + hop_count=new_path_len, + ) promoted_keys = await promote_prefix_contacts_for_contact( public_key=advert.public_key, log=logger, diff --git a/app/repository/messages.py b/app/repository/messages.py index 28e7d59..7c1c1aa 100644 --- a/app/repository/messages.py +++ b/app/repository/messages.py @@ -572,6 +572,9 @@ class MessageRepository: @staticmethod async def delete_by_id(message_id: int) -> None: """Delete a message row by ID.""" + await db.conn.execute( + "UPDATE raw_packets SET message_id = NULL WHERE message_id = ?", (message_id,) + ) await db.conn.execute("DELETE FROM messages WHERE id = ?", (message_id,)) await db.conn.commit() diff --git a/tests/test_api.py b/tests/test_api.py index d35da18..264655b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1167,7 +1167,14 @@ class TestRawPacketRepository: await RawPacketRepository.create(b"\x04\x05\x06", recent_timestamp) # Insert old but decrypted packet (should NOT be deleted) old_id, _ = await RawPacketRepository.create(b"\x07\x08\x09", old_timestamp) - await RawPacketRepository.mark_decrypted(old_id, 1) + msg_id = await MessageRepository.create( + msg_type="PRIV", + conversation_key="test_key", + text="test", + sender_timestamp=old_timestamp, + received_at=old_timestamp, + ) + await RawPacketRepository.mark_decrypted(old_id, msg_id) # Prune packets older than 10 days deleted = await RawPacketRepository.prune_old_undecrypted(10) @@ -1191,10 +1198,18 @@ class TestRawPacketRepository: async def test_purge_linked_to_messages_deletes_only_linked_packets(self, test_db): """Purge linked raw packets removes only rows with a message_id.""" ts = int(time.time()) + msg_id_1 = await MessageRepository.create( + msg_type="PRIV", conversation_key="k1", text="t1", + sender_timestamp=ts, received_at=ts, + ) + msg_id_2 = await MessageRepository.create( + msg_type="PRIV", conversation_key="k2", text="t2", + sender_timestamp=ts, received_at=ts, + ) linked_1, _ = await RawPacketRepository.create(b"\x01\x02\x03", ts) linked_2, _ = await RawPacketRepository.create(b"\x04\x05\x06", ts) - await RawPacketRepository.mark_decrypted(linked_1, 101) - await RawPacketRepository.mark_decrypted(linked_2, 102) + await RawPacketRepository.mark_decrypted(linked_1, msg_id_1) + await RawPacketRepository.mark_decrypted(linked_2, msg_id_2) await RawPacketRepository.create(b"\x07\x08\x09", ts) # undecrypted, should remain @@ -1232,10 +1247,18 @@ class TestMaintenanceEndpoint: from app.routers.packets import MaintenanceRequest, run_maintenance ts = int(time.time()) + msg_id_1 = await MessageRepository.create( + msg_type="PRIV", conversation_key="k1", text="t1", + sender_timestamp=ts, received_at=ts, + ) + msg_id_2 = await MessageRepository.create( + msg_type="PRIV", conversation_key="k2", text="t2", + sender_timestamp=ts, received_at=ts, + ) linked_1, _ = await RawPacketRepository.create(b"\x0a\x0b\x0c", ts) linked_2, _ = await RawPacketRepository.create(b"\x0d\x0e\x0f", ts) - await RawPacketRepository.mark_decrypted(linked_1, 201) - await RawPacketRepository.mark_decrypted(linked_2, 202) + await RawPacketRepository.mark_decrypted(linked_1, msg_id_1) + await RawPacketRepository.mark_decrypted(linked_2, msg_id_2) request = MaintenanceRequest(purge_linked_raw_packets=True) result = await run_maintenance(request) diff --git a/tests/test_migrations.py b/tests/test_migrations.py index c12c652..342acc1 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -513,7 +513,9 @@ class TestMigration018: from hashlib import sha256 assert bytes(rows[0]["payload_hash"]) == sha256(b"hash_a").digest() - assert rows[1]["message_id"] == 42 + # message_id=42 was orphaned (no matching messages row), so + # migration 49's orphan cleanup NULLs it out. + assert rows[1]["message_id"] is None # Verify payload_hash unique index still works cursor = await conn.execute( @@ -1247,8 +1249,8 @@ class TestMigration039: applied = await run_migrations(conn) - assert applied == 10 - assert await get_version(conn) == 48 + assert applied == 11 + assert await get_version(conn) == 49 cursor = await conn.execute( """ @@ -1319,8 +1321,8 @@ class TestMigration039: applied = await run_migrations(conn) - assert applied == 10 - assert await get_version(conn) == 48 + assert applied == 11 + assert await get_version(conn) == 49 cursor = await conn.execute( """ @@ -1386,8 +1388,8 @@ class TestMigration039: applied = await run_migrations(conn) - assert applied == 4 - assert await get_version(conn) == 48 + assert applied == 5 + assert await get_version(conn) == 49 cursor = await conn.execute( """ @@ -1439,8 +1441,8 @@ class TestMigration040: applied = await run_migrations(conn) - assert applied == 9 - assert await get_version(conn) == 48 + assert applied == 10 + assert await get_version(conn) == 49 await conn.execute( """ @@ -1501,8 +1503,8 @@ class TestMigration041: applied = await run_migrations(conn) - assert applied == 8 - assert await get_version(conn) == 48 + assert applied == 9 + assert await get_version(conn) == 49 await conn.execute( """ @@ -1554,8 +1556,8 @@ class TestMigration042: applied = await run_migrations(conn) - assert applied == 7 - assert await get_version(conn) == 48 + assert applied == 8 + assert await get_version(conn) == 49 await conn.execute( """ @@ -1694,8 +1696,8 @@ class TestMigration046: applied = await run_migrations(conn) - assert applied == 3 - assert await get_version(conn) == 48 + assert applied == 4 + assert await get_version(conn) == 49 cursor = await conn.execute( """ @@ -1788,8 +1790,8 @@ class TestMigration047: applied = await run_migrations(conn) - assert applied == 2 - assert await get_version(conn) == 48 + assert applied == 3 + assert await get_version(conn) == 49 cursor = await conn.execute( """