import logging from hashlib import sha256 from sqlite3 import OperationalError import aiosqlite from fastapi import APIRouter, BackgroundTasks, HTTPException, Response, status from pydantic import BaseModel, Field from app.database import db from app.decoder import parse_packet, try_decrypt_packet_with_channel_key from app.models import RawPacketDecryptedInfo, RawPacketDetail from app.packet_processor import create_message_from_decrypted, run_historical_dm_decryption from app.repository import ChannelRepository, MessageRepository, RawPacketRepository from app.websocket import broadcast_success logger = logging.getLogger(__name__) router = APIRouter(prefix="/packets", tags=["packets"]) class DecryptRequest(BaseModel): key_type: str = Field(description="Type of key: 'channel' or 'contact'") channel_key: str | None = Field( default=None, description="Channel key as hex (16 bytes = 32 chars)" ) channel_name: str | None = Field( default=None, description="Channel name (for hashtag channels, key derived from name)" ) # Fields for contact (DM) decryption private_key: str | None = Field( default=None, description="Our private key as hex (64 bytes = 128 chars, Ed25519 seed + pubkey)", ) contact_public_key: str | None = Field( default=None, description="Contact's public key as hex (32 bytes = 64 chars)" ) class DecryptResult(BaseModel): started: bool total_packets: int message: str def _bad_request(detail: str) -> HTTPException: return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) async def _run_historical_channel_decryption( channel_key_bytes: bytes, channel_key_hex: str, display_name: str | None = None ) -> None: """Background task to decrypt historical packets with a channel key.""" packets = await RawPacketRepository.get_all_undecrypted() total = len(packets) decrypted_count = 0 if total == 0: logger.info("No undecrypted packets to process") return logger.info("Starting historical channel decryption of %d packets", total) for packet_id, packet_data, packet_timestamp in packets: result = try_decrypt_packet_with_channel_key(packet_data, channel_key_bytes) if result is not None: # Extract path from the raw packet for storage packet_info = parse_packet(packet_data) path_hex = packet_info.path.hex() if packet_info else None msg_id = await create_message_from_decrypted( packet_id=packet_id, channel_key=channel_key_hex, channel_name=display_name, sender=result.sender, message_text=result.message, timestamp=result.timestamp, received_at=packet_timestamp, path=path_hex, path_len=packet_info.path_length if packet_info else None, realtime=False, # Historical decryption should not trigger fanout ) if msg_id is not None: decrypted_count += 1 logger.info( "Historical channel decryption complete: %d/%d packets decrypted", decrypted_count, total ) # Notify frontend if decrypted_count > 0: name = display_name or channel_key_hex[:12] broadcast_success( f"Historical decrypt complete for {name}", f"Decrypted {decrypted_count} message{'s' if decrypted_count != 1 else ''}", ) @router.get("/undecrypted/count") async def get_undecrypted_count() -> dict: """Get the count of undecrypted packets.""" count = await RawPacketRepository.get_undecrypted_count() return {"count": count} @router.get("/{packet_id}", response_model=RawPacketDetail) async def get_raw_packet(packet_id: int) -> RawPacketDetail: """Fetch one stored raw packet by row ID for on-demand inspection.""" packet_row = await RawPacketRepository.get_by_id(packet_id) if packet_row is None: raise HTTPException(status_code=404, detail="Raw packet not found") stored_packet_id, packet_data, packet_timestamp, message_id = packet_row packet_info = parse_packet(packet_data) payload_type_name = packet_info.payload_type.name if packet_info else "Unknown" decrypted_info: RawPacketDecryptedInfo | None = None if message_id is not None: message = await MessageRepository.get_by_id(message_id) if message is not None: if message.type == "CHAN": channel = await ChannelRepository.get_by_key(message.conversation_key) decrypted_info = RawPacketDecryptedInfo( channel_name=channel.name if channel else None, sender=message.sender_name, channel_key=message.conversation_key, contact_key=message.sender_key, ) else: decrypted_info = RawPacketDecryptedInfo( sender=message.sender_name, contact_key=message.conversation_key, ) return RawPacketDetail( id=stored_packet_id, timestamp=packet_timestamp, data=packet_data.hex(), payload_type=payload_type_name, decrypted=message_id is not None, decrypted_info=decrypted_info, ) @router.post("/decrypt/historical", response_model=DecryptResult) async def decrypt_historical_packets( request: DecryptRequest, background_tasks: BackgroundTasks, response: Response ) -> DecryptResult: """ Attempt to decrypt historical packets with the provided key. Runs in the background. Multiple decrypt jobs can run concurrently. """ if request.key_type == "channel": # Channel decryption if request.channel_key: try: channel_key_bytes = bytes.fromhex(request.channel_key) if len(channel_key_bytes) != 16: raise _bad_request("Channel key must be 16 bytes (32 hex chars)") channel_key_hex = request.channel_key.upper() except ValueError: raise _bad_request("Invalid hex string for channel key") from None elif request.channel_name: channel_key_bytes = sha256(request.channel_name.encode("utf-8")).digest()[:16] channel_key_hex = channel_key_bytes.hex().upper() else: raise _bad_request("Must provide channel_key or channel_name") # Get count and lookup channel name for display count = await RawPacketRepository.get_undecrypted_count() if count == 0: return DecryptResult( started=False, total_packets=0, message="No undecrypted packets to process" ) # Try to find channel name for display channel = await ChannelRepository.get_by_key(channel_key_hex) display_name = channel.name if channel else request.channel_name background_tasks.add_task( _run_historical_channel_decryption, channel_key_bytes, channel_key_hex, display_name ) response.status_code = status.HTTP_202_ACCEPTED return DecryptResult( started=True, total_packets=count, message=f"Started channel decryption of {count} packets in background", ) elif request.key_type == "contact": # DM decryption if not request.private_key: raise _bad_request("Must provide private_key for contact decryption") if not request.contact_public_key: raise _bad_request("Must provide contact_public_key for contact decryption") try: private_key_bytes = bytes.fromhex(request.private_key) if len(private_key_bytes) != 64: raise _bad_request("Private key must be 64 bytes (128 hex chars)") except ValueError: raise _bad_request("Invalid hex string for private key") from None try: contact_public_key_bytes = bytes.fromhex(request.contact_public_key) if len(contact_public_key_bytes) != 32: raise _bad_request("Contact public key must be 32 bytes (64 hex chars)") contact_public_key_hex = request.contact_public_key.lower() except ValueError: raise _bad_request("Invalid hex string for contact public key") from None packets = await RawPacketRepository.get_undecrypted_text_messages() count = len(packets) if count == 0: return DecryptResult( started=False, total_packets=0, message="No undecrypted TEXT_MESSAGE packets to process", ) # Try to find contact name for display from app.repository import ContactRepository contact = await ContactRepository.get_by_key(contact_public_key_hex) display_name = contact.name if contact else None background_tasks.add_task( run_historical_dm_decryption, private_key_bytes, contact_public_key_bytes, contact_public_key_hex, display_name, ) response.status_code = status.HTTP_202_ACCEPTED return DecryptResult( started=True, total_packets=count, message=f"Started DM decryption of {count} TEXT_MESSAGE packets in background", ) raise _bad_request("key_type must be 'channel' or 'contact'") class MaintenanceRequest(BaseModel): prune_undecrypted_days: int | None = Field( default=None, ge=1, description="Delete undecrypted packets older than this many days" ) purge_linked_raw_packets: bool = Field( default=False, description="Delete raw packets already linked to a stored message", ) class MaintenanceResult(BaseModel): packets_deleted: int vacuumed: bool @router.post("/maintenance", response_model=MaintenanceResult) async def run_maintenance(request: MaintenanceRequest) -> MaintenanceResult: """ Run packet maintenance tasks and reclaim disk space. - Optionally deletes undecrypted packets older than the specified number of days - Optionally deletes raw packets already linked to stored messages - Runs VACUUM to reclaim disk space """ deleted = 0 if request.prune_undecrypted_days is not None: logger.info( "Running maintenance: pruning undecrypted packets older than %d days", request.prune_undecrypted_days, ) pruned_undecrypted = await RawPacketRepository.prune_old_undecrypted( request.prune_undecrypted_days ) deleted += pruned_undecrypted logger.info("Deleted %d old undecrypted packets", pruned_undecrypted) if request.purge_linked_raw_packets: logger.info("Running maintenance: purging raw packets linked to stored messages") purged_linked = await RawPacketRepository.purge_linked_to_messages() deleted += purged_linked logger.info("Deleted %d linked raw packets", purged_linked) # Run VACUUM to reclaim space on a dedicated connection. # VACUUM requires exclusive access — if the main connection is actively # writing (background sync, message processing, etc.) it fails with # SQLITE_BUSY. This is expected; we just report vacuumed=False. vacuumed = False try: async with aiosqlite.connect(db.db_path) as vacuum_conn: await vacuum_conn.executescript("VACUUM;") vacuumed = True logger.info("Database vacuumed") except OperationalError as e: logger.warning("VACUUM skipped (database busy): %s", e) except Exception as e: logger.error("VACUUM failed unexpectedly: %s", e) return MaintenanceResult(packets_deleted=deleted, vacuumed=vacuumed)