Add fallback polling message persistence for channel messages

This commit is contained in:
Jack Kingsman
2026-03-13 11:05:49 -07:00
parent 70d28e53a9
commit 9c2b6f0744
5 changed files with 289 additions and 3 deletions

View File

@@ -135,6 +135,7 @@ class RadioManager:
self.path_hash_mode_supported: bool = False
self._channel_slot_by_key: OrderedDict[str, int] = OrderedDict()
self._channel_key_by_slot: dict[int, str] = {}
self._pending_message_channel_key_by_slot: dict[int, str] = {}
async def _acquire_operation_lock(
self,
@@ -232,6 +233,18 @@ class RadioManager:
self._channel_slot_by_key.clear()
self._channel_key_by_slot.clear()
def remember_pending_message_channel_slot(self, channel_key: str, slot: int) -> None:
"""Remember a channel key for later queued-message recovery."""
self._pending_message_channel_key_by_slot[slot] = channel_key.upper()
def get_pending_message_channel_key(self, slot: int) -> str | None:
"""Return the last remembered channel key for a radio slot."""
return self._pending_message_channel_key_by_slot.get(slot)
def clear_pending_message_channel_slots(self) -> None:
"""Drop any queued-message recovery slot metadata."""
self._pending_message_channel_key_by_slot.clear()
def channel_slot_reuse_enabled(self) -> bool:
"""Return whether this transport can safely reuse cached channel slots."""
if settings.force_channel_slot_reconfigure:
@@ -477,6 +490,7 @@ class RadioManager:
self.path_hash_mode = 0
self.path_hash_mode_supported = False
self.reset_channel_send_cache()
self.clear_pending_message_channel_slots()
logger.debug("Radio disconnected")
async def reconnect(self, *, broadcast_on_success: bool = True) -> bool:

View File

@@ -28,8 +28,9 @@ from app.repository import (
ContactRepository,
)
from app.services.contact_reconciliation import reconcile_contact_messages
from app.services.messages import create_fallback_channel_message
from app.services.radio_runtime import radio_runtime as radio_manager
from app.websocket import broadcast_error
from app.websocket import broadcast_error, broadcast_event
logger = logging.getLogger(__name__)
@@ -326,6 +327,7 @@ async def sync_and_offload_channels(mc: MeshCore, max_channels: int | None = Non
if key_hex is None:
continue
radio_manager.remember_pending_message_channel_slot(key_hex, idx)
synced += 1
logger.debug("Synced channel %s: %s", key_hex[:8], result.payload.get("channel_name"))
@@ -352,6 +354,87 @@ async def sync_and_offload_channels(mc: MeshCore, max_channels: int | None = Non
return {"synced": synced, "cleared": cleared}
def _split_channel_sender_and_text(text: str) -> tuple[str | None, str]:
"""Parse the canonical MeshCore "<sender>: <message>" channel text format."""
sender = None
message_text = text
colon_idx = text.find(": ")
if 0 < colon_idx < 50:
potential_sender = text[:colon_idx]
if not any(char in potential_sender for char in ":[]\x00"):
sender = potential_sender
message_text = text[colon_idx + 2 :]
return sender, message_text
async def _resolve_channel_for_pending_message(
mc: MeshCore,
channel_idx: int,
) -> tuple[str | None, str | None]:
"""Resolve a pending channel message's slot to a channel key and name."""
try:
result = await mc.commands.get_channel(channel_idx)
except Exception as exc:
logger.debug("Failed to fetch channel slot %s for pending message: %s", channel_idx, exc)
else:
if result.type == EventType.CHANNEL_INFO:
key_hex = await upsert_channel_from_radio_slot(result.payload, on_radio=False)
if key_hex is not None:
radio_manager.remember_pending_message_channel_slot(key_hex, channel_idx)
return key_hex, result.payload.get("channel_name") or None
current_slot_map = getattr(radio_manager, "_channel_key_by_slot", {})
cached_key = current_slot_map.get(channel_idx)
if cached_key is None:
cached_key = radio_manager.get_pending_message_channel_key(channel_idx)
if cached_key is None:
return None, None
channel = await ChannelRepository.get_by_key(cached_key)
return cached_key, channel.name if channel else None
async def _store_pending_channel_message(mc: MeshCore, payload: dict) -> None:
"""Persist a CHANNEL_MSG_RECV event pulled via get_msg()."""
channel_idx = payload.get("channel_idx")
if channel_idx is None:
logger.warning("Pending channel message missing channel_idx; dropping payload")
return
try:
normalized_channel_idx = int(channel_idx)
except (TypeError, ValueError):
logger.warning("Pending channel message had invalid channel_idx=%r", channel_idx)
return
channel_key, channel_name = await _resolve_channel_for_pending_message(
mc, normalized_channel_idx
)
if channel_key is None:
logger.warning(
"Could not resolve channel slot %d for pending message; message cannot be stored",
normalized_channel_idx,
)
return
received_at = int(time.time())
sender_timestamp = payload.get("sender_timestamp") or received_at
sender_name, message_text = _split_channel_sender_and_text(payload.get("text", ""))
await create_fallback_channel_message(
conversation_key=channel_key,
message_text=message_text,
sender_timestamp=sender_timestamp,
received_at=received_at,
path=payload.get("path"),
path_len=payload.get("path_len"),
txt_type=payload.get("txt_type", 0),
sender_name=sender_name,
channel_name=channel_name,
broadcast_fn=broadcast_event,
)
async def ensure_default_channels() -> None:
"""
Ensure default channels exist in the database.
@@ -421,6 +504,8 @@ async def drain_pending_messages(mc: MeshCore) -> int:
logger.debug("Error during message drain: %s", result.payload)
break
elif result.type in (EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV):
if result.type == EventType.CHANNEL_MSG_RECV:
await _store_pending_channel_message(mc, result.payload)
count += 1
# Small delay between fetches
@@ -456,6 +541,8 @@ async def poll_for_messages(mc: MeshCore) -> int:
elif result.type == EventType.ERROR:
return 0
elif result.type in (EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV):
if result.type == EventType.CHANNEL_MSG_RECV:
await _store_pending_channel_message(mc, result.payload)
count += 1
# If we got a message, there might be more - drain them
count += await drain_pending_messages(mc)

View File

@@ -127,7 +127,7 @@ async def increment_ack_and_broadcast(
async def handle_duplicate_message(
*,
packet_id: int,
packet_id: int | None,
msg_type: str,
conversation_key: str,
text: str,
@@ -179,7 +179,8 @@ async def handle_duplicate_message(
broadcast_fn=broadcast_fn,
)
await RawPacketRepository.mark_decrypted(packet_id, existing_msg.id)
if packet_id is not None:
await RawPacketRepository.mark_decrypted(packet_id, existing_msg.id)
async def create_message_from_decrypted(
@@ -396,6 +397,73 @@ async def create_fallback_direct_message(
return message
async def create_fallback_channel_message(
*,
conversation_key: str,
message_text: str,
sender_timestamp: int,
received_at: int,
path: str | None,
path_len: int | None,
txt_type: int,
sender_name: str | None,
channel_name: str | None,
broadcast_fn: BroadcastFn,
message_repository=MessageRepository,
) -> Message | None:
"""Store and broadcast a CHANNEL_MSG_RECV fallback channel message."""
conversation_key_normalized = conversation_key.upper()
text = f"{sender_name}: {message_text}" if sender_name else message_text
resolved_sender_key: str | None = None
if sender_name:
candidates = await ContactRepository.get_by_name(sender_name)
if len(candidates) == 1:
resolved_sender_key = candidates[0].public_key
msg_id = await message_repository.create(
msg_type="CHAN",
text=text,
conversation_key=conversation_key_normalized,
sender_timestamp=sender_timestamp,
received_at=received_at,
path=path,
path_len=path_len,
txt_type=txt_type,
sender_name=sender_name,
sender_key=resolved_sender_key,
)
if msg_id is None:
await handle_duplicate_message(
packet_id=None,
msg_type="CHAN",
conversation_key=conversation_key_normalized,
text=text,
sender_timestamp=sender_timestamp,
path=path,
received_at=received_at,
path_len=path_len,
broadcast_fn=broadcast_fn,
)
return None
message = build_message_model(
message_id=msg_id,
msg_type="CHAN",
conversation_key=conversation_key_normalized,
text=text,
sender_timestamp=sender_timestamp,
received_at=received_at,
paths=build_message_paths(path, received_at, path_len),
txt_type=txt_type,
sender_name=sender_name,
sender_key=resolved_sender_key,
channel_name=channel_name,
)
broadcast_message(message=message, broadcast_fn=broadcast_fn)
return message
async def create_outgoing_direct_message(
*,
conversation_key: str,

View File

@@ -137,6 +137,7 @@ async def run_post_connect_setup(radio_manager) -> None:
drained = await drain_pending_messages(mc)
if drained > 0:
logger.info("Drained %d pending message(s)", drained)
radio_manager.clear_pending_message_channel_slots()
await mc.start_auto_message_fetching()
logger.info("Auto message fetching started")

View File

@@ -43,6 +43,7 @@ def reset_sync_state():
prev_connection_info = radio_manager._connection_info
prev_slot_by_key = radio_manager._channel_slot_by_key.copy()
prev_key_by_slot = radio_manager._channel_key_by_slot.copy()
prev_pending_channel_key_by_slot = radio_manager._pending_message_channel_key_by_slot.copy()
radio_sync._polling_pause_count = 0
radio_sync._last_contact_sync = 0.0
@@ -55,6 +56,7 @@ def reset_sync_state():
radio_manager._connection_info = prev_connection_info
radio_manager._channel_slot_by_key = prev_slot_by_key
radio_manager._channel_key_by_slot = prev_key_by_slot
radio_manager._pending_message_channel_key_by_slot = prev_pending_channel_key_by_slot
KEY_A = "aa" * 32
@@ -1091,6 +1093,120 @@ class TestSyncAndOffloadChannels:
assert radio_manager.get_cached_channel_slot("AA" * 16) is None
@pytest.mark.asyncio
async def test_remembers_channel_slot_for_pending_message_recovery(self, test_db):
"""Offload snapshots slot-to-key mapping for the later startup drain."""
from app.radio_sync import sync_and_offload_channels
channel_key = "11" * 16
channel_result = MagicMock()
channel_result.type = EventType.CHANNEL_INFO
channel_result.payload = {
"channel_name": "#queued",
"channel_secret": bytes.fromhex(channel_key),
}
empty_result = MagicMock()
empty_result.type = EventType.ERROR
mock_mc = MagicMock()
mock_mc.commands.get_channel = AsyncMock(side_effect=[channel_result] + [empty_result] * 39)
mock_mc.commands.set_channel = AsyncMock(return_value=MagicMock(type=EventType.OK))
await sync_and_offload_channels(mock_mc)
assert radio_manager.get_pending_message_channel_key(0) == channel_key.upper()
class TestPendingChannelMessageFallback:
"""Queued CHANNEL_MSG_RECV events should be persisted instead of dropped."""
@pytest.mark.asyncio
async def test_drain_pending_messages_uses_snapshotted_slot_mapping_after_offload(
self, test_db
):
"""Startup drain can still store room traffic even after slots were cleared."""
from app.radio_sync import drain_pending_messages
channel_key = "22" * 16
await ChannelRepository.upsert(key=channel_key, name="#queued")
radio_manager.remember_pending_message_channel_slot(channel_key, 3)
channel_message = MagicMock()
channel_message.type = EventType.CHANNEL_MSG_RECV
channel_message.payload = {
"channel_idx": 3,
"text": "Alice: hello from queue",
"sender_timestamp": 1700000000,
"txt_type": 0,
"path": "aabb",
"path_len": 2,
}
no_more = MagicMock()
no_more.type = EventType.NO_MORE_MSGS
no_more.payload = {}
empty_slot = MagicMock()
empty_slot.type = EventType.ERROR
empty_slot.payload = {"error": "slot empty"}
mock_mc = MagicMock()
mock_mc.commands.get_msg = AsyncMock(side_effect=[channel_message, no_more])
mock_mc.commands.get_channel = AsyncMock(return_value=empty_slot)
with patch("app.radio_sync.broadcast_event") as mock_broadcast:
drained = await drain_pending_messages(mock_mc)
assert drained == 1
stored = await MessageRepository.get_all(msg_type="CHAN", conversation_key=channel_key)
assert len(stored) == 1
assert stored[0].text == "Alice: hello from queue"
assert stored[0].sender_name == "Alice"
assert stored[0].conversation_key == channel_key
assert stored[0].paths is not None
assert stored[0].paths[0].path == "aabb"
mock_broadcast.assert_called_once()
@pytest.mark.asyncio
async def test_poll_for_messages_stores_first_pending_channel_message(self, test_db):
"""Single-pass polling stores the first queued channel message before draining."""
from app.radio_sync import poll_for_messages
channel_key = "33" * 16
channel_result = MagicMock()
channel_result.type = EventType.CHANNEL_INFO
channel_result.payload = {
"channel_name": "#poll",
"channel_secret": bytes.fromhex(channel_key),
}
channel_message = MagicMock()
channel_message.type = EventType.CHANNEL_MSG_RECV
channel_message.payload = {
"channel_idx": 1,
"text": "Bob: polled message",
"sender_timestamp": 1700000010,
"txt_type": 0,
}
no_more = MagicMock()
no_more.type = EventType.NO_MORE_MSGS
no_more.payload = {}
mock_mc = MagicMock()
mock_mc.commands.get_msg = AsyncMock(side_effect=[channel_message, no_more])
mock_mc.commands.get_channel = AsyncMock(return_value=channel_result)
with patch("app.radio_sync.broadcast_event"):
count = await poll_for_messages(mock_mc)
assert count == 1
stored = await MessageRepository.get_all(msg_type="CHAN", conversation_key=channel_key)
assert len(stored) == 1
assert stored[0].text == "Bob: polled message"
class TestEnsureDefaultChannels:
"""Test ensure_default_channels: create/fix the Public channel."""