import logging import sqlite3 import time from hashlib import sha256 from app.database import db from app.decoder import PayloadType, extract_payload, get_packet_payload_type logger = logging.getLogger(__name__) class RawPacketRepository: @staticmethod async def create(data: bytes, timestamp: int | None = None) -> tuple[int, bool]: """ Create a raw packet with payload-based deduplication. Returns (packet_id, is_new) tuple: - is_new=True: New packet stored, packet_id is the new row ID - is_new=False: Duplicate payload detected, packet_id is the existing row ID Deduplication is based on the SHA-256 hash of the packet payload (excluding routing/path information). """ ts = timestamp if timestamp is not None else int(time.time()) # Compute payload hash for deduplication payload = extract_payload(data) if payload: payload_hash = sha256(payload).digest() else: # For malformed packets, hash the full data payload_hash = sha256(data).digest() # Check if this payload already exists cursor = await db.conn.execute( "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) ) existing = await cursor.fetchone() if existing: # Duplicate - return existing packet ID logger.debug( "Duplicate payload detected (hash=%s..., existing_id=%d)", payload_hash.hex()[:12], existing["id"], ) return (existing["id"], False) # New packet - insert with hash try: cursor = await db.conn.execute( "INSERT INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)", (ts, data, payload_hash), ) await db.conn.commit() assert cursor.lastrowid is not None # INSERT always returns a row ID return (cursor.lastrowid, True) except sqlite3.IntegrityError: # Race condition: another insert with same payload_hash happened between # our SELECT and INSERT. This is expected for duplicate packets arriving # close together. Query again to get the existing ID. logger.debug( "Duplicate packet detected via race condition (payload_hash=%s), dropping", payload_hash.hex()[:16], ) cursor = await db.conn.execute( "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) ) existing = await cursor.fetchone() if existing: return (existing["id"], False) # This shouldn't happen, but if it does, re-raise raise @staticmethod async def get_undecrypted_count() -> int: """Get count of undecrypted packets (those without a linked message).""" cursor = await db.conn.execute( "SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL" ) row = await cursor.fetchone() return row["count"] if row else 0 @staticmethod async def get_oldest_undecrypted() -> int | None: """Get timestamp of oldest undecrypted packet, or None if none exist.""" cursor = await db.conn.execute( "SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL" ) row = await cursor.fetchone() return row["oldest"] if row and row["oldest"] is not None else None @staticmethod async def get_all_undecrypted() -> list[tuple[int, bytes, int]]: """Get all undecrypted packets as (id, data, timestamp) tuples.""" cursor = await db.conn.execute( "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" ) rows = await cursor.fetchall() return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows] @staticmethod async def mark_decrypted(packet_id: int, message_id: int) -> None: """Link a raw packet to its decrypted message.""" await db.conn.execute( "UPDATE raw_packets SET message_id = ? WHERE id = ?", (message_id, packet_id), ) await db.conn.commit() @staticmethod async def prune_old_undecrypted(max_age_days: int) -> int: """Delete undecrypted packets older than max_age_days. Returns count deleted.""" cutoff = int(time.time()) - (max_age_days * 86400) cursor = await db.conn.execute( "DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?", (cutoff,), ) await db.conn.commit() return cursor.rowcount @staticmethod async def purge_linked_to_messages() -> int: """Delete raw packets that are already linked to a stored message.""" cursor = await db.conn.execute("DELETE FROM raw_packets WHERE message_id IS NOT NULL") await db.conn.commit() return cursor.rowcount @staticmethod async def get_undecrypted_text_messages() -> list[tuple[int, bytes, int]]: """Get all undecrypted TEXT_MESSAGE packets as (id, data, timestamp) tuples. Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02). These are direct messages that can be decrypted with contact ECDH keys. """ cursor = await db.conn.execute( "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" ) rows = await cursor.fetchall() # Filter for TEXT_MESSAGE packets result = [] for row in rows: data = bytes(row["data"]) payload_type = get_packet_payload_type(data) if payload_type == PayloadType.TEXT_MESSAGE: result.append((row["id"], data, row["timestamp"])) return result