Improve DB streaming perf for cracking and statistics

This commit is contained in:
Jack Kingsman
2026-03-30 21:31:59 -07:00
parent 5c60559cb8
commit 43abcd07b2
8 changed files with 203 additions and 59 deletions
+38 -14
View File
@@ -1,6 +1,7 @@
import logging
import sqlite3
import time
from collections.abc import AsyncIterator
from hashlib import sha256
from app.database import db
@@ -8,6 +9,8 @@ from app.decoder import PayloadType, extract_payload, get_packet_payload_type
logger = logging.getLogger(__name__)
UNDECRYPTED_PACKET_BATCH_SIZE = 500
class RawPacketRepository:
@staticmethod
@@ -100,6 +103,40 @@ class RawPacketRepository:
rows = await cursor.fetchall()
return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows]
@staticmethod
async def stream_undecrypted_text_messages(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Yield undecrypted TEXT_MESSAGE packets in bounded-size batches."""
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
)
try:
while True:
rows = await cursor.fetchmany(batch_size)
if not rows:
break
for row in rows:
data = bytes(row["data"])
payload_type = get_packet_payload_type(data)
if payload_type == PayloadType.TEXT_MESSAGE:
yield (row["id"], data, row["timestamp"])
finally:
await cursor.close()
@staticmethod
async def count_undecrypted_text_messages(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> int:
"""Count undecrypted TEXT_MESSAGE packets without materializing them all."""
count = 0
async for _packet in RawPacketRepository.stream_undecrypted_text_messages(
batch_size=batch_size
):
count += 1
return count
@staticmethod
async def mark_decrypted(packet_id: int, message_id: int) -> None:
"""Link a raw packet to its decrypted message."""
@@ -158,17 +195,4 @@ class RawPacketRepository:
Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02).
These are direct messages that can be decrypted with contact ECDH keys.
"""
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
)
rows = await cursor.fetchall()
# Filter for TEXT_MESSAGE packets
result = []
for row in rows:
data = bytes(row["data"])
payload_type = get_packet_payload_type(data)
if payload_type == PayloadType.TEXT_MESSAGE:
result.append((row["id"], data, row["timestamp"]))
return result
return [packet async for packet in RawPacketRepository.stream_undecrypted_text_messages()]
+40 -30
View File
@@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
SECONDS_1H = 3600
SECONDS_24H = 86400
SECONDS_7D = 604800
RAW_PACKET_STATS_BATCH_SIZE = 500
class AppSettingsRepository:
@@ -246,6 +247,26 @@ class AppSettingsRepository:
class StatisticsRepository:
@staticmethod
async def get_database_message_totals() -> dict[str, int]:
"""Return message totals needed by lightweight debug surfaces."""
cursor = await db.conn.execute(
"""
SELECT
SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms,
SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages,
SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing
FROM messages
"""
)
row = await cursor.fetchone()
assert row is not None
return {
"total_dms": row["total_dms"] or 0,
"total_channel_messages": row["total_channel_messages"] or 0,
"total_outgoing": row["total_outgoing"] or 0,
}
@staticmethod
async def _activity_counts(*, contact_type: int, exclude: bool = False) -> dict[str, int]:
"""Get time-windowed counts for contacts/repeaters heard."""
@@ -311,22 +332,26 @@ class StatisticsRepository:
"SELECT data FROM raw_packets WHERE timestamp >= ?",
(now - SECONDS_24H,),
)
rows = await cursor.fetchall()
single_byte = 0
double_byte = 0
triple_byte = 0
for row in rows:
envelope = parse_packet_envelope(bytes(row["data"]))
if envelope is None:
continue
if envelope.hash_size == 1:
single_byte += 1
elif envelope.hash_size == 2:
double_byte += 1
elif envelope.hash_size == 3:
triple_byte += 1
while True:
rows = await cursor.fetchmany(RAW_PACKET_STATS_BATCH_SIZE)
if not rows:
break
for row in rows:
envelope = parse_packet_envelope(bytes(row["data"]))
if envelope is None:
continue
if envelope.hash_size == 1:
single_byte += 1
elif envelope.hash_size == 2:
double_byte += 1
elif envelope.hash_size == 3:
triple_byte += 1
total_packets = single_byte + double_byte + triple_byte
if total_packets == 0:
@@ -409,22 +434,7 @@ class StatisticsRepository:
decrypted_packets = pkt_row["decrypted"] or 0
undecrypted_packets = total_packets - decrypted_packets
# Message type counts
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'PRIV'")
row = await cursor.fetchone()
assert row is not None
total_dms: int = row["cnt"]
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'CHAN'")
row = await cursor.fetchone()
assert row is not None
total_channel_messages: int = row["cnt"]
# Outgoing count
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE outgoing = 1")
row = await cursor.fetchone()
assert row is not None
total_outgoing: int = row["cnt"]
message_totals = await StatisticsRepository.get_database_message_totals()
# Activity windows
contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True)
@@ -440,9 +450,9 @@ class StatisticsRepository:
"total_packets": total_packets,
"decrypted_packets": decrypted_packets,
"undecrypted_packets": undecrypted_packets,
"total_dms": total_dms,
"total_channel_messages": total_channel_messages,
"total_outgoing": total_outgoing,
"total_dms": message_totals["total_dms"],
"total_channel_messages": message_totals["total_channel_messages"],
"total_outgoing": message_totals["total_outgoing"],
"contacts_heard": contacts_heard,
"repeaters_heard": repeaters_heard,
"known_channels_active": known_channels_active,