Improve DB streaming perf for cracking and statistics

This commit is contained in:
Jack Kingsman
2026-03-30 21:31:59 -07:00
parent 5c60559cb8
commit 43abcd07b2
8 changed files with 203 additions and 59 deletions

View File

@@ -122,20 +122,20 @@ async def run_historical_dm_decryption(
"""Background task to decrypt historical DM packets with contact's key."""
from app.websocket import broadcast_success
packets = await RawPacketRepository.get_undecrypted_text_messages()
total = len(packets)
total = 0
decrypted_count = 0
if total == 0:
logger.info("No undecrypted TEXT_MESSAGE packets to process")
return
logger.info("Starting historical DM decryption of %d TEXT_MESSAGE packets", total)
logger.info("Starting historical DM decryption scan for undecrypted TEXT_MESSAGE packets")
# Derive our public key from the private key
our_public_key_bytes = derive_public_key(private_key_bytes)
for packet_id, packet_data, packet_timestamp in packets:
async for (
packet_id,
packet_data,
packet_timestamp,
) in RawPacketRepository.stream_undecrypted_text_messages():
total += 1
# Note: passing our_public_key=None disables the outbound hash check in
# try_decrypt_dm (only the inbound check src_hash == their_first_byte runs).
# For the 255/256 case where our first byte differs from the contact's,
@@ -187,6 +187,10 @@ async def run_historical_dm_decryption(
if msg_id is not None:
decrypted_count += 1
if total == 0:
logger.info("No undecrypted TEXT_MESSAGE packets to process")
return
logger.info(
"Historical DM decryption complete: %d/%d packets decrypted",
decrypted_count,

View File

@@ -1,6 +1,7 @@
import logging
import sqlite3
import time
from collections.abc import AsyncIterator
from hashlib import sha256
from app.database import db
@@ -8,6 +9,8 @@ from app.decoder import PayloadType, extract_payload, get_packet_payload_type
logger = logging.getLogger(__name__)
UNDECRYPTED_PACKET_BATCH_SIZE = 500
class RawPacketRepository:
@staticmethod
@@ -100,6 +103,40 @@ class RawPacketRepository:
rows = await cursor.fetchall()
return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows]
@staticmethod
async def stream_undecrypted_text_messages(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Yield undecrypted TEXT_MESSAGE packets in bounded-size batches."""
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
)
try:
while True:
rows = await cursor.fetchmany(batch_size)
if not rows:
break
for row in rows:
data = bytes(row["data"])
payload_type = get_packet_payload_type(data)
if payload_type == PayloadType.TEXT_MESSAGE:
yield (row["id"], data, row["timestamp"])
finally:
await cursor.close()
@staticmethod
async def count_undecrypted_text_messages(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> int:
"""Count undecrypted TEXT_MESSAGE packets without materializing them all."""
count = 0
async for _packet in RawPacketRepository.stream_undecrypted_text_messages(
batch_size=batch_size
):
count += 1
return count
@staticmethod
async def mark_decrypted(packet_id: int, message_id: int) -> None:
"""Link a raw packet to its decrypted message."""
@@ -158,17 +195,4 @@ class RawPacketRepository:
Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02).
These are direct messages that can be decrypted with contact ECDH keys.
"""
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
)
rows = await cursor.fetchall()
# Filter for TEXT_MESSAGE packets
result = []
for row in rows:
data = bytes(row["data"])
payload_type = get_packet_payload_type(data)
if payload_type == PayloadType.TEXT_MESSAGE:
result.append((row["id"], data, row["timestamp"]))
return result
return [packet async for packet in RawPacketRepository.stream_undecrypted_text_messages()]

View File

@@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
SECONDS_1H = 3600
SECONDS_24H = 86400
SECONDS_7D = 604800
RAW_PACKET_STATS_BATCH_SIZE = 500
class AppSettingsRepository:
@@ -246,6 +247,26 @@ class AppSettingsRepository:
class StatisticsRepository:
@staticmethod
async def get_database_message_totals() -> dict[str, int]:
"""Return message totals needed by lightweight debug surfaces."""
cursor = await db.conn.execute(
"""
SELECT
SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms,
SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages,
SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing
FROM messages
"""
)
row = await cursor.fetchone()
assert row is not None
return {
"total_dms": row["total_dms"] or 0,
"total_channel_messages": row["total_channel_messages"] or 0,
"total_outgoing": row["total_outgoing"] or 0,
}
@staticmethod
async def _activity_counts(*, contact_type: int, exclude: bool = False) -> dict[str, int]:
"""Get time-windowed counts for contacts/repeaters heard."""
@@ -311,22 +332,26 @@ class StatisticsRepository:
"SELECT data FROM raw_packets WHERE timestamp >= ?",
(now - SECONDS_24H,),
)
rows = await cursor.fetchall()
single_byte = 0
double_byte = 0
triple_byte = 0
for row in rows:
envelope = parse_packet_envelope(bytes(row["data"]))
if envelope is None:
continue
if envelope.hash_size == 1:
single_byte += 1
elif envelope.hash_size == 2:
double_byte += 1
elif envelope.hash_size == 3:
triple_byte += 1
while True:
rows = await cursor.fetchmany(RAW_PACKET_STATS_BATCH_SIZE)
if not rows:
break
for row in rows:
envelope = parse_packet_envelope(bytes(row["data"]))
if envelope is None:
continue
if envelope.hash_size == 1:
single_byte += 1
elif envelope.hash_size == 2:
double_byte += 1
elif envelope.hash_size == 3:
triple_byte += 1
total_packets = single_byte + double_byte + triple_byte
if total_packets == 0:
@@ -409,22 +434,7 @@ class StatisticsRepository:
decrypted_packets = pkt_row["decrypted"] or 0
undecrypted_packets = total_packets - decrypted_packets
# Message type counts
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'PRIV'")
row = await cursor.fetchone()
assert row is not None
total_dms: int = row["cnt"]
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'CHAN'")
row = await cursor.fetchone()
assert row is not None
total_channel_messages: int = row["cnt"]
# Outgoing count
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE outgoing = 1")
row = await cursor.fetchone()
assert row is not None
total_outgoing: int = row["cnt"]
message_totals = await StatisticsRepository.get_database_message_totals()
# Activity windows
contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True)
@@ -440,9 +450,9 @@ class StatisticsRepository:
"total_packets": total_packets,
"decrypted_packets": decrypted_packets,
"undecrypted_packets": undecrypted_packets,
"total_dms": total_dms,
"total_channel_messages": total_channel_messages,
"total_outgoing": total_outgoing,
"total_dms": message_totals["total_dms"],
"total_channel_messages": message_totals["total_channel_messages"],
"total_outgoing": message_totals["total_outgoing"],
"contacts_heard": contacts_heard,
"repeaters_heard": repeaters_heard,
"known_channels_active": known_channels_active,

