diff --git a/app/routers/server_control.py b/app/routers/server_control.py index eda007e..fdd715d 100644 --- a/app/routers/server_control.py +++ b/app/routers/server_control.py @@ -13,7 +13,11 @@ from app.models import ( Contact, RepeaterLoginResponse, ) -from app.radio_sync import _store_pending_channel_message, _store_pending_direct_message +from app.radio_sync import ( + _store_pending_channel_message, + _store_pending_direct_message, + drain_pending_messages, +) from app.routers.contacts import _ensure_on_radio from app.services.radio_runtime import radio_runtime as radio_manager @@ -83,59 +87,130 @@ def extract_response_text(event) -> str: return text +async def _flush_pending_messages(mc) -> None: + """Drain the radio's pending-message buffer before issuing a CLI command. + + A CLI response that arrived after a previous command already returned can + sit buffered in the radio. Without this flush, the next command's fetch + could pull that stale response and mis-attribute it as the new command's + answer (the firmware does not correlate responses to requests). Draining + first routes any real DMs/channel messages to storage and lets stale CLI + responses (txt_type=1) be dropped by ``event_handlers.on_contact_message``, + so they cannot be returned as this command's answer. + + This shrinks — but cannot fully eliminate — same-contact straddle + mis-attribution: a reply that is still in flight when we send can only be + bounded by a protocol-level request id, which the wire format lacks. + """ + try: + drained = await drain_pending_messages(mc) + if drained: + logger.debug("Flushed %d buffered message(s) before CLI send", drained) + except Exception: + logger.debug("Pre-send message flush failed", exc_info=True) + + async def fetch_contact_cli_response( mc, target_pubkey_prefix: str, timeout: float = 20.0, ) -> "Event | None": - """Fetch a CLI response from a specific contact via a validated get_msg() loop.""" - deadline = _monotonic() + timeout + """Fetch a CLI response (txt_type=1) from a specific contact. - while _monotonic() < deadline: - try: - result = await mc.commands.get_msg(timeout=2.0) - except TimeoutError: - continue - except Exception as exc: - logger.debug("get_msg() exception: %s", exc) - await asyncio.sleep(1.0) - continue + CLI responses arrive as ``CONTACT_MSG_RECV`` events, and the dispatcher + clones every such event to *all* subscribers. The permanent handler in + ``event_handlers.on_contact_message`` can therefore consume (and drop) a + response in the gap between this loop's ``get_msg`` polls, producing a + spurious timeout even though the response was delivered. - if result.type == EventType.NO_MORE_MSGS: - await asyncio.sleep(1.0) - continue + To close that race we hold a request-scoped subscription for the target's + CLI responses for the whole window. Whichever path observes the response + first wins — ``get_msg``'s return value on the happy path, or the + subscription when ``get_msg`` misses it — and the subscription is torn down + in ``finally`` so nothing outlives this call (no global state, so a late or + duplicate response cannot leak into an unrelated later fetch). - if result.type == EventType.ERROR: - logger.debug("get_msg() error: %s", result.payload) - await asyncio.sleep(1.0) - continue + ``get_msg`` is still polled to pump the radio into delivering buffered + frames and to route any unrelated DMs/channel messages to storage. + """ + loop = asyncio.get_running_loop() + response_future: asyncio.Future = loop.create_future() - if result.type == EventType.CONTACT_MSG_RECV: - msg_prefix = result.payload.get("pubkey_prefix", "") - txt_type = result.payload.get("txt_type", 0) - if msg_prefix == target_pubkey_prefix and txt_type == 1: - return result - logger.debug( - "Storing non-target DM (from=%s, txt_type=%d) consumed while waiting for %s", - msg_prefix, - txt_type, - target_pubkey_prefix, - ) - await _store_pending_direct_message(result) - continue + def _capture(event: "Event") -> None: + # Dispatcher invokes sync callbacks inline with a cloned event; the + # attribute filter guarantees this only fires for the target's CLI + # responses, so we resolve with the first one seen. + if not response_future.done(): + response_future.set_result(event) - if result.type == EventType.CHANNEL_MSG_RECV: - logger.debug( - "Storing channel message (channel_idx=%s) consumed during CLI fetch", - result.payload.get("channel_idx"), - ) - await _store_pending_channel_message(mc, result.payload) - continue + subscription = mc.subscribe( + EventType.CONTACT_MSG_RECV, + _capture, + attribute_filters={"pubkey_prefix": target_pubkey_prefix, "txt_type": 1}, + ) - logger.debug("Unexpected event type %s during CLI fetch, skipping", result.type) + try: + deadline = _monotonic() + timeout - logger.warning("No CLI response from contact %s within %.1fs", target_pubkey_prefix, timeout) - return None + while _monotonic() < deadline: + if response_future.done(): + return response_future.result() + + try: + result = await mc.commands.get_msg(timeout=2.0) + except TimeoutError: + continue + except Exception as exc: + logger.debug("get_msg() exception: %s", exc) + await asyncio.sleep(1.0) + continue + + if result.type == EventType.NO_MORE_MSGS: + # The subscription may have captured a late delivery the radio + # didn't hand back through this poll; prefer it over sleeping. + if response_future.done(): + return response_future.result() + await asyncio.sleep(1.0) + continue + + if result.type == EventType.ERROR: + logger.debug("get_msg() error: %s", result.payload) + await asyncio.sleep(1.0) + continue + + if result.type == EventType.CONTACT_MSG_RECV: + msg_prefix = result.payload.get("pubkey_prefix", "") + txt_type = result.payload.get("txt_type", 0) + if msg_prefix == target_pubkey_prefix and txt_type == 1: + return result + logger.debug( + "Storing non-target DM (from=%s, txt_type=%d) consumed while waiting for %s", + msg_prefix, + txt_type, + target_pubkey_prefix, + ) + await _store_pending_direct_message(result) + continue + + if result.type == EventType.CHANNEL_MSG_RECV: + logger.debug( + "Storing channel message (channel_idx=%s) consumed during CLI fetch", + result.payload.get("channel_idx"), + ) + await _store_pending_channel_message(mc, result.payload) + continue + + logger.debug("Unexpected event type %s during CLI fetch, skipping", result.type) + + # Final grace check in case a delivery raced the deadline. + if response_future.done(): + return response_future.result() + logger.warning( + "No CLI response from contact %s within %.1fs", target_pubkey_prefix, timeout + ) + return None + finally: + subscription.unsubscribe() async def prepare_authenticated_contact_connection( @@ -252,6 +327,10 @@ async def batch_cli_fetch( await _ensure_on_radio(mc, contact) await asyncio.sleep(1.0) # settle after add_contact + # Clear any stale buffered CLI response from a prior command so it + # cannot be pulled and mis-attributed to this one. + await _flush_pending_messages(mc) + send_result = await mc.commands.send_cmd(contact.public_key, cmd) if send_result.type == EventType.ERROR: logger.debug("Command '%s' send error: %s", cmd, send_result.payload) @@ -286,6 +365,10 @@ async def send_contact_cli_command( await _ensure_on_radio(mc, contact) await asyncio.sleep(1.0) + # Clear any stale buffered CLI response from a prior command so it + # cannot be pulled and mis-attributed to this one. + await _flush_pending_messages(mc) + logger.info("Sending command to %s %s: %s", label, contact.public_key[:12], command) send_result = await mc.commands.send_cmd(contact.public_key, command) diff --git a/tests/test_cli_stale_response_flush.py b/tests/test_cli_stale_response_flush.py new file mode 100644 index 0000000..74bd06c --- /dev/null +++ b/tests/test_cli_stale_response_flush.py @@ -0,0 +1,147 @@ +"""Tests for the pre-send CLI buffer flush. + +These exercise the *real* ``_flush_pending_messages`` (unlike the route tests in +``test_repeater_routes.py``, which neutralize it), including the regression +guard that a stale buffered CLI response is not mis-attributed to a later +command. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from meshcore import EventType + +from app.models import CommandRequest +from app.radio import radio_manager +from app.repository import ContactRepository +from app.routers import server_control +from app.routers.repeaters import send_repeater_command + +KEY_A = "aa" * 32 + +# Patch target for the wall-clock wrapper used by fetch_contact_cli_response. +_MONOTONIC = "app.routers.server_control._monotonic" + + +@pytest.fixture(autouse=True) +def _reset_radio_state(): + """Save/restore radio_manager state so tests don't leak.""" + prev = radio_manager._meshcore + prev_lock = radio_manager._operation_lock + yield + radio_manager._meshcore = prev + radio_manager._operation_lock = prev_lock + + +def _radio_result(event_type=EventType.OK, payload=None): + result = MagicMock() + result.type = event_type + result.payload = payload or {} + return result + + +def _advancing_clock(start=0.0, step=0.1): + t = start + + def _tick(): + nonlocal t + val = t + t += step + return val + + return _tick + + +def _mock_mc(): + mc = MagicMock() + mc.commands = MagicMock() + mc.commands.send_cmd = AsyncMock(return_value=_radio_result(EventType.OK)) + mc.commands.get_msg = AsyncMock(return_value=_radio_result(EventType.NO_MORE_MSGS)) + mc.commands.add_contact = AsyncMock(return_value=_radio_result(EventType.OK)) + mc.subscribe = MagicMock(return_value=MagicMock(unsubscribe=MagicMock())) + mc.stop_auto_message_fetching = AsyncMock() + mc.start_auto_message_fetching = AsyncMock() + return mc + + +async def _insert_contact(public_key: str, name: str = "Repeater", contact_type: int = 2): + await ContactRepository.upsert( + { + "public_key": public_key, + "name": name, + "type": contact_type, + "flags": 0, + "direct_path": None, + "direct_path_len": -1, + "direct_path_hash_mode": -1, + "last_advert": None, + "lat": None, + "lon": None, + "last_seen": None, + "on_radio": False, + "last_contacted": None, + "first_seen": None, + } + ) + + +class TestFlushPendingMessages: + @pytest.mark.asyncio + async def test_flush_drains_pending_buffer(self): + mc = _mock_mc() + with patch( + "app.routers.server_control.drain_pending_messages", + new_callable=AsyncMock, + return_value=2, + ) as drain: + await server_control._flush_pending_messages(mc) + + drain.assert_awaited_once_with(mc) + + @pytest.mark.asyncio + async def test_flush_swallows_drain_errors(self): + """A flaky radio mid-flush must not abort the command.""" + mc = _mock_mc() + with patch( + "app.routers.server_control.drain_pending_messages", + new_callable=AsyncMock, + side_effect=RuntimeError("radio gone"), + ): + # Must not raise. + await server_control._flush_pending_messages(mc) + + +class TestStaleResponseRegression: + @pytest.mark.asyncio + async def test_stale_buffered_cli_response_is_flushed_not_returned(self, test_db): + """A stale CLI response buffered before the command is drained, so the + fetch returns the fresh response rather than the leftover one. + + Without the pre-send flush, the fetch loop would pull ``stale`` (same + contact, txt_type=1) and return it as the answer to ``get lat``. + """ + mc = _mock_mc() + await _insert_contact(KEY_A) + + stale = _radio_result( + EventType.CONTACT_MSG_RECV, + {"pubkey_prefix": KEY_A[:12], "text": "stale-name", "txt_type": 1}, + ) + no_more = _radio_result(EventType.NO_MORE_MSGS) + fresh = _radio_result( + EventType.CONTACT_MSG_RECV, + {"pubkey_prefix": KEY_A[:12], "text": "fresh-lat", "txt_type": 1}, + ) + # Flush drains [stale, no_more]; the subsequent fetch then sees [fresh]. + mc.commands.get_msg = AsyncMock(side_effect=[stale, no_more, fresh]) + + with ( + patch("app.routers.repeaters.radio_manager.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + patch(_MONOTONIC, side_effect=_advancing_clock()), + patch("app.routers.server_control.asyncio.sleep", new_callable=AsyncMock), + ): + response = await send_repeater_command(KEY_A, CommandRequest(command="get lat")) + + assert response.command == "get lat" + assert response.response == "fresh-lat" diff --git a/tests/test_repeater_routes.py b/tests/test_repeater_routes.py index f759af5..8853871 100644 --- a/tests/test_repeater_routes.py +++ b/tests/test_repeater_routes.py @@ -47,6 +47,24 @@ def _reset_radio_state(): radio_manager._operation_lock = prev_lock +@pytest.fixture(autouse=True) +def _no_op_pre_send_flush(): + """Neutralize the pre-send buffer flush for command/batch route tests. + + ``_flush_pending_messages`` drains ``mc.commands.get_msg``, which the tests + in this module mock to return fetch responses; flushing here would consume + them. The flush behavior and its stale-response regression guard are covered + in ``test_cli_stale_response_flush.py``, which exercises the real flush. + Tests in ``TestFetchContactCliResponse`` call ``fetch_contact_cli_response`` + directly and never reach the flush, so this patch is a harmless no-op there. + """ + with patch( + "app.routers.server_control._flush_pending_messages", + new_callable=AsyncMock, + ): + yield + + def _radio_result(event_type=EventType.OK, payload=None): result = MagicMock() result.type = event_type @@ -285,6 +303,101 @@ class TestFetchContactCliResponse: assert mc.commands.get_msg.await_count == 21 assert store_dm.await_count == 20 + @pytest.mark.asyncio + async def test_subscription_captures_response_when_get_msg_misses_it(self): + """The drop race: get_msg never returns the response, but the + request-scoped subscription captures the cloned push event.""" + mc = _mock_mc() + captured: dict = {} + + def _fake_subscribe(event_type, callback, attribute_filters=None): + captured["cb"] = callback + return MagicMock(unsubscribe=MagicMock()) + + mc.subscribe = MagicMock(side_effect=_fake_subscribe) + + push_event = _radio_result( + EventType.CONTACT_MSG_RECV, + {"pubkey_prefix": "aaaaaaaaaaaa", "text": "ver 1.0", "txt_type": 1}, + ) + calls = {"n": 0} + + async def _fake_get_msg(timeout=None): + calls["n"] += 1 + if calls["n"] == 1: + # Simulate the response being delivered to the permanent + # subscriber path (our scoped subscription) rather than via + # this get_msg's return value. + captured["cb"](push_event) + return _radio_result(EventType.NO_MORE_MSGS) + + mc.commands.get_msg = _fake_get_msg + + with ( + patch(_MONOTONIC, side_effect=_advancing_clock()), + patch("app.routers.server_control.asyncio.sleep", new_callable=AsyncMock), + ): + result = await fetch_contact_cli_response(mc, "aaaaaaaaaaaa", timeout=5.0) + + assert result is not None + assert result.payload["text"] == "ver 1.0" + + @pytest.mark.asyncio + async def test_subscription_uses_target_and_cli_filter(self): + """The scoped subscription filters on the target prefix and txt_type=1.""" + mc = _mock_mc() + mc.commands.get_msg = AsyncMock( + return_value=_radio_result( + EventType.CONTACT_MSG_RECV, + {"pubkey_prefix": "aaaaaaaaaaaa", "text": "ok", "txt_type": 1}, + ) + ) + + with patch(_MONOTONIC, side_effect=_advancing_clock()): + await fetch_contact_cli_response(mc, "aaaaaaaaaaaa", timeout=5.0) + + args, kwargs = mc.subscribe.call_args + assert args[0] == EventType.CONTACT_MSG_RECV + assert kwargs["attribute_filters"] == { + "pubkey_prefix": "aaaaaaaaaaaa", + "txt_type": 1, + } + + @pytest.mark.asyncio + async def test_unsubscribes_on_success(self): + mc = _mock_mc() + sub = MagicMock(unsubscribe=MagicMock()) + mc.subscribe = MagicMock(return_value=sub) + mc.commands.get_msg = AsyncMock( + return_value=_radio_result( + EventType.CONTACT_MSG_RECV, + {"pubkey_prefix": "aaaaaaaaaaaa", "text": "ok", "txt_type": 1}, + ) + ) + + with patch(_MONOTONIC, side_effect=_advancing_clock()): + result = await fetch_contact_cli_response(mc, "aaaaaaaaaaaa", timeout=5.0) + + assert result is not None + sub.unsubscribe.assert_called_once() + + @pytest.mark.asyncio + async def test_unsubscribes_on_timeout(self): + mc = _mock_mc() + sub = MagicMock(unsubscribe=MagicMock()) + mc.subscribe = MagicMock(return_value=sub) + mc.commands.get_msg = AsyncMock(return_value=_radio_result(EventType.NO_MORE_MSGS)) + times = iter([100.0, 100.5, 101.0, 103.0]) + + with ( + patch(_MONOTONIC, side_effect=times), + patch("app.routers.server_control.asyncio.sleep", new_callable=AsyncMock), + ): + result = await fetch_contact_cli_response(mc, "aaaaaaaaaaaa", timeout=2.0) + + assert result is None + sub.unsubscribe.assert_called_once() + class TestRepeaterCommandRoute: @pytest.mark.asyncio