Add DM decryption on new contact advert

This commit is contained in:
Jack Kingsman
2026-01-18 23:13:45 -08:00
parent 71ac7f1c6e
commit 43b7e94b0a
12 changed files with 290 additions and 234 deletions

View File

@@ -20,6 +20,7 @@ from app.decoder import (
DecryptedDirectMessage,
PacketInfo,
PayloadType,
derive_public_key,
parse_advertisement,
parse_packet,
try_decrypt_dm,
@@ -33,7 +34,7 @@ from app.repository import (
MessageRepository,
RawPacketRepository,
)
from app.websocket import broadcast_event
from app.websocket import broadcast_error, broadcast_event
logger = logging.getLogger(__name__)
@@ -291,6 +292,125 @@ async def create_dm_message_from_decrypted(
return msg_id
async def run_historical_dm_decryption(
private_key_bytes: bytes,
contact_public_key_bytes: bytes,
contact_public_key_hex: str,
display_name: str | None = None,
) -> None:
"""Background task to decrypt historical DM packets with contact's key."""
from app.websocket import broadcast_success
packets = await RawPacketRepository.get_undecrypted_text_messages()
total = len(packets)
decrypted_count = 0
if total == 0:
logger.info("No undecrypted TEXT_MESSAGE packets to process")
return
logger.info("Starting historical DM decryption of %d TEXT_MESSAGE packets", total)
# Derive our public key from the private key
our_public_key_bytes = derive_public_key(private_key_bytes)
for packet_id, packet_data, packet_timestamp in packets:
# Don't pass our_public_key - we want to decrypt both incoming AND outgoing messages.
result = try_decrypt_dm(
packet_data,
private_key_bytes,
contact_public_key_bytes,
our_public_key=None,
)
if result is not None:
# Determine direction by checking src_hash
src_hash = result.src_hash.lower()
our_first_byte = format(our_public_key_bytes[0], "02x").lower()
outgoing = src_hash == our_first_byte
# Extract path from the raw packet for storage
packet_info = parse_packet(packet_data)
path_hex = packet_info.path.hex() if packet_info else None
msg_id = await create_dm_message_from_decrypted(
packet_id=packet_id,
decrypted=result,
their_public_key=contact_public_key_hex,
our_public_key=our_public_key_bytes.hex(),
received_at=packet_timestamp,
path=path_hex,
outgoing=outgoing,
)
if msg_id is not None:
decrypted_count += 1
logger.info(
"Historical DM decryption complete: %d/%d packets decrypted",
decrypted_count,
total,
)
# Notify frontend
if decrypted_count > 0:
name = display_name or contact_public_key_hex[:12]
broadcast_success(
f"Historical decrypt complete for {name}",
f"Decrypted {decrypted_count} message{'s' if decrypted_count != 1 else ''}",
)
async def start_historical_dm_decryption(
background_tasks,
contact_public_key_hex: str,
display_name: str | None = None,
) -> None:
"""Start historical DM decryption using the stored private key."""
if not has_private_key():
logger.warning(
"Cannot start historical DM decryption: private key not available. "
"Ensure radio firmware has ENABLE_PRIVATE_KEY_EXPORT=1."
)
broadcast_error(
"Cannot decrypt historical DMs",
"Private key not available. Radio firmware may need ENABLE_PRIVATE_KEY_EXPORT=1.",
)
return
private_key_bytes = get_private_key()
if private_key_bytes is None:
return
try:
contact_public_key_bytes = bytes.fromhex(contact_public_key_hex)
except ValueError:
logger.warning(
"Cannot start historical DM decryption: invalid contact key %s",
contact_public_key_hex,
)
return
logger.info("Starting historical DM decryption for contact %s", contact_public_key_hex[:12])
if background_tasks is None:
asyncio.create_task(
run_historical_dm_decryption(
private_key_bytes,
contact_public_key_bytes,
contact_public_key_hex.lower(),
display_name,
)
)
else:
background_tasks.add_task(
run_historical_dm_decryption,
private_key_bytes,
contact_public_key_bytes,
contact_public_key_hex.lower(),
display_name,
)
async def process_raw_packet(
raw_bytes: bytes,
timestamp: int | None = None,
@@ -539,6 +659,10 @@ async def _process_advertisement(
},
)
# For new contacts, attempt to decrypt any historical DMs we may have stored
if existing is None:
await start_historical_dm_decryption(None, advert.public_key, advert.name)
# If this is not a repeater, trigger recent contacts sync to radio
# This ensures we can auto-ACK DMs from recent contacts
if contact_type != CONTACT_TYPE_REPEATER:

View File

@@ -579,7 +579,7 @@ class RawPacketRepository:
if existing:
# Duplicate - return existing packet ID
logger.info(
logger.debug(
"Duplicate payload detected (hash=%s..., existing_id=%d)",
payload_hash[:12],
existing["id"],

View File

@@ -16,10 +16,10 @@ from app.models import (
TelemetryRequest,
TelemetryResponse,
)
from app.packet_processor import start_historical_dm_decryption
from app.radio import radio_manager
from app.radio_sync import pause_polling
from app.repository import ContactRepository
from app.routers.packets import _run_historical_dm_decryption
logger = logging.getLogger(__name__)
@@ -83,7 +83,7 @@ async def create_contact(
"""
# Validate hex format
try:
contact_public_key_bytes = bytes.fromhex(request.public_key)
bytes.fromhex(request.public_key)
except ValueError as e:
raise HTTPException(status_code=400, detail="Invalid public key: must be valid hex") from e
@@ -112,8 +112,8 @@ async def create_contact(
# Trigger historical decryption if requested (even for existing contacts)
if request.try_historical:
await _start_historical_dm_decryption(
background_tasks, contact_public_key_bytes, request.public_key
await start_historical_dm_decryption(
background_tasks, request.public_key, request.name or existing.name
)
return existing
@@ -138,45 +138,11 @@ async def create_contact(
# Trigger historical decryption if requested
if request.try_historical:
await _start_historical_dm_decryption(
background_tasks, contact_public_key_bytes, request.public_key
)
await start_historical_dm_decryption(background_tasks, request.public_key, request.name)
return Contact(**contact_data)
async def _start_historical_dm_decryption(
background_tasks: BackgroundTasks,
contact_public_key_bytes: bytes,
contact_public_key_hex: str,
) -> None:
"""Start historical DM decryption using the stored private key."""
from app.keystore import get_private_key, has_private_key
from app.websocket import broadcast_error
if not has_private_key():
logger.warning(
"Cannot start historical DM decryption: private key not available. "
"Ensure radio firmware has ENABLE_PRIVATE_KEY_EXPORT=1."
)
broadcast_error(
"Cannot decrypt historical DMs",
"Private key not available. Radio firmware may need ENABLE_PRIVATE_KEY_EXPORT=1.",
)
return
private_key_bytes = get_private_key()
assert private_key_bytes is not None # Guaranteed by has_private_key check
logger.info("Starting historical DM decryption for contact %s", contact_public_key_hex[:12])
background_tasks.add_task(
_run_historical_dm_decryption,
private_key_bytes,
contact_public_key_bytes,
contact_public_key_hex.lower(),
)
@router.get("/{public_key}", response_model=Contact)
async def get_contact(public_key: str) -> Contact:
"""Get a specific contact by public key or prefix."""

View File

@@ -5,14 +5,10 @@ from fastapi import APIRouter, BackgroundTasks
from pydantic import BaseModel, Field
from app.database import db
from app.decoder import (
derive_public_key,
parse_packet,
try_decrypt_dm,
try_decrypt_packet_with_channel_key,
)
from app.packet_processor import create_dm_message_from_decrypted, create_message_from_decrypted
from app.repository import RawPacketRepository
from app.decoder import parse_packet, try_decrypt_packet_with_channel_key
from app.packet_processor import create_message_from_decrypted, run_historical_dm_decryption
from app.repository import ChannelRepository, RawPacketRepository
from app.websocket import broadcast_success
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/packets", tags=["packets"])
@@ -42,42 +38,24 @@ class DecryptResult(BaseModel):
message: str
class DecryptProgress(BaseModel):
total: int
processed: int
decrypted: int
in_progress: bool
# Global state for tracking decryption progress
_decrypt_progress: DecryptProgress | None = None
async def _run_historical_decryption(channel_key_bytes: bytes, channel_key_hex: str) -> None:
async def _run_historical_channel_decryption(
channel_key_bytes: bytes, channel_key_hex: str, display_name: str | None = None
) -> None:
"""Background task to decrypt historical packets with a channel key."""
global _decrypt_progress
packets = await RawPacketRepository.get_all_undecrypted()
total = len(packets)
processed = 0
decrypted_count = 0
_decrypt_progress = DecryptProgress(total=total, processed=0, decrypted=0, in_progress=True)
if total == 0:
logger.info("No undecrypted packets to process")
return
logger.info("Starting historical decryption of %d packets", total)
logger.info("Starting historical channel decryption of %d packets", total)
for packet_id, packet_data, packet_timestamp in packets:
result = try_decrypt_packet_with_channel_key(packet_data, channel_key_bytes)
if result is not None:
# Successfully decrypted - use shared logic to store message
logger.debug(
"Decrypted packet %d: sender=%s, message=%s",
packet_id,
result.sender,
result.message[:50] if result.message else "",
)
# Extract path from the raw packet for storage
packet_info = parse_packet(packet_data)
path_hex = packet_info.path.hex() if packet_info else None
@@ -88,102 +66,25 @@ async def _run_historical_decryption(channel_key_bytes: bytes, channel_key_hex:
sender=result.sender,
message_text=result.message,
timestamp=result.timestamp,
received_at=packet_timestamp, # Use original packet timestamp for correct ordering
path=path_hex,
)
if msg_id is not None:
decrypted_count += 1
processed += 1
_decrypt_progress = DecryptProgress(
total=total, processed=processed, decrypted=decrypted_count, in_progress=True
)
_decrypt_progress = DecryptProgress(
total=total, processed=processed, decrypted=decrypted_count, in_progress=False
)
logger.info("Historical decryption complete: %d/%d packets decrypted", decrypted_count, total)
async def _run_historical_dm_decryption(
private_key_bytes: bytes,
contact_public_key_bytes: bytes,
contact_public_key_hex: str,
) -> None:
"""Background task to decrypt historical DM packets with contact's key."""
global _decrypt_progress
# Get only TEXT_MESSAGE packets (undecrypted)
packets = await RawPacketRepository.get_undecrypted_text_messages()
total = len(packets)
processed = 0
decrypted_count = 0
_decrypt_progress = DecryptProgress(total=total, processed=0, decrypted=0, in_progress=True)
logger.info("Starting historical DM decryption of %d TEXT_MESSAGE packets", total)
# Derive our public key from the private key using Ed25519 scalar multiplication.
# Note: MeshCore stores the scalar directly (not a seed), so we use noclamp variant.
# See derive_public_key() for details on the MeshCore key format.
our_public_key_bytes = derive_public_key(private_key_bytes)
for packet_id, packet_data, packet_timestamp in packets:
# Don't pass our_public_key - we want to decrypt both incoming AND outgoing messages.
# The our_public_key filter in try_decrypt_dm only matches incoming (dest_hash == us),
# which would skip outgoing messages (where dest_hash == contact).
result = try_decrypt_dm(
packet_data,
private_key_bytes,
contact_public_key_bytes,
our_public_key=None,
)
if result is not None:
# Successfully decrypted - determine if inbound or outbound by checking src_hash
src_hash = result.src_hash.lower()
our_first_byte = format(our_public_key_bytes[0], "02x").lower()
outgoing = src_hash == our_first_byte
logger.debug(
"Decrypted DM packet %d: message=%s (outgoing=%s)",
packet_id,
result.message[:50] if result.message else "",
outgoing,
)
# Extract path from the raw packet for storage
packet_info = parse_packet(packet_data)
path_hex = packet_info.path.hex() if packet_info else None
msg_id = await create_dm_message_from_decrypted(
packet_id=packet_id,
decrypted=result,
their_public_key=contact_public_key_hex,
our_public_key=our_public_key_bytes.hex(),
received_at=packet_timestamp,
path=path_hex,
outgoing=outgoing,
)
if msg_id is not None:
decrypted_count += 1
processed += 1
_decrypt_progress = DecryptProgress(
total=total, processed=processed, decrypted=decrypted_count, in_progress=True
)
_decrypt_progress = DecryptProgress(
total=total, processed=processed, decrypted=decrypted_count, in_progress=False
)
logger.info(
"Historical DM decryption complete: %d/%d packets decrypted", decrypted_count, total
"Historical channel decryption complete: %d/%d packets decrypted", decrypted_count, total
)
# Notify frontend
if decrypted_count > 0:
name = display_name or channel_key_hex[:12]
broadcast_success(
f"Historical decrypt complete for {name}",
f"Decrypted {decrypted_count} message{'s' if decrypted_count != 1 else ''}",
)
@router.get("/undecrypted/count")
async def get_undecrypted_count() -> dict:
@@ -198,25 +99,11 @@ async def decrypt_historical_packets(
) -> DecryptResult:
"""
Attempt to decrypt historical packets with the provided key.
Runs in the background to avoid blocking.
Runs in the background. Multiple decrypt jobs can run concurrently.
"""
global _decrypt_progress
# Check if decryption is already in progress
if _decrypt_progress and _decrypt_progress.in_progress:
return DecryptResult(
started=False,
total_packets=_decrypt_progress.total,
message=f"Decryption already in progress: {_decrypt_progress.processed}/{_decrypt_progress.total}",
)
# Determine the channel key
channel_key_bytes: bytes | None = None
channel_key_hex: str | None = None
if request.key_type == "channel":
# Channel decryption
if request.channel_key:
# Direct key provided
try:
channel_key_bytes = bytes.fromhex(request.channel_key)
if len(channel_key_bytes) != 16:
@@ -233,7 +120,6 @@ async def decrypt_historical_packets(
message="Invalid hex string for channel key",
)
elif request.channel_name:
# Derive key from channel name (hashtag channel)
channel_key_bytes = sha256(request.channel_name.encode("utf-8")).digest()[:16]
channel_key_hex = channel_key_bytes.hex().upper()
else:
@@ -242,8 +128,30 @@ async def decrypt_historical_packets(
total_packets=0,
message="Must provide channel_key or channel_name",
)
# Get count and lookup channel name for display
count = await RawPacketRepository.get_undecrypted_count()
if count == 0:
return DecryptResult(
started=False, total_packets=0, message="No undecrypted packets to process"
)
# Try to find channel name for display
channel = await ChannelRepository.get_by_key(channel_key_hex)
display_name = channel.name if channel else request.channel_name
background_tasks.add_task(
_run_historical_channel_decryption, channel_key_bytes, channel_key_hex, display_name
)
return DecryptResult(
started=True,
total_packets=count,
message=f"Started channel decryption of {count} packets in background",
)
elif request.key_type == "contact":
# Validate required fields for contact decryption
# DM decryption
if not request.private_key:
return DecryptResult(
started=False,
@@ -257,7 +165,6 @@ async def decrypt_historical_packets(
message="Must provide contact_public_key for contact decryption",
)
# Parse private key
try:
private_key_bytes = bytes.fromhex(request.private_key)
if len(private_key_bytes) != 64:
@@ -273,7 +180,6 @@ async def decrypt_historical_packets(
message="Invalid hex string for private key",
)
# Parse contact public key
try:
contact_public_key_bytes = bytes.fromhex(request.contact_public_key)
if len(contact_public_key_bytes) != 32:
@@ -290,7 +196,6 @@ async def decrypt_historical_packets(
message="Invalid hex string for contact public key",
)
# Get count of undecrypted TEXT_MESSAGE packets
packets = await RawPacketRepository.get_undecrypted_text_messages()
count = len(packets)
if count == 0:
@@ -300,12 +205,18 @@ async def decrypt_historical_packets(
message="No undecrypted TEXT_MESSAGE packets to process",
)
# Start background decryption
# Try to find contact name for display
from app.repository import ContactRepository
contact = await ContactRepository.get_by_key_or_prefix(contact_public_key_hex)
display_name = contact.name if contact else None
background_tasks.add_task(
_run_historical_dm_decryption,
run_historical_dm_decryption,
private_key_bytes,
contact_public_key_bytes,
contact_public_key_hex,
display_name,
)
return DecryptResult(
@@ -313,36 +224,14 @@ async def decrypt_historical_packets(
total_packets=count,
message=f"Started DM decryption of {count} TEXT_MESSAGE packets in background",
)
else:
return DecryptResult(
started=False,
total_packets=0,
message="key_type must be 'channel' or 'contact'",
)
# Get count of undecrypted packets
count = await RawPacketRepository.get_undecrypted_count()
if count == 0:
return DecryptResult(
started=False, total_packets=0, message="No undecrypted packets to process"
)
# Start background decryption
background_tasks.add_task(_run_historical_decryption, channel_key_bytes, channel_key_hex)
return DecryptResult(
started=True,
total_packets=count,
message=f"Started decryption of {count} packets in background",
started=False,
total_packets=0,
message="key_type must be 'channel' or 'contact'",
)
@router.get("/decrypt/progress", response_model=DecryptProgress | None)
async def get_decrypt_progress() -> DecryptProgress | None:
"""Get the current progress of historical decryption."""
return _decrypt_progress
class MaintenanceRequest(BaseModel):
prune_undecrypted_days: int = Field(
ge=1, description="Delete undecrypted packets older than this many days"

View File

@@ -86,6 +86,17 @@ def broadcast_error(message: str, details: str | None = None) -> None:
asyncio.create_task(ws_manager.broadcast("error", data))
def broadcast_success(message: str, details: str | None = None) -> None:
"""Broadcast a success notification to all connected clients.
This appears as a toast notification in the frontend.
"""
data = {"message": message}
if details:
data["details"] = details
asyncio.create_task(ws_manager.broadcast("success", data))
def broadcast_health(radio_connected: bool, serial_port: str | None = None) -> None:
"""Broadcast health status change to all connected clients."""
from app.repository import RawPacketRepository

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -13,7 +13,7 @@
<link rel="shortcut icon" href="/favicon.ico" />
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
<link rel="manifest" href="/site.webmanifest" />
<script type="module" crossorigin src="/assets/index-CSUvhn5B.js"></script>
<script type="module" crossorigin src="/assets/index-Cjj7DnBW.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-C-fUaa04.css">
</head>
<body>

View File

@@ -137,6 +137,11 @@ export function App() {
description: error.details,
});
},
onSuccess: (success: { message: string; details?: string }) => {
toast.success(success.message, {
description: success.details,
});
},
onContacts: (data: Contact[]) => setContacts(data),
onChannels: (data: Channel[]) => setChannels(data),
onMessage: (msg: Message) => {

View File

@@ -11,6 +11,11 @@ interface ErrorEvent {
details?: string;
}
interface SuccessEvent {
message: string;
details?: string;
}
interface UseWebSocketOptions {
onHealth?: (health: HealthStatus) => void;
onContacts?: (contacts: Contact[]) => void;
@@ -20,6 +25,7 @@ interface UseWebSocketOptions {
onRawPacket?: (packet: RawPacket) => void;
onMessageAcked?: (messageId: number, ackCount: number, paths?: MessagePath[]) => void;
onError?: (error: ErrorEvent) => void;
onSuccess?: (success: SuccessEvent) => void;
}
export function useWebSocket(options: UseWebSocketOptions) {
@@ -94,6 +100,9 @@ export function useWebSocket(options: UseWebSocketOptions) {
case 'error':
options.onError?.(msg.data as ErrorEvent);
break;
case 'success':
options.onSuccess?.(msg.data as SuccessEvent);
break;
case 'pong':
// Heartbeat response, ignore
break;

View File

@@ -9,7 +9,7 @@ between backend and frontend - both sides test against the same data.
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -256,6 +256,58 @@ class TestAdvertisementPipeline:
# Empty path stored as None or ""
assert contact.last_path in (None, "")
@pytest.mark.asyncio
async def test_advertisement_triggers_historical_decrypt_for_new_contact(
self, test_db, captured_broadcasts
):
"""New contact via advertisement starts historical DM decryption."""
from app.packet_processor import process_raw_packet
fixture = FIXTURES["advertisement_with_gps"]
packet_bytes = bytes.fromhex(fixture["raw_packet_hex"])
expected = fixture["expected_ws_event"]["data"]
broadcasts, mock_broadcast = captured_broadcasts
with patch("app.packet_processor.broadcast_event", mock_broadcast):
with patch(
"app.packet_processor.start_historical_dm_decryption", new=AsyncMock()
) as mock_start:
await process_raw_packet(packet_bytes, timestamp=1700000000)
mock_start.assert_awaited_once_with(None, expected["public_key"], expected["name"])
@pytest.mark.asyncio
async def test_advertisement_skips_historical_decrypt_for_existing_contact(
self, test_db, captured_broadcasts
):
"""Existing contact via advertisement does not start historical DM decryption."""
from app.packet_processor import process_raw_packet
fixture = FIXTURES["advertisement_chat_node"]
packet_bytes = bytes.fromhex(fixture["raw_packet_hex"])
expected = fixture["expected_ws_event"]["data"]
await ContactRepository.upsert(
{
"public_key": expected["public_key"],
"name": "Existing",
"type": 0,
"lat": None,
"lon": None,
}
)
broadcasts, mock_broadcast = captured_broadcasts
with patch("app.packet_processor.broadcast_event", mock_broadcast):
with patch(
"app.packet_processor.start_historical_dm_decryption", new=AsyncMock()
) as mock_start:
await process_raw_packet(packet_bytes, timestamp=1700000000)
assert mock_start.await_count == 0
@pytest.mark.asyncio
async def test_advertisement_keeps_shorter_path_within_window(
self, test_db, captured_broadcasts