mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-05-01 02:53:00 +02:00
Improve DB streaming perf for cracking and statistics
This commit is contained in:
@@ -122,20 +122,20 @@ async def run_historical_dm_decryption(
|
||||
"""Background task to decrypt historical DM packets with contact's key."""
|
||||
from app.websocket import broadcast_success
|
||||
|
||||
packets = await RawPacketRepository.get_undecrypted_text_messages()
|
||||
total = len(packets)
|
||||
total = 0
|
||||
decrypted_count = 0
|
||||
|
||||
if total == 0:
|
||||
logger.info("No undecrypted TEXT_MESSAGE packets to process")
|
||||
return
|
||||
|
||||
logger.info("Starting historical DM decryption of %d TEXT_MESSAGE packets", total)
|
||||
logger.info("Starting historical DM decryption scan for undecrypted TEXT_MESSAGE packets")
|
||||
|
||||
# Derive our public key from the private key
|
||||
our_public_key_bytes = derive_public_key(private_key_bytes)
|
||||
|
||||
for packet_id, packet_data, packet_timestamp in packets:
|
||||
async for (
|
||||
packet_id,
|
||||
packet_data,
|
||||
packet_timestamp,
|
||||
) in RawPacketRepository.stream_undecrypted_text_messages():
|
||||
total += 1
|
||||
# Note: passing our_public_key=None disables the outbound hash check in
|
||||
# try_decrypt_dm (only the inbound check src_hash == their_first_byte runs).
|
||||
# For the 255/256 case where our first byte differs from the contact's,
|
||||
@@ -187,6 +187,10 @@ async def run_historical_dm_decryption(
|
||||
if msg_id is not None:
|
||||
decrypted_count += 1
|
||||
|
||||
if total == 0:
|
||||
logger.info("No undecrypted TEXT_MESSAGE packets to process")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Historical DM decryption complete: %d/%d packets decrypted",
|
||||
decrypted_count,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from hashlib import sha256
|
||||
|
||||
from app.database import db
|
||||
@@ -8,6 +9,8 @@ from app.decoder import PayloadType, extract_payload, get_packet_payload_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UNDECRYPTED_PACKET_BATCH_SIZE = 500
|
||||
|
||||
|
||||
class RawPacketRepository:
|
||||
@staticmethod
|
||||
@@ -100,6 +103,40 @@ class RawPacketRepository:
|
||||
rows = await cursor.fetchall()
|
||||
return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
async def stream_undecrypted_text_messages(
|
||||
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
|
||||
) -> AsyncIterator[tuple[int, bytes, int]]:
|
||||
"""Yield undecrypted TEXT_MESSAGE packets in bounded-size batches."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
rows = await cursor.fetchmany(batch_size)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
data = bytes(row["data"])
|
||||
payload_type = get_packet_payload_type(data)
|
||||
if payload_type == PayloadType.TEXT_MESSAGE:
|
||||
yield (row["id"], data, row["timestamp"])
|
||||
finally:
|
||||
await cursor.close()
|
||||
|
||||
@staticmethod
|
||||
async def count_undecrypted_text_messages(
|
||||
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
|
||||
) -> int:
|
||||
"""Count undecrypted TEXT_MESSAGE packets without materializing them all."""
|
||||
count = 0
|
||||
async for _packet in RawPacketRepository.stream_undecrypted_text_messages(
|
||||
batch_size=batch_size
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
async def mark_decrypted(packet_id: int, message_id: int) -> None:
|
||||
"""Link a raw packet to its decrypted message."""
|
||||
@@ -158,17 +195,4 @@ class RawPacketRepository:
|
||||
Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02).
|
||||
These are direct messages that can be decrypted with contact ECDH keys.
|
||||
"""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
# Filter for TEXT_MESSAGE packets
|
||||
result = []
|
||||
for row in rows:
|
||||
data = bytes(row["data"])
|
||||
payload_type = get_packet_payload_type(data)
|
||||
if payload_type == PayloadType.TEXT_MESSAGE:
|
||||
result.append((row["id"], data, row["timestamp"]))
|
||||
|
||||
return result
|
||||
return [packet async for packet in RawPacketRepository.stream_undecrypted_text_messages()]
|
||||
|
||||
@@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
SECONDS_1H = 3600
|
||||
SECONDS_24H = 86400
|
||||
SECONDS_7D = 604800
|
||||
RAW_PACKET_STATS_BATCH_SIZE = 500
|
||||
|
||||
|
||||
class AppSettingsRepository:
|
||||
@@ -246,6 +247,26 @@ class AppSettingsRepository:
|
||||
|
||||
|
||||
class StatisticsRepository:
|
||||
@staticmethod
|
||||
async def get_database_message_totals() -> dict[str, int]:
|
||||
"""Return message totals needed by lightweight debug surfaces."""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms,
|
||||
SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages,
|
||||
SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing
|
||||
FROM messages
|
||||
"""
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
return {
|
||||
"total_dms": row["total_dms"] or 0,
|
||||
"total_channel_messages": row["total_channel_messages"] or 0,
|
||||
"total_outgoing": row["total_outgoing"] or 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _activity_counts(*, contact_type: int, exclude: bool = False) -> dict[str, int]:
|
||||
"""Get time-windowed counts for contacts/repeaters heard."""
|
||||
@@ -311,22 +332,26 @@ class StatisticsRepository:
|
||||
"SELECT data FROM raw_packets WHERE timestamp >= ?",
|
||||
(now - SECONDS_24H,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
single_byte = 0
|
||||
double_byte = 0
|
||||
triple_byte = 0
|
||||
|
||||
for row in rows:
|
||||
envelope = parse_packet_envelope(bytes(row["data"]))
|
||||
if envelope is None:
|
||||
continue
|
||||
if envelope.hash_size == 1:
|
||||
single_byte += 1
|
||||
elif envelope.hash_size == 2:
|
||||
double_byte += 1
|
||||
elif envelope.hash_size == 3:
|
||||
triple_byte += 1
|
||||
while True:
|
||||
rows = await cursor.fetchmany(RAW_PACKET_STATS_BATCH_SIZE)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
envelope = parse_packet_envelope(bytes(row["data"]))
|
||||
if envelope is None:
|
||||
continue
|
||||
if envelope.hash_size == 1:
|
||||
single_byte += 1
|
||||
elif envelope.hash_size == 2:
|
||||
double_byte += 1
|
||||
elif envelope.hash_size == 3:
|
||||
triple_byte += 1
|
||||
|
||||
total_packets = single_byte + double_byte + triple_byte
|
||||
if total_packets == 0:
|
||||
@@ -409,22 +434,7 @@ class StatisticsRepository:
|
||||
decrypted_packets = pkt_row["decrypted"] or 0
|
||||
undecrypted_packets = total_packets - decrypted_packets
|
||||
|
||||
# Message type counts
|
||||
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'PRIV'")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
total_dms: int = row["cnt"]
|
||||
|
||||
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'CHAN'")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
total_channel_messages: int = row["cnt"]
|
||||
|
||||
# Outgoing count
|
||||
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE outgoing = 1")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
total_outgoing: int = row["cnt"]
|
||||
message_totals = await StatisticsRepository.get_database_message_totals()
|
||||
|
||||
# Activity windows
|
||||
contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True)
|
||||
@@ -440,9 +450,9 @@ class StatisticsRepository:
|
||||
"total_packets": total_packets,
|
||||
"decrypted_packets": decrypted_packets,
|
||||
"undecrypted_packets": undecrypted_packets,
|
||||
"total_dms": total_dms,
|
||||
"total_channel_messages": total_channel_messages,
|
||||
"total_outgoing": total_outgoing,
|
||||
"total_dms": message_totals["total_dms"],
|
||||
"total_channel_messages": message_totals["total_channel_messages"],
|
||||
"total_outgoing": message_totals["total_outgoing"],
|
||||
"contacts_heard": contacts_heard,
|
||||
"repeaters_heard": repeaters_heard,
|
||||
"known_channels_active": known_channels_active,
|
||||
|
||||
@@ -265,7 +265,7 @@ async def _probe_radio() -> DebugRadioProbe:
|
||||
async def debug_support_snapshot() -> DebugSnapshotResponse:
|
||||
"""Return a support/debug snapshot with recent logs and live radio state."""
|
||||
health_data = await build_health_data(radio_runtime.is_connected, radio_runtime.connection_info)
|
||||
statistics = await StatisticsRepository.get_all()
|
||||
message_totals = await StatisticsRepository.get_database_message_totals()
|
||||
radio_probe = await _probe_radio()
|
||||
channels_with_incoming_messages = (
|
||||
await MessageRepository.count_channels_with_incoming_messages()
|
||||
@@ -291,9 +291,9 @@ async def debug_support_snapshot() -> DebugSnapshotResponse:
|
||||
},
|
||||
),
|
||||
database=DebugDatabaseInfo(
|
||||
total_dms=statistics["total_dms"],
|
||||
total_channel_messages=statistics["total_channel_messages"],
|
||||
total_outgoing=statistics["total_outgoing"],
|
||||
total_dms=message_totals["total_dms"],
|
||||
total_channel_messages=message_totals["total_channel_messages"],
|
||||
total_outgoing=message_totals["total_outgoing"],
|
||||
),
|
||||
radio_probe=radio_probe,
|
||||
logs=[*LOG_COPY_BOUNDARY_PREFIX, *get_recent_log_lines(limit=1000)],
|
||||
|
||||
@@ -210,8 +210,7 @@ async def decrypt_historical_packets(
|
||||
except ValueError:
|
||||
raise _bad_request("Invalid hex string for contact public key") from None
|
||||
|
||||
packets = await RawPacketRepository.get_undecrypted_text_messages()
|
||||
count = len(packets)
|
||||
count = await RawPacketRepository.count_undecrypted_text_messages()
|
||||
if count == 0:
|
||||
return DecryptResult(
|
||||
started=False,
|
||||
|
||||
@@ -190,6 +190,29 @@ class TestDebugEndpoint:
|
||||
assert payload["database"]["total_channel_messages"] == 1
|
||||
assert payload["database"]["total_outgoing"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_support_snapshot_uses_lightweight_message_totals(self, test_db, client):
|
||||
"""Debug snapshot should not call the full statistics aggregation."""
|
||||
with (
|
||||
patch(
|
||||
"app.routers.debug.StatisticsRepository.get_all",
|
||||
new=AsyncMock(side_effect=AssertionError("get_all should not be called")),
|
||||
),
|
||||
patch(
|
||||
"app.routers.debug.StatisticsRepository.get_database_message_totals",
|
||||
new=AsyncMock(
|
||||
return_value={
|
||||
"total_dms": 0,
|
||||
"total_channel_messages": 0,
|
||||
"total_outgoing": 0,
|
||||
}
|
||||
),
|
||||
),
|
||||
):
|
||||
response = await client.get("/api/debug")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestRadioDisconnectedHandler:
|
||||
"""Test that RadioDisconnectedError maps to 503."""
|
||||
|
||||
@@ -5,7 +5,7 @@ undecrypted count endpoint, and the maintenance endpoint.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -305,6 +305,43 @@ class TestDecryptHistoricalPackets:
|
||||
assert "key_type" in data["detail"].lower()
|
||||
|
||||
|
||||
class TestUndecryptedTextPacketStreaming:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_undecrypted_text_messages_uses_batched_streaming(self, test_db):
|
||||
"""Counting undecrypted DM packets should stream batches and filter by payload type."""
|
||||
|
||||
class FakeCursor:
|
||||
def __init__(self):
|
||||
self._batches = [
|
||||
[
|
||||
{"id": 1, "data": b"\x09\x00dm", "timestamp": 1000},
|
||||
{"id": 2, "data": b"\x15\x00chan", "timestamp": 1001},
|
||||
],
|
||||
[{"id": 3, "data": b"\x09\x00dm2", "timestamp": 1002}],
|
||||
[],
|
||||
]
|
||||
self.fetchall_called = False
|
||||
|
||||
async def fetchmany(self, size):
|
||||
assert size > 0
|
||||
return self._batches.pop(0)
|
||||
|
||||
async def close(self):
|
||||
return None
|
||||
|
||||
async def fetchall(self):
|
||||
self.fetchall_called = True
|
||||
raise AssertionError("fetchall() should not be used")
|
||||
|
||||
fake_cursor = FakeCursor()
|
||||
|
||||
with patch.object(test_db.conn, "execute", new=AsyncMock(return_value=fake_cursor)):
|
||||
count = await RawPacketRepository.count_undecrypted_text_messages(batch_size=2)
|
||||
|
||||
assert fake_cursor.fetchall_called is False
|
||||
assert count == 2
|
||||
|
||||
|
||||
class TestRunHistoricalChannelDecryption:
|
||||
"""Test the _run_historical_channel_decryption background task."""
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for the statistics repository and endpoint."""
|
||||
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -349,6 +350,52 @@ class TestPathHashWidthStats:
|
||||
assert breakdown["double_byte_pct"] == pytest.approx(100 / 3, rel=1e-3)
|
||||
assert breakdown["triple_byte_pct"] == pytest.approx(100 / 3, rel=1e-3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_hash_width_scan_uses_batched_fetchmany(self, test_db):
|
||||
"""Hash-width stats should stream batches instead of calling fetchall()."""
|
||||
|
||||
class FakeCursor:
|
||||
def __init__(self):
|
||||
self._batches = [
|
||||
[{"data": b"a"}, {"data": b"b"}],
|
||||
[{"data": b"c"}],
|
||||
[],
|
||||
]
|
||||
self.fetchall_called = False
|
||||
|
||||
async def fetchmany(self, size):
|
||||
assert size > 0
|
||||
return self._batches.pop(0)
|
||||
|
||||
async def fetchall(self):
|
||||
self.fetchall_called = True
|
||||
raise AssertionError("fetchall() should not be used")
|
||||
|
||||
fake_cursor = FakeCursor()
|
||||
|
||||
def fake_parse(raw_packet: bytes):
|
||||
hash_sizes = {
|
||||
b"a": 1,
|
||||
b"b": 2,
|
||||
b"c": 3,
|
||||
}
|
||||
hash_size = hash_sizes.get(raw_packet)
|
||||
if hash_size is None:
|
||||
return None
|
||||
return SimpleNamespace(hash_size=hash_size)
|
||||
|
||||
with (
|
||||
patch.object(test_db.conn, "execute", new=AsyncMock(return_value=fake_cursor)),
|
||||
patch("app.repository.settings.parse_packet_envelope", side_effect=fake_parse),
|
||||
):
|
||||
breakdown = await StatisticsRepository._path_hash_width_24h()
|
||||
|
||||
assert fake_cursor.fetchall_called is False
|
||||
assert breakdown["total_packets"] == 3
|
||||
assert breakdown["single_byte"] == 1
|
||||
assert breakdown["double_byte"] == 1
|
||||
assert breakdown["triple_byte"] == 1
|
||||
|
||||
|
||||
class TestStatisticsEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user