Files
Remote-Terminal-for-MeshCore/app/repository/raw_packets.py

202 lines
7.9 KiB
Python

import logging
import time
from collections.abc import AsyncIterator
from hashlib import sha256
from app.database import db
from app.decoder import PayloadType, extract_payload, get_packet_payload_type
logger = logging.getLogger(__name__)
UNDECRYPTED_PACKET_BATCH_SIZE = 500
class RawPacketRepository:
@staticmethod
async def create(data: bytes, timestamp: int | None = None) -> tuple[int, bool]:
"""
Create a raw packet with payload-based deduplication.
Returns (packet_id, is_new) tuple:
- is_new=True: New packet stored, packet_id is the new row ID
- is_new=False: Duplicate payload detected, packet_id is the existing row ID
Deduplication is based on the SHA-256 hash of the packet payload
(excluding routing/path information).
"""
ts = timestamp if timestamp is not None else int(time.time())
# Compute payload hash for deduplication
payload = extract_payload(data)
if payload:
payload_hash = sha256(payload).digest()
else:
# For malformed packets, hash the full data
payload_hash = sha256(data).digest()
async with db.tx() as conn:
async with conn.execute(
"INSERT OR IGNORE INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)",
(ts, data, payload_hash),
) as cursor:
rowcount = cursor.rowcount
lastrowid = cursor.lastrowid
if rowcount > 0:
assert lastrowid is not None
return (lastrowid, True)
# Duplicate payload — look up the existing row (same transaction).
async with conn.execute(
"SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,)
) as cursor:
existing = await cursor.fetchone()
assert existing is not None
return (existing["id"], False)
@staticmethod
async def get_undecrypted_count() -> int:
"""Get count of undecrypted packets (those without a linked message)."""
async with db.readonly() as conn:
async with conn.execute(
"SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL"
) as cursor:
row = await cursor.fetchone()
return row["count"] if row else 0
@staticmethod
async def get_oldest_undecrypted() -> int | None:
"""Get timestamp of oldest undecrypted packet, or None if none exist."""
async with db.readonly() as conn:
async with conn.execute(
"SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL"
) as cursor:
row = await cursor.fetchone()
return row["oldest"] if row and row["oldest"] is not None else None
@staticmethod
async def _stream_undecrypted_rows(
batch_size: int,
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Internal: keyset-paginated scan of every undecrypted raw packet.
Yields ``(id, data, timestamp)`` for each row across all batches.
Lock is acquired per batch only — concurrent writes can interleave
at batch boundaries rather than being blocked for the full scan.
Each batch opens a fresh cursor and consumes it fully with
``fetchall()`` before releasing, so no prepared statement is alive
at a yield boundary.
``last_id`` advances per row, not per yield, so external filters
(see ``stream_undecrypted_text_messages``) that drop rows do not
cause a re-scan of skipped IDs.
"""
last_id = -1
while True:
async with db.readonly() as conn:
async with conn.execute(
"SELECT id, data, timestamp FROM raw_packets "
"WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?",
(last_id, batch_size),
) as cursor:
rows = await cursor.fetchall()
if not rows:
return
for row in rows:
last_id = row["id"]
yield (row["id"], bytes(row["data"]), row["timestamp"])
@staticmethod
async def stream_all_undecrypted(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Yield all undecrypted packets as (id, data, timestamp) in bounded batches."""
async for row in RawPacketRepository._stream_undecrypted_rows(batch_size):
yield row
@staticmethod
async def stream_undecrypted_text_messages(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Yield undecrypted TEXT_MESSAGE packets in bounded-size batches.
Filters the shared scan to rows whose payload parses as a text
message. Non-matching rows still advance the keyset cursor so they
aren't re-fetched on subsequent batches.
"""
async for packet_id, data, timestamp in RawPacketRepository._stream_undecrypted_rows(
batch_size
):
if get_packet_payload_type(data) == PayloadType.TEXT_MESSAGE:
yield (packet_id, data, timestamp)
@staticmethod
async def count_undecrypted_text_messages(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> int:
"""Count undecrypted TEXT_MESSAGE packets without materializing them all."""
count = 0
async for _packet in RawPacketRepository.stream_undecrypted_text_messages(
batch_size=batch_size
):
count += 1
return count
@staticmethod
async def mark_decrypted(packet_id: int, message_id: int) -> None:
"""Link a raw packet to its decrypted message."""
async with db.tx() as conn:
async with conn.execute(
"UPDATE raw_packets SET message_id = ? WHERE id = ?",
(message_id, packet_id),
):
pass
@staticmethod
async def get_linked_message_id(packet_id: int) -> int | None:
"""Return the linked message ID for a raw packet, if any."""
async with db.readonly() as conn:
async with conn.execute(
"SELECT message_id FROM raw_packets WHERE id = ?",
(packet_id,),
) as cursor:
row = await cursor.fetchone()
if not row:
return None
return row["message_id"]
@staticmethod
async def get_by_id(packet_id: int) -> tuple[int, bytes, int, int | None] | None:
"""Return a raw packet row as (id, data, timestamp, message_id)."""
async with db.readonly() as conn:
async with conn.execute(
"SELECT id, data, timestamp, message_id FROM raw_packets WHERE id = ?",
(packet_id,),
) as cursor:
row = await cursor.fetchone()
if not row:
return None
return (row["id"], bytes(row["data"]), row["timestamp"], row["message_id"])
@staticmethod
async def prune_old_undecrypted(max_age_days: int) -> int:
"""Delete undecrypted packets older than max_age_days. Returns count deleted."""
cutoff = int(time.time()) - (max_age_days * 86400)
async with db.tx() as conn:
async with conn.execute(
"DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?",
(cutoff,),
) as cursor:
rowcount = cursor.rowcount
return rowcount
@staticmethod
async def purge_linked_to_messages() -> int:
"""Delete raw packets that are already linked to a stored message."""
async with db.tx() as conn:
async with conn.execute(
"DELETE FROM raw_packets WHERE message_id IS NOT NULL"
) as cursor:
rowcount = cursor.rowcount
return rowcount