From 43abcd07b20bf48dde7d5131a63a2560ea31db4d Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 30 Mar 2026 21:31:59 -0700 Subject: [PATCH] Improve DB streaming perf for cracking and statistics --- app/packet_processor.py | 20 ++++++---- app/repository/raw_packets.py | 52 +++++++++++++++++++------- app/repository/settings.py | 70 ++++++++++++++++++++--------------- app/routers/debug.py | 8 ++-- app/routers/packets.py | 3 +- tests/test_api.py | 23 ++++++++++++ tests/test_packets_router.py | 39 ++++++++++++++++++- tests/test_statistics.py | 47 +++++++++++++++++++++++ 8 files changed, 203 insertions(+), 59 deletions(-) diff --git a/app/packet_processor.py b/app/packet_processor.py index 32ddd9f..8a43618 100644 --- a/app/packet_processor.py +++ b/app/packet_processor.py @@ -122,20 +122,20 @@ async def run_historical_dm_decryption( """Background task to decrypt historical DM packets with contact's key.""" from app.websocket import broadcast_success - packets = await RawPacketRepository.get_undecrypted_text_messages() - total = len(packets) + total = 0 decrypted_count = 0 - if total == 0: - logger.info("No undecrypted TEXT_MESSAGE packets to process") - return - - logger.info("Starting historical DM decryption of %d TEXT_MESSAGE packets", total) + logger.info("Starting historical DM decryption scan for undecrypted TEXT_MESSAGE packets") # Derive our public key from the private key our_public_key_bytes = derive_public_key(private_key_bytes) - for packet_id, packet_data, packet_timestamp in packets: + async for ( + packet_id, + packet_data, + packet_timestamp, + ) in RawPacketRepository.stream_undecrypted_text_messages(): + total += 1 # Note: passing our_public_key=None disables the outbound hash check in # try_decrypt_dm (only the inbound check src_hash == their_first_byte runs). # For the 255/256 case where our first byte differs from the contact's, @@ -187,6 +187,10 @@ async def run_historical_dm_decryption( if msg_id is not None: decrypted_count += 1 + if total == 0: + logger.info("No undecrypted TEXT_MESSAGE packets to process") + return + logger.info( "Historical DM decryption complete: %d/%d packets decrypted", decrypted_count, diff --git a/app/repository/raw_packets.py b/app/repository/raw_packets.py index c773a67..29ead2c 100644 --- a/app/repository/raw_packets.py +++ b/app/repository/raw_packets.py @@ -1,6 +1,7 @@ import logging import sqlite3 import time +from collections.abc import AsyncIterator from hashlib import sha256 from app.database import db @@ -8,6 +9,8 @@ from app.decoder import PayloadType, extract_payload, get_packet_payload_type logger = logging.getLogger(__name__) +UNDECRYPTED_PACKET_BATCH_SIZE = 500 + class RawPacketRepository: @staticmethod @@ -100,6 +103,40 @@ class RawPacketRepository: rows = await cursor.fetchall() return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows] + @staticmethod + async def stream_undecrypted_text_messages( + batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE, + ) -> AsyncIterator[tuple[int, bytes, int]]: + """Yield undecrypted TEXT_MESSAGE packets in bounded-size batches.""" + cursor = await db.conn.execute( + "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" + ) + try: + while True: + rows = await cursor.fetchmany(batch_size) + if not rows: + break + + for row in rows: + data = bytes(row["data"]) + payload_type = get_packet_payload_type(data) + if payload_type == PayloadType.TEXT_MESSAGE: + yield (row["id"], data, row["timestamp"]) + finally: + await cursor.close() + + @staticmethod + async def count_undecrypted_text_messages( + batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE, + ) -> int: + """Count undecrypted TEXT_MESSAGE packets without materializing them all.""" + count = 0 + async for _packet in RawPacketRepository.stream_undecrypted_text_messages( + batch_size=batch_size + ): + count += 1 + return count + @staticmethod async def mark_decrypted(packet_id: int, message_id: int) -> None: """Link a raw packet to its decrypted message.""" @@ -158,17 +195,4 @@ class RawPacketRepository: Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02). These are direct messages that can be decrypted with contact ECDH keys. """ - cursor = await db.conn.execute( - "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" - ) - rows = await cursor.fetchall() - - # Filter for TEXT_MESSAGE packets - result = [] - for row in rows: - data = bytes(row["data"]) - payload_type = get_packet_payload_type(data) - if payload_type == PayloadType.TEXT_MESSAGE: - result.append((row["id"], data, row["timestamp"])) - - return result + return [packet async for packet in RawPacketRepository.stream_undecrypted_text_messages()] diff --git a/app/repository/settings.py b/app/repository/settings.py index 23c41ce..2428feb 100644 --- a/app/repository/settings.py +++ b/app/repository/settings.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) SECONDS_1H = 3600 SECONDS_24H = 86400 SECONDS_7D = 604800 +RAW_PACKET_STATS_BATCH_SIZE = 500 class AppSettingsRepository: @@ -246,6 +247,26 @@ class AppSettingsRepository: class StatisticsRepository: + @staticmethod + async def get_database_message_totals() -> dict[str, int]: + """Return message totals needed by lightweight debug surfaces.""" + cursor = await db.conn.execute( + """ + SELECT + SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms, + SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages, + SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing + FROM messages + """ + ) + row = await cursor.fetchone() + assert row is not None + return { + "total_dms": row["total_dms"] or 0, + "total_channel_messages": row["total_channel_messages"] or 0, + "total_outgoing": row["total_outgoing"] or 0, + } + @staticmethod async def _activity_counts(*, contact_type: int, exclude: bool = False) -> dict[str, int]: """Get time-windowed counts for contacts/repeaters heard.""" @@ -311,22 +332,26 @@ class StatisticsRepository: "SELECT data FROM raw_packets WHERE timestamp >= ?", (now - SECONDS_24H,), ) - rows = await cursor.fetchall() single_byte = 0 double_byte = 0 triple_byte = 0 - for row in rows: - envelope = parse_packet_envelope(bytes(row["data"])) - if envelope is None: - continue - if envelope.hash_size == 1: - single_byte += 1 - elif envelope.hash_size == 2: - double_byte += 1 - elif envelope.hash_size == 3: - triple_byte += 1 + while True: + rows = await cursor.fetchmany(RAW_PACKET_STATS_BATCH_SIZE) + if not rows: + break + + for row in rows: + envelope = parse_packet_envelope(bytes(row["data"])) + if envelope is None: + continue + if envelope.hash_size == 1: + single_byte += 1 + elif envelope.hash_size == 2: + double_byte += 1 + elif envelope.hash_size == 3: + triple_byte += 1 total_packets = single_byte + double_byte + triple_byte if total_packets == 0: @@ -409,22 +434,7 @@ class StatisticsRepository: decrypted_packets = pkt_row["decrypted"] or 0 undecrypted_packets = total_packets - decrypted_packets - # Message type counts - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'PRIV'") - row = await cursor.fetchone() - assert row is not None - total_dms: int = row["cnt"] - - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'CHAN'") - row = await cursor.fetchone() - assert row is not None - total_channel_messages: int = row["cnt"] - - # Outgoing count - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE outgoing = 1") - row = await cursor.fetchone() - assert row is not None - total_outgoing: int = row["cnt"] + message_totals = await StatisticsRepository.get_database_message_totals() # Activity windows contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True) @@ -440,9 +450,9 @@ class StatisticsRepository: "total_packets": total_packets, "decrypted_packets": decrypted_packets, "undecrypted_packets": undecrypted_packets, - "total_dms": total_dms, - "total_channel_messages": total_channel_messages, - "total_outgoing": total_outgoing, + "total_dms": message_totals["total_dms"], + "total_channel_messages": message_totals["total_channel_messages"], + "total_outgoing": message_totals["total_outgoing"], "contacts_heard": contacts_heard, "repeaters_heard": repeaters_heard, "known_channels_active": known_channels_active, diff --git a/app/routers/debug.py b/app/routers/debug.py index 55c1f65..f7573e9 100644 --- a/app/routers/debug.py +++ b/app/routers/debug.py @@ -265,7 +265,7 @@ async def _probe_radio() -> DebugRadioProbe: async def debug_support_snapshot() -> DebugSnapshotResponse: """Return a support/debug snapshot with recent logs and live radio state.""" health_data = await build_health_data(radio_runtime.is_connected, radio_runtime.connection_info) - statistics = await StatisticsRepository.get_all() + message_totals = await StatisticsRepository.get_database_message_totals() radio_probe = await _probe_radio() channels_with_incoming_messages = ( await MessageRepository.count_channels_with_incoming_messages() @@ -291,9 +291,9 @@ async def debug_support_snapshot() -> DebugSnapshotResponse: }, ), database=DebugDatabaseInfo( - total_dms=statistics["total_dms"], - total_channel_messages=statistics["total_channel_messages"], - total_outgoing=statistics["total_outgoing"], + total_dms=message_totals["total_dms"], + total_channel_messages=message_totals["total_channel_messages"], + total_outgoing=message_totals["total_outgoing"], ), radio_probe=radio_probe, logs=[*LOG_COPY_BOUNDARY_PREFIX, *get_recent_log_lines(limit=1000)], diff --git a/app/routers/packets.py b/app/routers/packets.py index 4c6374c..40475e4 100644 --- a/app/routers/packets.py +++ b/app/routers/packets.py @@ -210,8 +210,7 @@ async def decrypt_historical_packets( except ValueError: raise _bad_request("Invalid hex string for contact public key") from None - packets = await RawPacketRepository.get_undecrypted_text_messages() - count = len(packets) + count = await RawPacketRepository.count_undecrypted_text_messages() if count == 0: return DecryptResult( started=False, diff --git a/tests/test_api.py b/tests/test_api.py index 574b968..4cf90f8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -190,6 +190,29 @@ class TestDebugEndpoint: assert payload["database"]["total_channel_messages"] == 1 assert payload["database"]["total_outgoing"] == 1 + @pytest.mark.asyncio + async def test_support_snapshot_uses_lightweight_message_totals(self, test_db, client): + """Debug snapshot should not call the full statistics aggregation.""" + with ( + patch( + "app.routers.debug.StatisticsRepository.get_all", + new=AsyncMock(side_effect=AssertionError("get_all should not be called")), + ), + patch( + "app.routers.debug.StatisticsRepository.get_database_message_totals", + new=AsyncMock( + return_value={ + "total_dms": 0, + "total_channel_messages": 0, + "total_outgoing": 0, + } + ), + ), + ): + response = await client.get("/api/debug") + + assert response.status_code == 200 + class TestRadioDisconnectedHandler: """Test that RadioDisconnectedError maps to 503.""" diff --git a/tests/test_packets_router.py b/tests/test_packets_router.py index 339b5a5..d9d1243 100644 --- a/tests/test_packets_router.py +++ b/tests/test_packets_router.py @@ -5,7 +5,7 @@ undecrypted count endpoint, and the maintenance endpoint. """ import time -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -305,6 +305,43 @@ class TestDecryptHistoricalPackets: assert "key_type" in data["detail"].lower() +class TestUndecryptedTextPacketStreaming: + @pytest.mark.asyncio + async def test_count_undecrypted_text_messages_uses_batched_streaming(self, test_db): + """Counting undecrypted DM packets should stream batches and filter by payload type.""" + + class FakeCursor: + def __init__(self): + self._batches = [ + [ + {"id": 1, "data": b"\x09\x00dm", "timestamp": 1000}, + {"id": 2, "data": b"\x15\x00chan", "timestamp": 1001}, + ], + [{"id": 3, "data": b"\x09\x00dm2", "timestamp": 1002}], + [], + ] + self.fetchall_called = False + + async def fetchmany(self, size): + assert size > 0 + return self._batches.pop(0) + + async def close(self): + return None + + async def fetchall(self): + self.fetchall_called = True + raise AssertionError("fetchall() should not be used") + + fake_cursor = FakeCursor() + + with patch.object(test_db.conn, "execute", new=AsyncMock(return_value=fake_cursor)): + count = await RawPacketRepository.count_undecrypted_text_messages(batch_size=2) + + assert fake_cursor.fetchall_called is False + assert count == 2 + + class TestRunHistoricalChannelDecryption: """Test the _run_historical_channel_decryption background task.""" diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 6a50531..79d5ff2 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -1,6 +1,7 @@ """Tests for the statistics repository and endpoint.""" import time +from types import SimpleNamespace from unittest.mock import AsyncMock, patch import pytest @@ -349,6 +350,52 @@ class TestPathHashWidthStats: assert breakdown["double_byte_pct"] == pytest.approx(100 / 3, rel=1e-3) assert breakdown["triple_byte_pct"] == pytest.approx(100 / 3, rel=1e-3) + @pytest.mark.asyncio + async def test_path_hash_width_scan_uses_batched_fetchmany(self, test_db): + """Hash-width stats should stream batches instead of calling fetchall().""" + + class FakeCursor: + def __init__(self): + self._batches = [ + [{"data": b"a"}, {"data": b"b"}], + [{"data": b"c"}], + [], + ] + self.fetchall_called = False + + async def fetchmany(self, size): + assert size > 0 + return self._batches.pop(0) + + async def fetchall(self): + self.fetchall_called = True + raise AssertionError("fetchall() should not be used") + + fake_cursor = FakeCursor() + + def fake_parse(raw_packet: bytes): + hash_sizes = { + b"a": 1, + b"b": 2, + b"c": 3, + } + hash_size = hash_sizes.get(raw_packet) + if hash_size is None: + return None + return SimpleNamespace(hash_size=hash_size) + + with ( + patch.object(test_db.conn, "execute", new=AsyncMock(return_value=fake_cursor)), + patch("app.repository.settings.parse_packet_envelope", side_effect=fake_parse), + ): + breakdown = await StatisticsRepository._path_hash_width_24h() + + assert fake_cursor.fetchall_called is False + assert breakdown["total_packets"] == 3 + assert breakdown["single_byte"] == 1 + assert breakdown["double_byte"] == 1 + assert breakdown["triple_byte"] == 1 + class TestStatisticsEndpoint: @pytest.mark.asyncio