View File

@@ -265,7 +265,7 @@ async def _probe_radio() -> DebugRadioProbe:
async def debug_support_snapshot() -> DebugSnapshotResponse:
"""Return a support/debug snapshot with recent logs and live radio state."""
health_data = await build_health_data(radio_runtime.is_connected, radio_runtime.connection_info)
statistics = await StatisticsRepository.get_all()
message_totals = await StatisticsRepository.get_database_message_totals()
radio_probe = await _probe_radio()
channels_with_incoming_messages = (
await MessageRepository.count_channels_with_incoming_messages()
@@ -291,9 +291,9 @@ async def debug_support_snapshot() -> DebugSnapshotResponse:
},
),
database=DebugDatabaseInfo(
total_dms=statistics["total_dms"],
total_channel_messages=statistics["total_channel_messages"],
total_outgoing=statistics["total_outgoing"],
total_dms=message_totals["total_dms"],
total_channel_messages=message_totals["total_channel_messages"],
total_outgoing=message_totals["total_outgoing"],
),
radio_probe=radio_probe,
logs=[*LOG_COPY_BOUNDARY_PREFIX, *get_recent_log_lines(limit=1000)],

View File

@@ -210,8 +210,7 @@ async def decrypt_historical_packets(
except ValueError:
raise _bad_request("Invalid hex string for contact public key") from None
packets = await RawPacketRepository.get_undecrypted_text_messages()
count = len(packets)
count = await RawPacketRepository.count_undecrypted_text_messages()
if count == 0:
return DecryptResult(
started=False,

View File

@@ -190,6 +190,29 @@ class TestDebugEndpoint:
assert payload["database"]["total_channel_messages"] == 1
assert payload["database"]["total_outgoing"] == 1
@pytest.mark.asyncio
async def test_support_snapshot_uses_lightweight_message_totals(self, test_db, client):
"""Debug snapshot should not call the full statistics aggregation."""
with (
patch(
"app.routers.debug.StatisticsRepository.get_all",
new=AsyncMock(side_effect=AssertionError("get_all should not be called")),
),
patch(
"app.routers.debug.StatisticsRepository.get_database_message_totals",
new=AsyncMock(
return_value={
"total_dms": 0,
"total_channel_messages": 0,
"total_outgoing": 0,
}
),
),
):
response = await client.get("/api/debug")
assert response.status_code == 200
class TestRadioDisconnectedHandler:
"""Test that RadioDisconnectedError maps to 503."""

View File

@@ -5,7 +5,7 @@ undecrypted count endpoint, and the maintenance endpoint.
"""
import time
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
import pytest
@@ -305,6 +305,43 @@ class TestDecryptHistoricalPackets:
assert "key_type" in data["detail"].lower()
class TestUndecryptedTextPacketStreaming:
@pytest.mark.asyncio
async def test_count_undecrypted_text_messages_uses_batched_streaming(self, test_db):
"""Counting undecrypted DM packets should stream batches and filter by payload type."""
class FakeCursor:
def __init__(self):
self._batches = [
[
{"id": 1, "data": b"\x09\x00dm", "timestamp": 1000},
{"id": 2, "data": b"\x15\x00chan", "timestamp": 1001},
],
[{"id": 3, "data": b"\x09\x00dm2", "timestamp": 1002}],
[],
]
self.fetchall_called = False
async def fetchmany(self, size):
assert size > 0
return self._batches.pop(0)
async def close(self):
return None
async def fetchall(self):
self.fetchall_called = True
raise AssertionError("fetchall() should not be used")
fake_cursor = FakeCursor()
with patch.object(test_db.conn, "execute", new=AsyncMock(return_value=fake_cursor)):
count = await RawPacketRepository.count_undecrypted_text_messages(batch_size=2)
assert fake_cursor.fetchall_called is False
assert count == 2
class TestRunHistoricalChannelDecryption:
"""Test the _run_historical_channel_decryption background task."""

View File

@@ -1,6 +1,7 @@
"""Tests for the statistics repository and endpoint."""
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
@@ -349,6 +350,52 @@ class TestPathHashWidthStats:
assert breakdown["double_byte_pct"] == pytest.approx(100 / 3, rel=1e-3)
assert breakdown["triple_byte_pct"] == pytest.approx(100 / 3, rel=1e-3)
@pytest.mark.asyncio
async def test_path_hash_width_scan_uses_batched_fetchmany(self, test_db):
"""Hash-width stats should stream batches instead of calling fetchall()."""
class FakeCursor:
def __init__(self):
self._batches = [
[{"data": b"a"}, {"data": b"b"}],
[{"data": b"c"}],
[],
]
self.fetchall_called = False
async def fetchmany(self, size):
assert size > 0
return self._batches.pop(0)
async def fetchall(self):
self.fetchall_called = True
raise AssertionError("fetchall() should not be used")
fake_cursor = FakeCursor()
def fake_parse(raw_packet: bytes):
hash_sizes = {
b"a": 1,
b"b": 2,
b"c": 3,
}
hash_size = hash_sizes.get(raw_packet)
if hash_size is None:
return None
return SimpleNamespace(hash_size=hash_size)
with (
patch.object(test_db.conn, "execute", new=AsyncMock(return_value=fake_cursor)),
patch("app.repository.settings.parse_packet_envelope", side_effect=fake_parse),
):
breakdown = await StatisticsRepository._path_hash_width_24h()
assert fake_cursor.fetchall_called is False
assert breakdown["total_packets"] == 3
assert breakdown["single_byte"] == 1
assert breakdown["double_byte"] == 1
assert breakdown["triple_byte"] == 1
class TestStatisticsEndpoint:
@pytest.mark.asyncio