diff --git a/app/radio.py b/app/radio.py index c467e3a..68766f6 100644 --- a/app/radio.py +++ b/app/radio.py @@ -135,6 +135,7 @@ class RadioManager: self.path_hash_mode_supported: bool = False self._channel_slot_by_key: OrderedDict[str, int] = OrderedDict() self._channel_key_by_slot: dict[int, str] = {} + self._pending_message_channel_key_by_slot: dict[int, str] = {} async def _acquire_operation_lock( self, @@ -232,6 +233,18 @@ class RadioManager: self._channel_slot_by_key.clear() self._channel_key_by_slot.clear() + def remember_pending_message_channel_slot(self, channel_key: str, slot: int) -> None: + """Remember a channel key for later queued-message recovery.""" + self._pending_message_channel_key_by_slot[slot] = channel_key.upper() + + def get_pending_message_channel_key(self, slot: int) -> str | None: + """Return the last remembered channel key for a radio slot.""" + return self._pending_message_channel_key_by_slot.get(slot) + + def clear_pending_message_channel_slots(self) -> None: + """Drop any queued-message recovery slot metadata.""" + self._pending_message_channel_key_by_slot.clear() + def channel_slot_reuse_enabled(self) -> bool: """Return whether this transport can safely reuse cached channel slots.""" if settings.force_channel_slot_reconfigure: @@ -477,6 +490,7 @@ class RadioManager: self.path_hash_mode = 0 self.path_hash_mode_supported = False self.reset_channel_send_cache() + self.clear_pending_message_channel_slots() logger.debug("Radio disconnected") async def reconnect(self, *, broadcast_on_success: bool = True) -> bool: diff --git a/app/radio_sync.py b/app/radio_sync.py index 140ddf8..82cec7a 100644 --- a/app/radio_sync.py +++ b/app/radio_sync.py @@ -28,8 +28,9 @@ from app.repository import ( ContactRepository, ) from app.services.contact_reconciliation import reconcile_contact_messages +from app.services.messages import create_fallback_channel_message from app.services.radio_runtime import radio_runtime as radio_manager -from app.websocket import broadcast_error +from app.websocket import broadcast_error, broadcast_event logger = logging.getLogger(__name__) @@ -326,6 +327,7 @@ async def sync_and_offload_channels(mc: MeshCore, max_channels: int | None = Non if key_hex is None: continue + radio_manager.remember_pending_message_channel_slot(key_hex, idx) synced += 1 logger.debug("Synced channel %s: %s", key_hex[:8], result.payload.get("channel_name")) @@ -352,6 +354,87 @@ async def sync_and_offload_channels(mc: MeshCore, max_channels: int | None = Non return {"synced": synced, "cleared": cleared} +def _split_channel_sender_and_text(text: str) -> tuple[str | None, str]: + """Parse the canonical MeshCore ": " channel text format.""" + sender = None + message_text = text + colon_idx = text.find(": ") + if 0 < colon_idx < 50: + potential_sender = text[:colon_idx] + if not any(char in potential_sender for char in ":[]\x00"): + sender = potential_sender + message_text = text[colon_idx + 2 :] + return sender, message_text + + +async def _resolve_channel_for_pending_message( + mc: MeshCore, + channel_idx: int, +) -> tuple[str | None, str | None]: + """Resolve a pending channel message's slot to a channel key and name.""" + try: + result = await mc.commands.get_channel(channel_idx) + except Exception as exc: + logger.debug("Failed to fetch channel slot %s for pending message: %s", channel_idx, exc) + else: + if result.type == EventType.CHANNEL_INFO: + key_hex = await upsert_channel_from_radio_slot(result.payload, on_radio=False) + if key_hex is not None: + radio_manager.remember_pending_message_channel_slot(key_hex, channel_idx) + return key_hex, result.payload.get("channel_name") or None + + current_slot_map = getattr(radio_manager, "_channel_key_by_slot", {}) + cached_key = current_slot_map.get(channel_idx) + if cached_key is None: + cached_key = radio_manager.get_pending_message_channel_key(channel_idx) + if cached_key is None: + return None, None + + channel = await ChannelRepository.get_by_key(cached_key) + return cached_key, channel.name if channel else None + + +async def _store_pending_channel_message(mc: MeshCore, payload: dict) -> None: + """Persist a CHANNEL_MSG_RECV event pulled via get_msg().""" + channel_idx = payload.get("channel_idx") + if channel_idx is None: + logger.warning("Pending channel message missing channel_idx; dropping payload") + return + + try: + normalized_channel_idx = int(channel_idx) + except (TypeError, ValueError): + logger.warning("Pending channel message had invalid channel_idx=%r", channel_idx) + return + + channel_key, channel_name = await _resolve_channel_for_pending_message( + mc, normalized_channel_idx + ) + if channel_key is None: + logger.warning( + "Could not resolve channel slot %d for pending message; message cannot be stored", + normalized_channel_idx, + ) + return + + received_at = int(time.time()) + sender_timestamp = payload.get("sender_timestamp") or received_at + sender_name, message_text = _split_channel_sender_and_text(payload.get("text", "")) + + await create_fallback_channel_message( + conversation_key=channel_key, + message_text=message_text, + sender_timestamp=sender_timestamp, + received_at=received_at, + path=payload.get("path"), + path_len=payload.get("path_len"), + txt_type=payload.get("txt_type", 0), + sender_name=sender_name, + channel_name=channel_name, + broadcast_fn=broadcast_event, + ) + + async def ensure_default_channels() -> None: """ Ensure default channels exist in the database. @@ -421,6 +504,8 @@ async def drain_pending_messages(mc: MeshCore) -> int: logger.debug("Error during message drain: %s", result.payload) break elif result.type in (EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV): + if result.type == EventType.CHANNEL_MSG_RECV: + await _store_pending_channel_message(mc, result.payload) count += 1 # Small delay between fetches @@ -456,6 +541,8 @@ async def poll_for_messages(mc: MeshCore) -> int: elif result.type == EventType.ERROR: return 0 elif result.type in (EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV): + if result.type == EventType.CHANNEL_MSG_RECV: + await _store_pending_channel_message(mc, result.payload) count += 1 # If we got a message, there might be more - drain them count += await drain_pending_messages(mc) diff --git a/app/services/messages.py b/app/services/messages.py index 0832e3b..cd9b256 100644 --- a/app/services/messages.py +++ b/app/services/messages.py @@ -127,7 +127,7 @@ async def increment_ack_and_broadcast( async def handle_duplicate_message( *, - packet_id: int, + packet_id: int | None, msg_type: str, conversation_key: str, text: str, @@ -179,7 +179,8 @@ async def handle_duplicate_message( broadcast_fn=broadcast_fn, ) - await RawPacketRepository.mark_decrypted(packet_id, existing_msg.id) + if packet_id is not None: + await RawPacketRepository.mark_decrypted(packet_id, existing_msg.id) async def create_message_from_decrypted( @@ -396,6 +397,73 @@ async def create_fallback_direct_message( return message +async def create_fallback_channel_message( + *, + conversation_key: str, + message_text: str, + sender_timestamp: int, + received_at: int, + path: str | None, + path_len: int | None, + txt_type: int, + sender_name: str | None, + channel_name: str | None, + broadcast_fn: BroadcastFn, + message_repository=MessageRepository, +) -> Message | None: + """Store and broadcast a CHANNEL_MSG_RECV fallback channel message.""" + conversation_key_normalized = conversation_key.upper() + text = f"{sender_name}: {message_text}" if sender_name else message_text + + resolved_sender_key: str | None = None + if sender_name: + candidates = await ContactRepository.get_by_name(sender_name) + if len(candidates) == 1: + resolved_sender_key = candidates[0].public_key + + msg_id = await message_repository.create( + msg_type="CHAN", + text=text, + conversation_key=conversation_key_normalized, + sender_timestamp=sender_timestamp, + received_at=received_at, + path=path, + path_len=path_len, + txt_type=txt_type, + sender_name=sender_name, + sender_key=resolved_sender_key, + ) + if msg_id is None: + await handle_duplicate_message( + packet_id=None, + msg_type="CHAN", + conversation_key=conversation_key_normalized, + text=text, + sender_timestamp=sender_timestamp, + path=path, + received_at=received_at, + path_len=path_len, + broadcast_fn=broadcast_fn, + ) + return None + + message = build_message_model( + message_id=msg_id, + msg_type="CHAN", + conversation_key=conversation_key_normalized, + text=text, + sender_timestamp=sender_timestamp, + received_at=received_at, + paths=build_message_paths(path, received_at, path_len), + txt_type=txt_type, + sender_name=sender_name, + sender_key=resolved_sender_key, + channel_name=channel_name, + ) + broadcast_message(message=message, broadcast_fn=broadcast_fn) + return message + + async def create_outgoing_direct_message( *, conversation_key: str, diff --git a/app/services/radio_lifecycle.py b/app/services/radio_lifecycle.py index 4071830..c601a9a 100644 --- a/app/services/radio_lifecycle.py +++ b/app/services/radio_lifecycle.py @@ -137,6 +137,7 @@ async def run_post_connect_setup(radio_manager) -> None: drained = await drain_pending_messages(mc) if drained > 0: logger.info("Drained %d pending message(s)", drained) + radio_manager.clear_pending_message_channel_slots() await mc.start_auto_message_fetching() logger.info("Auto message fetching started") diff --git a/tests/test_radio_sync.py b/tests/test_radio_sync.py index d50e4aa..5966754 100644 --- a/tests/test_radio_sync.py +++ b/tests/test_radio_sync.py @@ -43,6 +43,7 @@ def reset_sync_state(): prev_connection_info = radio_manager._connection_info prev_slot_by_key = radio_manager._channel_slot_by_key.copy() prev_key_by_slot = radio_manager._channel_key_by_slot.copy() + prev_pending_channel_key_by_slot = radio_manager._pending_message_channel_key_by_slot.copy() radio_sync._polling_pause_count = 0 radio_sync._last_contact_sync = 0.0 @@ -55,6 +56,7 @@ def reset_sync_state(): radio_manager._connection_info = prev_connection_info radio_manager._channel_slot_by_key = prev_slot_by_key radio_manager._channel_key_by_slot = prev_key_by_slot + radio_manager._pending_message_channel_key_by_slot = prev_pending_channel_key_by_slot KEY_A = "aa" * 32 @@ -1091,6 +1093,120 @@ class TestSyncAndOffloadChannels: assert radio_manager.get_cached_channel_slot("AA" * 16) is None + @pytest.mark.asyncio + async def test_remembers_channel_slot_for_pending_message_recovery(self, test_db): + """Offload snapshots slot-to-key mapping for the later startup drain.""" + from app.radio_sync import sync_and_offload_channels + + channel_key = "11" * 16 + channel_result = MagicMock() + channel_result.type = EventType.CHANNEL_INFO + channel_result.payload = { + "channel_name": "#queued", + "channel_secret": bytes.fromhex(channel_key), + } + + empty_result = MagicMock() + empty_result.type = EventType.ERROR + + mock_mc = MagicMock() + mock_mc.commands.get_channel = AsyncMock(side_effect=[channel_result] + [empty_result] * 39) + mock_mc.commands.set_channel = AsyncMock(return_value=MagicMock(type=EventType.OK)) + + await sync_and_offload_channels(mock_mc) + + assert radio_manager.get_pending_message_channel_key(0) == channel_key.upper() + + +class TestPendingChannelMessageFallback: + """Queued CHANNEL_MSG_RECV events should be persisted instead of dropped.""" + + @pytest.mark.asyncio + async def test_drain_pending_messages_uses_snapshotted_slot_mapping_after_offload( + self, test_db + ): + """Startup drain can still store room traffic even after slots were cleared.""" + from app.radio_sync import drain_pending_messages + + channel_key = "22" * 16 + await ChannelRepository.upsert(key=channel_key, name="#queued") + radio_manager.remember_pending_message_channel_slot(channel_key, 3) + + channel_message = MagicMock() + channel_message.type = EventType.CHANNEL_MSG_RECV + channel_message.payload = { + "channel_idx": 3, + "text": "Alice: hello from queue", + "sender_timestamp": 1700000000, + "txt_type": 0, + "path": "aabb", + "path_len": 2, + } + + no_more = MagicMock() + no_more.type = EventType.NO_MORE_MSGS + no_more.payload = {} + + empty_slot = MagicMock() + empty_slot.type = EventType.ERROR + empty_slot.payload = {"error": "slot empty"} + + mock_mc = MagicMock() + mock_mc.commands.get_msg = AsyncMock(side_effect=[channel_message, no_more]) + mock_mc.commands.get_channel = AsyncMock(return_value=empty_slot) + + with patch("app.radio_sync.broadcast_event") as mock_broadcast: + drained = await drain_pending_messages(mock_mc) + + assert drained == 1 + stored = await MessageRepository.get_all(msg_type="CHAN", conversation_key=channel_key) + assert len(stored) == 1 + assert stored[0].text == "Alice: hello from queue" + assert stored[0].sender_name == "Alice" + assert stored[0].conversation_key == channel_key + assert stored[0].paths is not None + assert stored[0].paths[0].path == "aabb" + + mock_broadcast.assert_called_once() + + @pytest.mark.asyncio + async def test_poll_for_messages_stores_first_pending_channel_message(self, test_db): + """Single-pass polling stores the first queued channel message before draining.""" + from app.radio_sync import poll_for_messages + + channel_key = "33" * 16 + channel_result = MagicMock() + channel_result.type = EventType.CHANNEL_INFO + channel_result.payload = { + "channel_name": "#poll", + "channel_secret": bytes.fromhex(channel_key), + } + + channel_message = MagicMock() + channel_message.type = EventType.CHANNEL_MSG_RECV + channel_message.payload = { + "channel_idx": 1, + "text": "Bob: polled message", + "sender_timestamp": 1700000010, + "txt_type": 0, + } + + no_more = MagicMock() + no_more.type = EventType.NO_MORE_MSGS + no_more.payload = {} + + mock_mc = MagicMock() + mock_mc.commands.get_msg = AsyncMock(side_effect=[channel_message, no_more]) + mock_mc.commands.get_channel = AsyncMock(return_value=channel_result) + + with patch("app.radio_sync.broadcast_event"): + count = await poll_for_messages(mock_mc) + + assert count == 1 + stored = await MessageRepository.get_all(msg_type="CHAN", conversation_key=channel_key) + assert len(stored) == 1 + assert stored[0].text == "Bob: polled message" + class TestEnsureDefaultChannels: """Test ensure_default_channels: create/fix the Public channel."""