From e8ddba0131a5275d6fd0267eab62bfa853154abf Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 16 Feb 2026 19:10:20 -0800 Subject: [PATCH] Add radio lock acquire around missing spots, and validate --- app/routers/contacts.py | 50 +++++++++++++++++++---------------- app/routers/radio.py | 12 ++++++--- tests/test_contacts_router.py | 8 ++++++ tests/test_radio_operation.py | 34 ++++++++++++++++++++++++ tests/test_radio_router.py | 41 ++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 26 deletions(-) diff --git a/app/routers/contacts.py b/app/routers/contacts.py index 3a0441a..c6df141 100644 --- a/app/routers/contacts.py +++ b/app/routers/contacts.py @@ -183,7 +183,8 @@ async def sync_contacts_from_radio() -> dict: logger.info("Syncing contacts from radio") - result = await mc.commands.get_contacts() + async with radio_manager.radio_operation("sync_contacts_from_radio", meshcore=mc): + result = await mc.commands.get_contacts() if result.type == EventType.ERROR: raise HTTPException(status_code=500, detail=f"Failed to get contacts: {result.payload}") @@ -208,19 +209,20 @@ async def remove_contact_from_radio(public_key: str) -> dict: contact = await _resolve_contact_or_404(public_key) - # Get the contact from radio - radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) - if not radio_contact: - # Already not on radio - await ContactRepository.set_on_radio(contact.public_key, False) - return {"status": "ok", "message": "Contact was not on radio"} + async with radio_manager.radio_operation("remove_contact_from_radio", meshcore=mc): + # Get the contact from radio + radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) + if not radio_contact: + # Already not on radio + await ContactRepository.set_on_radio(contact.public_key, False) + return {"status": "ok", "message": "Contact was not on radio"} - logger.info("Removing contact %s from radio", contact.public_key[:12]) + logger.info("Removing contact %s from radio", contact.public_key[:12]) - result = await mc.commands.remove_contact(radio_contact) + result = await mc.commands.remove_contact(radio_contact) - if result.type == EventType.ERROR: - raise HTTPException(status_code=500, detail=f"Failed to remove contact: {result.payload}") + if result.type == EventType.ERROR: + raise HTTPException(status_code=500, detail=f"Failed to remove contact: {result.payload}") await ContactRepository.set_on_radio(contact.public_key, False) return {"status": "ok"} @@ -233,17 +235,18 @@ async def add_contact_to_radio(public_key: str) -> dict: contact = await _resolve_contact_or_404(public_key, "Contact not found in database") - # Check if already on radio - radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) - if radio_contact: - return {"status": "ok", "message": "Contact already on radio"} + async with radio_manager.radio_operation("add_contact_to_radio", meshcore=mc): + # Check if already on radio + radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) + if radio_contact: + return {"status": "ok", "message": "Contact already on radio"} - logger.info("Adding contact %s to radio", contact.public_key[:12]) + logger.info("Adding contact %s to radio", contact.public_key[:12]) - result = await mc.commands.add_contact(contact.to_radio_dict()) + result = await mc.commands.add_contact(contact.to_radio_dict()) - if result.type == EventType.ERROR: - raise HTTPException(status_code=500, detail=f"Failed to add contact: {result.payload}") + if result.type == EventType.ERROR: + raise HTTPException(status_code=500, detail=f"Failed to add contact: {result.payload}") await ContactRepository.set_on_radio(contact.public_key, True) return {"status": "ok"} @@ -269,10 +272,11 @@ async def delete_contact(public_key: str) -> dict: # Remove from radio if connected and contact is on radio if radio_manager.is_connected and radio_manager.meshcore: mc = radio_manager.meshcore - radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) - if radio_contact: - logger.info("Removing contact %s from radio before deletion", contact.public_key[:12]) - await mc.commands.remove_contact(radio_contact) + async with radio_manager.radio_operation("delete_contact_from_radio", meshcore=mc): + radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) + if radio_contact: + logger.info("Removing contact %s from radio before deletion", contact.public_key[:12]) + await mc.commands.remove_contact(radio_contact) # Delete from database await ContactRepository.delete(contact.public_key) diff --git a/app/routers/radio.py b/app/routers/radio.py index a6d9dd8..de6fdb7 100644 --- a/app/routers/radio.py +++ b/app/routers/radio.py @@ -125,7 +125,8 @@ async def set_private_key(update: PrivateKeyUpdate) -> dict: raise HTTPException(status_code=400, detail="Invalid hex string for private key") from None logger.info("Importing private key") - result = await mc.commands.import_private_key(key_bytes) + async with radio_manager.radio_operation("import_private_key", meshcore=mc): + result = await mc.commands.import_private_key(key_bytes) if result.type == EventType.ERROR: raise HTTPException( @@ -149,7 +150,8 @@ async def send_advertisement() -> dict: require_connected() logger.info("Sending flood advertisement") - success = await do_send_advertisement(force=True) + async with radio_manager.radio_operation("manual_advertisement"): + success = await do_send_advertisement(force=True) if not success: raise HTTPException(status_code=500, detail="Failed to send advertisement") @@ -167,7 +169,11 @@ async def reboot_radio() -> dict: # If connected, send reboot command if radio_manager.is_connected and radio_manager.meshcore: logger.info("Rebooting radio") - await radio_manager.meshcore.commands.reboot() + async with radio_manager.radio_operation( + "reboot_radio", + meshcore=radio_manager.meshcore, + ): + await radio_manager.meshcore.commands.reboot() return { "status": "ok", "message": "Reboot command sent. Radio will reconnect automatically.", diff --git a/tests/test_contacts_router.py b/tests/test_contacts_router.py index a06f16e..62f1496 100644 --- a/tests/test_contacts_router.py +++ b/tests/test_contacts_router.py @@ -6,6 +6,7 @@ and add/remove from radio operations. Uses httpx.AsyncClient with real in-memory SQLite database. """ +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -21,6 +22,11 @@ KEY_B = "bb" * 32 # bbbb...bb KEY_C = "cc" * 32 # cccc...cc +@asynccontextmanager +async def _noop_radio_operation(*_args, **_kwargs): + yield + + @pytest.fixture async def test_db(): """Create an in-memory test database with schema + migrations.""" @@ -219,6 +225,7 @@ class TestDeleteContact: with patch("app.routers.contacts.radio_manager") as mock_rm: mock_rm.is_connected = False mock_rm.meshcore = None + mock_rm.radio_operation = _noop_radio_operation response = await client.delete(f"/api/contacts/{KEY_A}") @@ -248,6 +255,7 @@ class TestDeleteContact: with patch("app.routers.contacts.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc + mock_rm.radio_operation = _noop_radio_operation response = await client.delete(f"/api/contacts/{KEY_A}") diff --git a/tests/test_radio_operation.py b/tests/test_radio_operation.py index e61d5d2..7729267 100644 --- a/tests/test_radio_operation.py +++ b/tests/test_radio_operation.py @@ -48,6 +48,40 @@ class TestRadioOperationLock: release.set() await holder_task + @pytest.mark.asyncio + async def test_blocking_waits_and_acquires_after_release(self): + holder_entered = asyncio.Event() + holder_release = asyncio.Event() + contender_entered = asyncio.Event() + order: list[str] = [] + + async def holder(): + async with radio_manager.radio_operation("holder"): + order.append("holder_enter") + holder_entered.set() + await holder_release.wait() + order.append("holder_exit") + + async def contender(): + await holder_entered.wait() + async with radio_manager.radio_operation("contender"): + order.append("contender_enter") + contender_entered.set() + + holder_task = asyncio.create_task(holder()) + contender_task = asyncio.create_task(contender()) + + await holder_entered.wait() + await asyncio.sleep(0.02) + assert not contender_entered.is_set() + + holder_release.set() + await asyncio.wait_for(contender_entered.wait(), timeout=1.0) + + await holder_task + await contender_task + assert order == ["holder_enter", "holder_exit", "contender_enter"] + @pytest.mark.asyncio async def test_suspend_auto_fetch_stops_and_restarts(self): mc = MagicMock() diff --git a/tests/test_radio_router.py b/tests/test_radio_router.py index a9758e0..029f0af 100644 --- a/tests/test_radio_router.py +++ b/tests/test_radio_router.py @@ -1,5 +1,7 @@ """Tests for radio router endpoint logic.""" +import asyncio +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -18,6 +20,7 @@ from app.routers.radio import ( set_private_key, update_radio_config, ) +from app.radio import RadioManager def _radio_result(event_type=EventType.OK, payload=None): @@ -27,6 +30,11 @@ def _radio_result(event_type=EventType.OK, payload=None): return result +@asynccontextmanager +async def _noop_radio_operation(*_args, **_kwargs): + yield + + def _mock_meshcore_with_info(): mc = MagicMock() mc.self_info = { @@ -147,6 +155,34 @@ class TestAdvertise: assert exc.value.status_code == 500 + @pytest.mark.asyncio + async def test_concurrent_advertise_calls_are_serialized(self): + active = 0 + max_active = 0 + + async def fake_send(*, force: bool): + nonlocal active, max_active + assert force is True + active += 1 + max_active = max(max_active, active) + await asyncio.sleep(0.05) + active -= 1 + return True + + isolated_manager = RadioManager() + with ( + patch("app.routers.radio.require_connected"), + patch("app.routers.radio.radio_manager", isolated_manager), + patch( + "app.routers.radio.do_send_advertisement", + new_callable=AsyncMock, + side_effect=fake_send, + ), + ): + await asyncio.gather(send_advertisement(), send_advertisement()) + + assert max_active == 1 + class TestRebootAndReconnect: @pytest.mark.asyncio @@ -155,6 +191,7 @@ class TestRebootAndReconnect: mock_rm.is_connected = True mock_rm.meshcore = MagicMock() mock_rm.meshcore.commands.reboot = AsyncMock() + mock_rm.radio_operation = _noop_radio_operation with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() @@ -168,6 +205,7 @@ class TestRebootAndReconnect: mock_rm.is_connected = False mock_rm.meshcore = None mock_rm.is_reconnecting = True + mock_rm.radio_operation = _noop_radio_operation with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() @@ -183,6 +221,7 @@ class TestRebootAndReconnect: mock_rm.is_reconnecting = False mock_rm.reconnect = AsyncMock(return_value=True) mock_rm.post_connect_setup = AsyncMock() + mock_rm.radio_operation = _noop_radio_operation with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() @@ -196,6 +235,7 @@ class TestRebootAndReconnect: async def test_reconnect_returns_already_connected(self): mock_rm = MagicMock() mock_rm.is_connected = True + mock_rm.radio_operation = _noop_radio_operation with patch("app.routers.radio.radio_manager", mock_rm): result = await reconnect_radio() @@ -209,6 +249,7 @@ class TestRebootAndReconnect: mock_rm.is_connected = False mock_rm.is_reconnecting = False mock_rm.reconnect = AsyncMock(return_value=False) + mock_rm.radio_operation = _noop_radio_operation with patch("app.routers.radio.radio_manager", mock_rm): with pytest.raises(HTTPException) as exc: