From 152eab99db32a6f882d187726a121e89a14f0436 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 23 Feb 2026 19:11:58 -0800 Subject: [PATCH] More stable MC object reference and proper radio disconnection detection --- app/main.py | 12 +++- app/radio.py | 27 +++++++-- app/radio_sync.py | 4 +- app/routers/channels.py | 4 +- app/routers/contacts.py | 31 +++++----- app/routers/messages.py | 30 +++++----- app/routers/radio.py | 17 +++--- tests/test_api.py | 39 +++++++++++++ tests/test_contacts_router.py | 29 +++++++-- tests/test_radio_operation.py | 50 +++++++++++++++- tests/test_radio_router.py | 47 +++++++++++---- tests/test_radio_sync.py | 107 ++++++++++++++++++---------------- tests/test_repeater_routes.py | 39 ++++++++++++- tests/test_send_messages.py | 32 +++++++++- 14 files changed, 342 insertions(+), 126 deletions(-) diff --git a/app/main.py b/app/main.py index 0cecd60..62372b4 100644 --- a/app/main.py +++ b/app/main.py @@ -2,13 +2,14 @@ import logging from contextlib import asynccontextmanager from pathlib import Path -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from app.config import setup_logging from app.database import db from app.frontend_static import register_frontend_static_routes -from app.radio import radio_manager +from app.radio import RadioDisconnectedError, radio_manager from app.radio_sync import ( stop_message_polling, stop_periodic_advert, @@ -75,6 +76,13 @@ app.add_middleware( allow_headers=["*"], ) + +@app.exception_handler(RadioDisconnectedError) +async def radio_disconnected_handler(request: Request, exc: RadioDisconnectedError): + """Return 503 when a radio disconnect race occurs during an operation.""" + return JSONResponse(status_code=503, content={"detail": "Radio not connected"}) + + # API routes - all prefixed with /api for production compatibility app.include_router(health.router, prefix="/api") app.include_router(radio.router, prefix="/api") diff --git a/app/radio.py b/app/radio.py index b0faa9c..116089c 100644 --- a/app/radio.py +++ b/app/radio.py @@ -20,6 +20,10 @@ class RadioOperationBusyError(RadioOperationError): """Raised when a non-blocking radio operation cannot acquire the lock.""" +class RadioDisconnectedError(RadioOperationError): + """Raised when the radio disconnects between pre-check and lock acquisition.""" + + @asynccontextmanager async def _noop_context(): """No-op async context manager for optional nesting.""" @@ -166,37 +170,48 @@ class RadioManager: pause_polling: bool = False, suspend_auto_fetch: bool = False, blocking: bool = True, - meshcore: MeshCore | None = None, ): """Acquire shared radio lock and optionally pause polling / auto-fetch. + After acquiring the lock, resolves the current MeshCore instance and + yields it. Callers get a fresh reference via ``async with ... as mc:``, + avoiding stale-reference bugs when a reconnect swaps ``_meshcore`` + between the pre-check and the lock acquisition. + Args: name: Human-readable operation name for logs/errors. pause_polling: Pause fallback message polling while held. suspend_auto_fetch: Stop MeshCore auto message fetching while held. blocking: If False, fail immediately when lock is held. - meshcore: Optional explicit MeshCore instance for auto-fetch control. + + Raises: + RadioDisconnectedError: If the radio disconnected before the lock + was acquired (``_meshcore`` is ``None``). """ await self._acquire_operation_lock(name, blocking=blocking) + mc = self._meshcore + if mc is None: + self._release_operation_lock(name) + raise RadioDisconnectedError("Radio disconnected") + poll_context = _noop_context() if pause_polling: from app.radio_sync import pause_polling as pause_polling_context poll_context = pause_polling_context() - mc = meshcore or self._meshcore auto_fetch_paused = False try: async with poll_context: - if suspend_auto_fetch and mc is not None: + if suspend_auto_fetch: await mc.stop_auto_message_fetching() auto_fetch_paused = True - yield + yield mc finally: try: - if auto_fetch_paused and mc is not None: + if auto_fetch_paused: try: await mc.start_auto_message_fetching() except Exception as e: diff --git a/app/radio_sync.py b/app/radio_sync.py index 86252ae..5bfea74 100644 --- a/app/radio_sync.py +++ b/app/radio_sync.py @@ -573,13 +573,11 @@ async def sync_recent_contacts_to_radio(force: bool = False) -> dict: logger.debug("Cannot sync contacts to radio: not connected") return {"loaded": 0, "error": "Radio not connected"} - mc = radio_manager.meshcore - try: async with radio_manager.radio_operation( "sync_recent_contacts_to_radio", blocking=False, - ): + ) as mc: _last_contact_sync = now # Build prioritized contact list: diff --git a/app/routers/channels.py b/app/routers/channels.py index b2f4c2a..6b95554 100644 --- a/app/routers/channels.py +++ b/app/routers/channels.py @@ -85,12 +85,12 @@ async def create_channel(request: CreateChannelRequest) -> Channel: @router.post("/sync") async def sync_channels_from_radio(max_channels: int = Query(default=40, ge=1, le=40)) -> dict: """Sync channels from the radio to the database.""" - mc = require_connected() + require_connected() logger.info("Syncing channels from radio (checking %d slots)", max_channels) count = 0 - async with radio_manager.radio_operation("sync_channels_from_radio"): + async with radio_manager.radio_operation("sync_channels_from_radio") as mc: for idx in range(max_channels): result = await mc.commands.get_channel(idx) diff --git a/app/routers/contacts.py b/app/routers/contacts.py index 2a374f7..7d37f60 100644 --- a/app/routers/contacts.py +++ b/app/routers/contacts.py @@ -264,11 +264,11 @@ async def get_contact(public_key: str) -> Contact: @router.post("/sync") async def sync_contacts_from_radio() -> dict: """Sync contacts from the radio to the database.""" - mc = require_connected() + require_connected() logger.info("Syncing contacts from radio") - async with radio_manager.radio_operation("sync_contacts_from_radio", meshcore=mc): + async with radio_manager.radio_operation("sync_contacts_from_radio") as mc: result = await mc.commands.get_contacts() if result.type == EventType.ERROR: @@ -294,11 +294,11 @@ async def sync_contacts_from_radio() -> dict: @router.post("/{public_key}/remove-from-radio") async def remove_contact_from_radio(public_key: str) -> dict: """Remove a contact from the radio (keeps it in database).""" - mc = require_connected() + require_connected() contact = await _resolve_contact_or_404(public_key) - async with radio_manager.radio_operation("remove_contact_from_radio", meshcore=mc): + async with radio_manager.radio_operation("remove_contact_from_radio") as mc: # Get the contact from radio radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) if not radio_contact: @@ -322,11 +322,11 @@ async def remove_contact_from_radio(public_key: str) -> dict: @router.post("/{public_key}/add-to-radio") async def add_contact_to_radio(public_key: str) -> dict: """Add a contact from the database to the radio.""" - mc = require_connected() + require_connected() contact = await _resolve_contact_or_404(public_key, "Contact not found in database") - async with radio_manager.radio_operation("add_contact_to_radio", meshcore=mc): + async with radio_manager.radio_operation("add_contact_to_radio") as mc: # Check if already on radio radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) if radio_contact: @@ -361,9 +361,8 @@ async def delete_contact(public_key: str) -> dict: contact = await _resolve_contact_or_404(public_key) # Remove from radio if connected and contact is on radio - if radio_manager.is_connected and radio_manager.meshcore: - mc = radio_manager.meshcore - async with radio_manager.radio_operation("delete_contact_from_radio", meshcore=mc): + if radio_manager.is_connected: + async with radio_manager.radio_operation("delete_contact_from_radio") as mc: radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) if radio_contact: logger.info( @@ -385,7 +384,7 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem The contact must be a repeater (type=2). If not on the radio, it will be added. Uses login + status request with retry logic. """ - mc = require_connected() + require_connected() # Get contact from database contact = await _resolve_contact_or_404(public_key) @@ -399,10 +398,9 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem async with radio_manager.radio_operation( "request_telemetry", - meshcore=mc, pause_polling=True, suspend_auto_fetch=True, - ): + ) as mc: # Prepare connection (add/remove dance + login) await prepare_repeater_connection(mc, contact, request.password) @@ -552,7 +550,7 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com - reboot - ver """ - mc = require_connected() + require_connected() # Get contact from database contact = await _resolve_contact_or_404(public_key) @@ -566,10 +564,9 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com async with radio_manager.radio_operation( "send_repeater_command", - meshcore=mc, pause_polling=True, suspend_auto_fetch=True, - ): + ) as mc: # Add contact to radio with path from DB logger.info("Adding repeater %s to radio", contact.public_key[:12]) await mc.commands.add_contact(contact.to_radio_dict()) @@ -621,7 +618,7 @@ async def request_trace(public_key: str) -> TraceResponse: (no intermediate repeaters). The radio firmware requires at least one node in the path. """ - mc = require_connected() + require_connected() contact = await _resolve_contact_or_404(public_key) @@ -631,7 +628,7 @@ async def request_trace(public_key: str) -> TraceResponse: # Trace does not need auto-fetch suspension: response arrives as TRACE_DATA # from the reader loop, not via get_msg(). - async with radio_manager.radio_operation("request_trace", pause_polling=True): + async with radio_manager.radio_operation("request_trace", pause_polling=True) as mc: # Ensure contact is on radio so the trace can reach them await mc.commands.add_contact(contact.to_radio_dict()) diff --git a/app/routers/messages.py b/app/routers/messages.py index 314ce22..96b4c22 100644 --- a/app/routers/messages.py +++ b/app/routers/messages.py @@ -43,7 +43,7 @@ async def list_messages( @router.post("/direct", response_model=Message) async def send_direct_message(request: SendDirectMessageRequest) -> Message: """Send a direct message to a contact.""" - mc = require_connected() + require_connected() # First check our database for the contact from app.repository import ContactRepository @@ -69,7 +69,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message: # so we can't rely on it to know if the firmware has the contact. # add_contact is idempotent - updates if exists, adds if not. contact_data = db_contact.to_radio_dict() - async with radio_manager.radio_operation("send_direct_message"): + async with radio_manager.radio_operation("send_direct_message") as mc: logger.debug("Ensuring contact %s is on radio before sending", db_contact.public_key[:12]) add_result = await mc.commands.add_contact(contact_data) if add_result.type == EventType.ERROR: @@ -163,7 +163,7 @@ TEMP_RADIO_SLOT = 0 @router.post("/channel", response_model=Message) async def send_channel_message(request: SendChannelMessageRequest) -> Message: """Send a message to a channel.""" - mc = require_connected() + require_connected() # Get channel info from our database from app.decoder import calculate_channel_hash @@ -192,12 +192,14 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message: expected_hash, ) channel_key_upper = request.channel_key.upper() - radio_name = mc.self_info.get("name", "") if mc.self_info else "" - text_with_sender = f"{radio_name}: {request.text}" if radio_name else request.text message_id: int | None = None now: int | None = None + radio_name: str = "" + text_with_sender: str = request.text - async with radio_manager.radio_operation("send_channel_message"): + async with radio_manager.radio_operation("send_channel_message") as mc: + radio_name = mc.self_info.get("name", "") if mc.self_info else "" + text_with_sender = f"{radio_name}: {request.text}" if radio_name else request.text # Load the channel to a temporary radio slot before sending set_result = await mc.commands.set_channel( channel_idx=TEMP_RADIO_SLOT, @@ -318,7 +320,7 @@ async def resend_channel_message( When new_timestamp=True: resend with a fresh timestamp so repeaters treat it as a new packet. Creates a new message row in the database. No time window restriction. """ - mc = require_connected() + require_connected() from app.repository import ChannelRepository @@ -352,12 +354,6 @@ async def resend_channel_message( else: timestamp_bytes = msg.sender_timestamp.to_bytes(4, "little") - # Strip sender prefix: DB stores "RadioName: message" but radio needs "message" - radio_name = mc.self_info.get("name", "") if mc.self_info else "" - text_to_send = msg.text - if radio_name and text_to_send.startswith(f"{radio_name}: "): - text_to_send = text_to_send[len(f"{radio_name}: ") :] - try: key_bytes = bytes.fromhex(msg.conversation_key) except ValueError: @@ -365,7 +361,13 @@ async def resend_channel_message( status_code=400, detail=f"Invalid channel key format: {msg.conversation_key}" ) from None - async with radio_manager.radio_operation("resend_channel_message"): + async with radio_manager.radio_operation("resend_channel_message") as mc: + # Strip sender prefix: DB stores "RadioName: message" but radio needs "message" + radio_name = mc.self_info.get("name", "") if mc.self_info else "" + text_to_send = msg.text + if radio_name and text_to_send.startswith(f"{radio_name}: "): + text_to_send = text_to_send[len(f"{radio_name}: ") :] + set_result = await mc.commands.set_channel( channel_idx=TEMP_RADIO_SLOT, channel_name=db_channel.name, diff --git a/app/routers/radio.py b/app/routers/radio.py index de6fdb7..00fd530 100644 --- a/app/routers/radio.py +++ b/app/routers/radio.py @@ -70,9 +70,9 @@ async def get_radio_config() -> RadioConfigResponse: @router.patch("/config", response_model=RadioConfigResponse) async def update_radio_config(update: RadioConfigUpdate) -> RadioConfigResponse: """Update radio configuration. Only provided fields will be updated.""" - mc = require_connected() + require_connected() - async with radio_manager.radio_operation("update_radio_config"): + async with radio_manager.radio_operation("update_radio_config") as mc: if update.name is not None: logger.info("Setting radio name to %s", update.name) await mc.commands.set_name(update.name) @@ -117,7 +117,7 @@ async def update_radio_config(update: RadioConfigUpdate) -> RadioConfigResponse: @router.put("/private-key") async def set_private_key(update: PrivateKeyUpdate) -> dict: """Set the radio's private key. This is write-only.""" - mc = require_connected() + require_connected() try: key_bytes = bytes.fromhex(update.private_key) @@ -125,7 +125,7 @@ async def set_private_key(update: PrivateKeyUpdate) -> dict: raise HTTPException(status_code=400, detail="Invalid hex string for private key") from None logger.info("Importing private key") - async with radio_manager.radio_operation("import_private_key", meshcore=mc): + async with radio_manager.radio_operation("import_private_key") as mc: result = await mc.commands.import_private_key(key_bytes) if result.type == EventType.ERROR: @@ -167,13 +167,10 @@ async def reboot_radio() -> dict: If not connected: attempts to reconnect (same as /reconnect endpoint). """ # If connected, send reboot command - if radio_manager.is_connected and radio_manager.meshcore: + if radio_manager.is_connected: logger.info("Rebooting radio") - async with radio_manager.radio_operation( - "reboot_radio", - meshcore=radio_manager.meshcore, - ): - await radio_manager.meshcore.commands.reboot() + async with radio_manager.radio_operation("reboot_radio") as mc: + await mc.commands.reboot() return { "status": "ok", "message": "Reboot command sent. Radio will reconnect automatically.", diff --git a/tests/test_api.py b/tests/test_api.py index 9cb92cd..66782ab 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -12,6 +12,7 @@ import httpx import pytest from app.database import Database +from app.radio import radio_manager from app.repository import ( ChannelRepository, ContactRepository, @@ -20,6 +21,16 @@ from app.repository import ( ) +@pytest.fixture(autouse=True) +def _reset_radio_state(): + """Save/restore radio_manager state so tests don't leak.""" + prev = radio_manager._meshcore + prev_lock = radio_manager._operation_lock + yield + radio_manager._meshcore = prev + radio_manager._operation_lock = prev_lock + + @pytest.fixture async def test_db(): """Create an in-memory test database with schema + migrations.""" @@ -109,6 +120,29 @@ class TestHealthEndpoint: assert data["connection_info"] is None +class TestRadioDisconnectedHandler: + """Test that RadioDisconnectedError maps to 503.""" + + @pytest.mark.asyncio + async def test_disconnect_race_returns_503(self, test_db, client): + """If radio disconnects between require_connected() and lock acquisition, return 503.""" + pub_key = "ab" * 32 + await _insert_contact(pub_key, "Alice") + + # require_connected() passes, but _meshcore is None when radio_operation() checks + radio_manager._meshcore = None + with patch("app.dependencies.radio_manager") as mock_rm: + mock_rm.is_connected = True + mock_rm.meshcore = MagicMock() + + response = await client.post( + "/api/messages/direct", json={"destination": pub_key, "text": "Hi"} + ) + + assert response.status_code == 503 + assert "not connected" in response.json()["detail"].lower() + + class TestMessagesEndpoint: """Test message-related endpoints.""" @@ -161,6 +195,7 @@ class TestMessagesEndpoint: coro.close() return MagicMock() + radio_manager._meshcore = mock_mc with ( patch("app.dependencies.radio_manager") as mock_rm, patch("app.bot.run_bot_for_message", new=AsyncMock()), @@ -204,6 +239,7 @@ class TestMessagesEndpoint: coro.close() return MagicMock() + radio_manager._meshcore = mock_mc with ( patch("app.dependencies.radio_manager") as mock_rm, patch("app.decoder.calculate_channel_hash", return_value="abcd"), @@ -260,6 +296,7 @@ class TestMessagesEndpoint: return_value=MagicMock(type=MagicMock(name="OK"), payload={"expected_ack": b"\x00\x01"}) ) + radio_manager._meshcore = mock_mc with ( patch("app.dependencies.radio_manager") as mock_rm, patch("app.routers.messages.MessageRepository") as mock_msg_repo, @@ -296,6 +333,7 @@ class TestMessagesEndpoint: return_value=MagicMock(type=MagicMock(name="OK"), payload={}) ) + radio_manager._meshcore = mock_mc with ( patch("app.dependencies.radio_manager") as mock_rm, patch("app.routers.messages.MessageRepository") as mock_msg_repo, @@ -355,6 +393,7 @@ class TestMessagesEndpoint: return_value=MagicMock(type=EventType.MSG_SENT, payload={}) ) + radio_manager._meshcore = mock_mc with patch("app.dependencies.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc diff --git a/tests/test_contacts_router.py b/tests/test_contacts_router.py index 76c6073..29dded3 100644 --- a/tests/test_contacts_router.py +++ b/tests/test_contacts_router.py @@ -14,6 +14,7 @@ import pytest from meshcore import EventType from app.database import Database +from app.radio import radio_manager from app.repository import ContactRepository, MessageRepository # Sample 64-char hex public keys for testing @@ -22,9 +23,24 @@ KEY_B = "bb" * 32 # bbbb...bb KEY_C = "cc" * 32 # cccc...cc -@asynccontextmanager -async def _noop_radio_operation(*_args, **_kwargs): +def _noop_radio_operation(mc=None): + """Factory for a no-op radio_operation context manager that yields mc.""" + + @asynccontextmanager + async def _ctx(*_args, **_kwargs): + yield mc + + return _ctx + + +@pytest.fixture(autouse=True) +def _reset_radio_state(): + """Save/restore radio_manager state so tests don't leak.""" + prev = radio_manager._meshcore + prev_lock = radio_manager._operation_lock yield + radio_manager._meshcore = prev + radio_manager._operation_lock = prev_lock @pytest.fixture @@ -225,7 +241,7 @@ class TestDeleteContact: with patch("app.routers.contacts.radio_manager") as mock_rm: mock_rm.is_connected = False mock_rm.meshcore = None - mock_rm.radio_operation = _noop_radio_operation + mock_rm.radio_operation = _noop_radio_operation() response = await client.delete(f"/api/contacts/{KEY_A}") @@ -255,7 +271,7 @@ class TestDeleteContact: with patch("app.routers.contacts.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc - mock_rm.radio_operation = _noop_radio_operation + mock_rm.radio_operation = _noop_radio_operation(mock_mc) response = await client.delete(f"/api/contacts/{KEY_A}") @@ -277,6 +293,7 @@ class TestSyncContacts: } mock_mc.commands.get_contacts = AsyncMock(return_value=mock_result) + radio_manager._meshcore = mock_mc with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc @@ -318,6 +335,7 @@ class TestSyncContacts: mock_result.payload = {KEY_A: {"adv_name": "Alice", "type": 1, "flags": 0}} mock_mc.commands.get_contacts = AsyncMock(return_value=mock_result) + radio_manager._meshcore = mock_mc with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc @@ -345,6 +363,7 @@ class TestAddRemoveRadio: mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) + radio_manager._meshcore = mock_mc with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc @@ -366,6 +385,7 @@ class TestAddRemoveRadio: mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # On radio + radio_manager._meshcore = mock_mc with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc @@ -386,6 +406,7 @@ class TestAddRemoveRadio: mock_result.type = EventType.OK mock_mc.commands.remove_contact = AsyncMock(return_value=mock_result) + radio_manager._meshcore = mock_mc with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc diff --git a/tests/test_radio_operation.py b/tests/test_radio_operation.py index 7729267..248447d 100644 --- a/tests/test_radio_operation.py +++ b/tests/test_radio_operation.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from app.radio import RadioOperationBusyError, radio_manager +from app.radio import RadioDisconnectedError, RadioOperationBusyError, radio_manager from app.radio_sync import is_polling_paused @@ -14,7 +14,9 @@ def reset_radio_operation_state(): """Reset shared radio operation lock state before/after each test.""" prev_meshcore = radio_manager._meshcore radio_manager._operation_lock = None - radio_manager._meshcore = None + # Default to a non-None MagicMock so radio_operation() doesn't raise + # RadioDisconnectedError for tests that only exercise locking. + radio_manager._meshcore = MagicMock() import app.radio_sync as radio_sync @@ -123,3 +125,47 @@ class TestRadioOperationLock: assert is_polling_paused() assert not is_polling_paused() + + +class TestRadioOperationYield: + """Validate that radio_operation() yields the current meshcore instance.""" + + @pytest.mark.asyncio + async def test_radio_operation_yields_current_meshcore(self): + """The yielded value is the current _meshcore at lock-acquisition time.""" + mc = MagicMock() + radio_manager._meshcore = mc + + async with radio_manager.radio_operation("test_yield") as yielded: + assert yielded is mc + + @pytest.mark.asyncio + async def test_radio_operation_raises_when_disconnected_after_lock(self): + """RadioDisconnectedError is raised when _meshcore is None after acquiring the lock.""" + radio_manager._meshcore = None + + with pytest.raises(RadioDisconnectedError): + async with radio_manager.radio_operation("test_disconnected"): + pass # pragma: no cover + + # Lock must be released even after the error + radio_manager._meshcore = MagicMock() + async with radio_manager.radio_operation("after_error", blocking=False): + pass + + @pytest.mark.asyncio + async def test_radio_operation_yields_fresh_reference_after_swap(self): + """If _meshcore is swapped between pre-check and lock acquisition, + the yielded value is the new (current) instance, not the old one.""" + old_mc = MagicMock(name="old") + new_mc = MagicMock(name="new") + + # Start with old_mc + radio_manager._meshcore = old_mc + + # Simulate a reconnect swapping _meshcore before the caller enters the block + radio_manager._meshcore = new_mc + + async with radio_manager.radio_operation("test_swap") as yielded: + assert yielded is new_mc + assert yielded is not old_mc diff --git a/tests/test_radio_router.py b/tests/test_radio_router.py index 639e47c..dd97870 100644 --- a/tests/test_radio_router.py +++ b/tests/test_radio_router.py @@ -8,7 +8,7 @@ import pytest from fastapi import HTTPException from meshcore import EventType -from app.radio import RadioManager +from app.radio import RadioManager, radio_manager from app.routers.radio import ( PrivateKeyUpdate, RadioConfigResponse, @@ -30,9 +30,24 @@ def _radio_result(event_type=EventType.OK, payload=None): return result -@asynccontextmanager -async def _noop_radio_operation(*_args, **_kwargs): +def _noop_radio_operation(mc=None): + """Factory for a no-op radio_operation context manager that yields mc.""" + + @asynccontextmanager + async def _ctx(*_args, **_kwargs): + yield mc + + return _ctx + + +@pytest.fixture(autouse=True) +def _reset_radio_state(): + """Save/restore radio_manager state so tests don't leak.""" + prev = radio_manager._meshcore + prev_lock = radio_manager._operation_lock yield + radio_manager._meshcore = prev + radio_manager._operation_lock = prev_lock def _mock_meshcore_with_info(): @@ -100,6 +115,7 @@ class TestUpdateRadioConfig: with ( patch("app.routers.radio.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.radio.sync_radio_time", new_callable=AsyncMock) as mock_sync_time, patch( "app.routers.radio.get_radio_config", new_callable=AsyncMock, return_value=expected @@ -132,7 +148,10 @@ class TestPrivateKeyImport: mc.commands.import_private_key = AsyncMock( return_value=_radio_result(EventType.ERROR, {"error": "failed"}) ) - with patch("app.routers.radio.require_connected", return_value=mc): + with ( + patch("app.routers.radio.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): with pytest.raises(HTTPException) as exc: await set_private_key(PrivateKeyUpdate(private_key="aa" * 64)) @@ -142,6 +161,7 @@ class TestPrivateKeyImport: class TestAdvertise: @pytest.mark.asyncio async def test_raises_when_send_fails(self): + radio_manager._meshcore = MagicMock() with ( patch("app.routers.radio.require_connected"), patch( @@ -170,6 +190,7 @@ class TestAdvertise: return True isolated_manager = RadioManager() + isolated_manager._meshcore = MagicMock() with ( patch("app.routers.radio.require_connected"), patch("app.routers.radio.radio_manager", isolated_manager), @@ -187,17 +208,19 @@ class TestAdvertise: class TestRebootAndReconnect: @pytest.mark.asyncio async def test_reboot_connected_sends_reboot_command(self): + mock_mc = MagicMock() + mock_mc.commands.reboot = AsyncMock() + mock_rm = MagicMock() mock_rm.is_connected = True - mock_rm.meshcore = MagicMock() - mock_rm.meshcore.commands.reboot = AsyncMock() - mock_rm.radio_operation = _noop_radio_operation + mock_rm.meshcore = mock_mc + mock_rm.radio_operation = _noop_radio_operation(mock_mc) with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() assert result["status"] == "ok" - mock_rm.meshcore.commands.reboot.assert_awaited_once() + mock_mc.commands.reboot.assert_awaited_once() @pytest.mark.asyncio async def test_reboot_returns_pending_when_reconnect_in_progress(self): @@ -205,7 +228,7 @@ class TestRebootAndReconnect: mock_rm.is_connected = False mock_rm.meshcore = None mock_rm.is_reconnecting = True - mock_rm.radio_operation = _noop_radio_operation + mock_rm.radio_operation = _noop_radio_operation() with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() @@ -221,7 +244,7 @@ class TestRebootAndReconnect: mock_rm.is_reconnecting = False mock_rm.reconnect = AsyncMock(return_value=True) mock_rm.post_connect_setup = AsyncMock() - mock_rm.radio_operation = _noop_radio_operation + mock_rm.radio_operation = _noop_radio_operation() with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() @@ -235,7 +258,7 @@ class TestRebootAndReconnect: async def test_reconnect_returns_already_connected(self): mock_rm = MagicMock() mock_rm.is_connected = True - mock_rm.radio_operation = _noop_radio_operation + mock_rm.radio_operation = _noop_radio_operation() with patch("app.routers.radio.radio_manager", mock_rm): result = await reconnect_radio() @@ -249,7 +272,7 @@ class TestRebootAndReconnect: mock_rm.is_connected = False mock_rm.is_reconnecting = False mock_rm.reconnect = AsyncMock(return_value=False) - mock_rm.radio_operation = _noop_radio_operation + mock_rm.radio_operation = _noop_radio_operation() with patch("app.routers.radio.radio_manager", mock_rm): with pytest.raises(HTTPException) as exc: diff --git a/tests/test_radio_sync.py b/tests/test_radio_sync.py index 0ea602f..1752c5b 100644 --- a/tests/test_radio_sync.py +++ b/tests/test_radio_sync.py @@ -11,6 +11,7 @@ from meshcore import EventType from app.database import Database from app.models import Favorite +from app.radio import radio_manager from app.radio_sync import ( is_polling_paused, pause_polling, @@ -45,14 +46,19 @@ async def test_db(): @pytest.fixture(autouse=True) def reset_sync_state(): - """Reset polling pause state and sync timestamp before and after each test.""" + """Reset polling pause state, sync timestamp, and radio_manager before/after each test.""" import app.radio_sync as radio_sync + prev_mc = radio_manager._meshcore + prev_lock = radio_manager._operation_lock + radio_sync._polling_pause_count = 0 radio_sync._last_contact_sync = 0.0 yield radio_sync._polling_pause_count = 0 radio_sync._last_contact_sync = 0.0 + radio_manager._meshcore = prev_mc + radio_manager._operation_lock = prev_lock KEY_A = "aa" * 32 @@ -236,11 +242,8 @@ class TestSyncRecentContactsToRadio: mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc - - result = await sync_recent_contacts_to_radio() + radio_manager._meshcore = mock_mc + result = await sync_recent_contacts_to_radio() assert result["loaded"] == 2 # Verify contacts are now marked as on_radio in DB @@ -268,11 +271,8 @@ class TestSyncRecentContactsToRadio: mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc - - result = await sync_recent_contacts_to_radio() + radio_manager._meshcore = mock_mc + result = await sync_recent_contacts_to_radio() assert result["loaded"] == 2 # KEY_A (favorite) should be loaded first, then KEY_B (most recent) @@ -298,11 +298,8 @@ class TestSyncRecentContactsToRadio: mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc - - result = await sync_recent_contacts_to_radio() + radio_manager._meshcore = mock_mc + result = await sync_recent_contacts_to_radio() assert result["loaded"] == 2 loaded_keys = [ @@ -319,11 +316,8 @@ class TestSyncRecentContactsToRadio: mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # Found mock_mc.commands.add_contact = AsyncMock() - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc - - result = await sync_recent_contacts_to_radio() + radio_manager._meshcore = mock_mc + result = await sync_recent_contacts_to_radio() assert result["loaded"] == 0 assert result["already_on_radio"] == 1 @@ -335,34 +329,30 @@ class TestSyncRecentContactsToRadio: mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None) - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc + radio_manager._meshcore = mock_mc - # First call succeeds - result1 = await sync_recent_contacts_to_radio() - assert "throttled" not in result1 + # First call succeeds + result1 = await sync_recent_contacts_to_radio() + assert "throttled" not in result1 - # Second call is throttled - result2 = await sync_recent_contacts_to_radio() - assert result2["throttled"] is True - assert result2["loaded"] == 0 + # Second call is throttled + result2 = await sync_recent_contacts_to_radio() + assert result2["throttled"] is True + assert result2["loaded"] == 0 @pytest.mark.asyncio async def test_force_bypasses_throttle(self, test_db): """force=True bypasses the throttle window.""" mock_mc = MagicMock() - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc + radio_manager._meshcore = mock_mc - # First call - await sync_recent_contacts_to_radio() + # First call + await sync_recent_contacts_to_radio() - # Forced second call is not throttled - result = await sync_recent_contacts_to_radio(force=True) - assert "throttled" not in result + # Forced second call is not throttled + result = await sync_recent_contacts_to_radio(force=True) + assert "throttled" not in result @pytest.mark.asyncio async def test_not_connected_returns_error(self): @@ -384,11 +374,8 @@ class TestSyncRecentContactsToRadio: mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # Found - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc - - result = await sync_recent_contacts_to_radio() + radio_manager._meshcore = mock_mc + result = await sync_recent_contacts_to_radio() assert result["already_on_radio"] == 1 # Should update the flag since contact.on_radio was False @@ -407,15 +394,37 @@ class TestSyncRecentContactsToRadio: mock_result.payload = {"error": "Radio full"} mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - with patch("app.radio_sync.radio_manager") as mock_rm: - mock_rm.is_connected = True - mock_rm.meshcore = mock_mc - - result = await sync_recent_contacts_to_radio() + radio_manager._meshcore = mock_mc + result = await sync_recent_contacts_to_radio() assert result["loaded"] == 0 assert result["failed"] == 1 + @pytest.mark.asyncio + async def test_uses_post_lock_meshcore_after_swap(self, test_db): + """If _meshcore is swapped between pre-check and lock acquisition, + the function uses the new (post-lock) instance, not the stale one.""" + await _insert_contact(KEY_A, "Alice", last_contacted=2000) + + old_mc = MagicMock(name="old_mc") + new_mc = MagicMock(name="new_mc") + new_mc.get_contact_by_key_prefix = MagicMock(return_value=None) + mock_result = MagicMock() + mock_result.type = EventType.OK + new_mc.commands.add_contact = AsyncMock(return_value=mock_result) + + # Pre-check sees old_mc (truthy, passes is_connected guard) + radio_manager._meshcore = old_mc + # Simulate reconnect swapping _meshcore before lock acquisition + radio_manager._meshcore = new_mc + + result = await sync_recent_contacts_to_radio() + + assert result["loaded"] == 1 + # new_mc was used, not old_mc + new_mc.commands.add_contact.assert_called_once() + old_mc.commands.add_contact.assert_not_called() + class TestSyncAndOffloadContacts: """Test sync_and_offload_contacts: pull contacts from radio, save to DB, remove from radio.""" diff --git a/tests/test_repeater_routes.py b/tests/test_repeater_routes.py index 5c83cd3..bbaa3c7 100644 --- a/tests/test_repeater_routes.py +++ b/tests/test_repeater_routes.py @@ -8,6 +8,7 @@ from meshcore import EventType from app.database import Database from app.models import CommandRequest, TelemetryRequest +from app.radio import radio_manager from app.repository import ContactRepository from app.routers.contacts import ( _fetch_repeater_response, @@ -23,6 +24,16 @@ KEY_A = "aa" * 32 _MONOTONIC = "app.routers.contacts._monotonic" +@pytest.fixture(autouse=True) +def _reset_radio_state(): + """Save/restore radio_manager state so tests don't leak.""" + prev = radio_manager._meshcore + prev_lock = radio_manager._operation_lock + yield + radio_manager._meshcore = prev + radio_manager._operation_lock = prev_lock + + @pytest.fixture async def test_db(): """Create an in-memory test database with schema + migrations.""" @@ -263,7 +274,10 @@ class TestTelemetryRoute: @pytest.mark.asyncio async def test_returns_404_when_contact_missing(self, test_db): mc = _mock_mc() - with patch("app.routers.contacts.require_connected", return_value=mc): + with ( + patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): with pytest.raises(HTTPException) as exc: await request_telemetry(KEY_A, TelemetryRequest(password="pw")) @@ -274,7 +288,10 @@ class TestTelemetryRoute: mc = _mock_mc() await _insert_contact(KEY_A, name="Client", contact_type=1) - with patch("app.routers.contacts.require_connected", return_value=mc): + with ( + patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): with pytest.raises(HTTPException) as exc: await request_telemetry(KEY_A, TelemetryRequest(password="pw")) @@ -289,6 +306,7 @@ class TestTelemetryRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch( "app.routers.contacts.prepare_repeater_connection", new_callable=AsyncMock, @@ -329,6 +347,7 @@ class TestTelemetryRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch( "app.routers.contacts.prepare_repeater_connection", new_callable=AsyncMock, @@ -406,6 +425,7 @@ class TestTelemetryRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch( "app.routers.contacts.prepare_repeater_connection", new_callable=AsyncMock, @@ -462,6 +482,7 @@ class TestTelemetryRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch( "app.routers.contacts.prepare_repeater_connection", new_callable=AsyncMock, @@ -485,7 +506,10 @@ class TestRepeaterCommandRoute: return_value=_radio_result(EventType.ERROR, {"err": "bad"}) ) - with patch("app.routers.contacts.require_connected", return_value=mc): + with ( + patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): with pytest.raises(HTTPException) as exc: await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -502,6 +526,7 @@ class TestRepeaterCommandRoute: # Expire the deadline after a couple of ticks with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=[0.0, 5.0, 25.0]), patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), ): @@ -530,6 +555,7 @@ class TestRepeaterCommandRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -557,6 +583,7 @@ class TestRepeaterCommandRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -584,6 +611,7 @@ class TestRepeaterCommandRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -609,6 +637,7 @@ class TestRepeaterCommandRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -631,6 +660,7 @@ class TestRepeaterCommandRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), ): @@ -651,6 +681,7 @@ class TestTraceRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.contacts.random.randint", return_value=1234), ): with pytest.raises(HTTPException) as exc: @@ -667,6 +698,7 @@ class TestTraceRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.contacts.random.randint", return_value=1234), ): with pytest.raises(HTTPException) as exc: @@ -685,6 +717,7 @@ class TestTraceRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.contacts.random.randint", return_value=1234), ): response = await request_trace(KEY_A) diff --git a/tests/test_send_messages.py b/tests/test_send_messages.py index 44ed3b1..1d55447 100644 --- a/tests/test_send_messages.py +++ b/tests/test_send_messages.py @@ -13,6 +13,7 @@ from app.models import ( SendChannelMessageRequest, SendDirectMessageRequest, ) +from app.radio import radio_manager from app.repository import ( ChannelRepository, ContactRepository, @@ -25,6 +26,16 @@ from app.routers.messages import ( ) +@pytest.fixture(autouse=True) +def _reset_radio_state(): + """Save/restore radio_manager state so tests don't leak.""" + prev = radio_manager._meshcore + prev_lock = radio_manager._operation_lock + yield + radio_manager._meshcore = prev + radio_manager._operation_lock = prev_lock + + @pytest.fixture async def test_db(): """Create an in-memory test database with schema + migrations.""" @@ -96,6 +107,7 @@ class TestOutgoingDMBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): request = SendDirectMessageRequest(destination=pub_key, text="!lasttime Alice") @@ -127,6 +139,7 @@ class TestOutgoingDMBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.bot.run_bot_for_message", new=slow_bot), ): request = SendDirectMessageRequest(destination=pub_key, text="Hello") @@ -144,6 +157,7 @@ class TestOutgoingDMBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): request = SendDirectMessageRequest(destination=pub_key, text="test") @@ -184,6 +198,7 @@ class TestOutgoingChannelBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): @@ -210,6 +225,7 @@ class TestOutgoingChannelBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): @@ -234,6 +250,7 @@ class TestOutgoingChannelBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=slow_bot), ): @@ -250,6 +267,7 @@ class TestOutgoingChannelBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()), ): @@ -282,7 +300,10 @@ class TestResendChannelMessage: ) assert msg_id is not None - with patch("app.routers.messages.require_connected", return_value=mc): + with ( + patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): result = await resend_channel_message(msg_id, new_timestamp=False) assert result["status"] == "ok" @@ -341,6 +362,7 @@ class TestResendChannelMessage: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.messages.broadcast_event"), patch("app.routers.messages.time") as mock_time, ): @@ -436,7 +458,10 @@ class TestResendChannelMessage: ) assert msg_id is not None - with patch("app.routers.messages.require_connected", return_value=mc): + with ( + patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): await resend_channel_message(msg_id, new_timestamp=False) call_kwargs = mc.commands.send_chan_msg.await_args.kwargs @@ -462,6 +487,7 @@ class TestResendChannelMessage: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.messages.broadcast_event"), ): result = await resend_channel_message(msg_id, new_timestamp=True) @@ -490,6 +516,7 @@ class TestResendChannelMessage: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.messages.broadcast_event"), ): result = await resend_channel_message(msg_id, new_timestamp=True) @@ -524,6 +551,7 @@ class TestResendChannelMessage: with ( patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), patch("app.routers.messages.broadcast_event") as mock_broadcast, ): result = await resend_channel_message(msg_id, new_timestamp=True)