From b8ea31666f1aa2898bf44432ca9b9d281a8fbd75 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Tue, 10 Feb 2026 18:00:41 -0800 Subject: [PATCH] Rework radio lock handling --- app/dependencies.py | 2 + app/radio.py | 156 ++++++++++++++--- app/radio_sync.py | 193 ++++++++++++--------- app/routers/channels.py | 42 ++--- app/routers/contacts.py | 311 ++++++++++++++++------------------ app/routers/messages.py | 43 +++-- app/routers/radio.py | 70 ++++---- tests/test_radio_operation.py | 91 ++++++++++ tests/test_radio_router.py | 10 +- 9 files changed, 569 insertions(+), 349 deletions(-) create mode 100644 tests/test_radio_operation.py diff --git a/app/dependencies.py b/app/dependencies.py index c377d27..b6af89d 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -10,6 +10,8 @@ def require_connected(): Raises HTTPException 503 if radio is not connected. """ + if getattr(radio_manager, "is_setup_in_progress", False) is True: + raise HTTPException(status_code=503, detail="Radio is initializing") if not radio_manager.is_connected or radio_manager.meshcore is None: raise HTTPException(status_code=503, detail="Radio not connected") return radio_manager.meshcore diff --git a/app/radio.py b/app/radio.py index 67c0ac2..25b04ab 100644 --- a/app/radio.py +++ b/app/radio.py @@ -2,6 +2,7 @@ import asyncio import glob import logging import platform +from contextlib import asynccontextmanager from pathlib import Path from meshcore import MeshCore @@ -11,6 +12,20 @@ from app.config import settings logger = logging.getLogger(__name__) +class RadioOperationError(RuntimeError): + """Base class for shared radio operation lock errors.""" + + +class RadioOperationBusyError(RadioOperationError): + """Raised when a non-blocking radio operation cannot acquire the lock.""" + + +@asynccontextmanager +async def _noop_context(): + """No-op async context manager for optional nesting.""" + yield + + def detect_serial_devices() -> list[str]: """Detect available serial devices based on platform.""" devices: list[str] = [] @@ -106,6 +121,83 @@ class RadioManager: self._reconnect_task: asyncio.Task | None = None self._last_connected: bool = False self._reconnect_lock: asyncio.Lock | None = None + self._operation_lock: asyncio.Lock | None = None + self._setup_lock: asyncio.Lock | None = None + self._setup_in_progress: bool = False + + async def _acquire_operation_lock( + self, + name: str, + *, + blocking: bool, + ) -> None: + """Acquire the shared radio operation lock.""" + + if self._operation_lock is None: + self._operation_lock = asyncio.Lock() + + if not blocking: + if self._operation_lock.locked(): + raise RadioOperationBusyError(f"Radio is busy (operation: {name})") + await self._operation_lock.acquire() + else: + await self._operation_lock.acquire() + + logger.debug("Acquired radio operation lock (%s)", name) + + def _release_operation_lock(self, name: str) -> None: + """Release the shared radio operation lock.""" + if self._operation_lock and self._operation_lock.locked(): + self._operation_lock.release() + logger.debug("Released radio operation lock (%s)", name) + else: + logger.error("Attempted to release unlocked radio operation lock (%s)", name) + + @asynccontextmanager + async def radio_operation( + self, + name: str, + *, + pause_polling: bool = False, + suspend_auto_fetch: bool = False, + blocking: bool = True, + meshcore: MeshCore | None = None, + ): + """Acquire shared radio lock and optionally pause polling / auto-fetch. + + Args: + name: Human-readable operation name for logs/errors. + pause_polling: Pause fallback message polling while held. + suspend_auto_fetch: Stop MeshCore auto message fetching while held. + blocking: If False, fail immediately when lock is held. + meshcore: Optional explicit MeshCore instance for auto-fetch control. + """ + await self._acquire_operation_lock(name, blocking=blocking) + + poll_context = _noop_context() + if pause_polling: + from app.radio_sync import pause_polling as pause_polling_context + + poll_context = pause_polling_context() + + mc = meshcore or self._meshcore + auto_fetch_paused = False + + try: + async with poll_context: + if suspend_auto_fetch and mc is not None: + await mc.stop_auto_message_fetching() + auto_fetch_paused = True + yield + finally: + try: + if auto_fetch_paused and mc is not None: + try: + await mc.start_auto_message_fetching() + except Exception as e: + logger.warning("Failed to restart auto message fetching (%s): %s", name, e) + finally: + self._release_operation_lock(name) async def post_connect_setup(self) -> None: """Full post-connection setup: handlers, key export, sync, advertisements, polling. @@ -128,39 +220,49 @@ class RadioManager: if not self._meshcore: return - register_event_handlers(self._meshcore) - await export_and_store_private_key(self._meshcore) + if self._setup_lock is None: + self._setup_lock = asyncio.Lock() - # Sync radio clock with system time - await sync_radio_time() + async with self._setup_lock: + if not self._meshcore: + return + self._setup_in_progress = True + try: + register_event_handlers(self._meshcore) + await export_and_store_private_key(self._meshcore) - # Sync contacts/channels from radio to DB and clear radio - logger.info("Syncing and offloading radio data...") - result = await sync_and_offload_all() - logger.info("Sync complete: %s", result) + # Sync radio clock with system time + await sync_radio_time() - # Start periodic sync (idempotent) - start_periodic_sync() + # Sync contacts/channels from radio to DB and clear radio + logger.info("Syncing and offloading radio data...") + result = await sync_and_offload_all() + logger.info("Sync complete: %s", result) - # Send advertisement to announce our presence (if enabled and not throttled) - if await send_advertisement(): - logger.info("Advertisement sent") - else: - logger.debug("Advertisement skipped (disabled or throttled)") + # Start periodic sync (idempotent) + start_periodic_sync() - # Start periodic advertisement (idempotent) - start_periodic_advert() + # Send advertisement to announce our presence (if enabled and not throttled) + if await send_advertisement(): + logger.info("Advertisement sent") + else: + logger.debug("Advertisement skipped (disabled or throttled)") - await self._meshcore.start_auto_message_fetching() - logger.info("Auto message fetching started") + # Start periodic advertisement (idempotent) + start_periodic_advert() - # Drain any messages that were queued before we connected - drained = await drain_pending_messages() - if drained > 0: - logger.info("Drained %d pending message(s)", drained) + await self._meshcore.start_auto_message_fetching() + logger.info("Auto message fetching started") - # Start periodic message polling as fallback (idempotent) - start_message_polling() + # Drain any messages that were queued before we connected + drained = await drain_pending_messages() + if drained > 0: + logger.info("Drained %d pending message(s)", drained) + + # Start periodic message polling as fallback (idempotent) + start_message_polling() + finally: + self._setup_in_progress = False logger.info("Post-connect setup complete") @@ -180,6 +282,10 @@ class RadioManager: def is_reconnecting(self) -> bool: return self._reconnect_lock is not None and self._reconnect_lock.locked() + @property + def is_setup_in_progress(self) -> bool: + return self._setup_in_progress + async def connect(self) -> None: """Connect to the radio using the configured transport.""" if self._meshcore is not None: diff --git a/app/radio_sync.py b/app/radio_sync.py index 2ee10ad..91445ae 100644 --- a/app/radio_sync.py +++ b/app/radio_sync.py @@ -17,7 +17,7 @@ from contextlib import asynccontextmanager from meshcore import EventType from app.models import Contact -from app.radio import radio_manager +from app.radio import RadioOperationBusyError, radio_manager from app.repository import ( AppSettingsRepository, ChannelRepository, @@ -318,7 +318,16 @@ async def _message_poll_loop(): await asyncio.sleep(MESSAGE_POLL_INTERVAL) if radio_manager.is_connected and not is_polling_paused(): - await poll_for_messages() + mc = radio_manager.meshcore + if mc is not None: + try: + async with radio_manager.radio_operation( + "message_poll_loop", + blocking=False, + ): + await poll_for_messages() + except RadioOperationBusyError: + logger.debug("Skipping message poll: radio busy") except asyncio.CancelledError: break @@ -414,7 +423,16 @@ async def _periodic_advert_loop(): # Try to send - send_advertisement() handles all checks # (disabled, throttled, not connected) if radio_manager.is_connected: - await send_advertisement() + mc = radio_manager.meshcore + if mc is not None: + try: + async with radio_manager.radio_operation( + "periodic_advertisement", + blocking=False, + ): + await send_advertisement() + except RadioOperationBusyError: + logger.debug("Skipping periodic advertisement: radio busy") # Sleep before next check await asyncio.sleep(ADVERT_CHECK_INTERVAL) @@ -477,9 +495,20 @@ async def _periodic_sync_loop(): while True: try: await asyncio.sleep(SYNC_INTERVAL) - logger.debug("Running periodic radio sync") - await sync_and_offload_all() - await sync_radio_time() + mc = radio_manager.meshcore + if mc is None: + continue + + try: + async with radio_manager.radio_operation( + "periodic_sync", + blocking=False, + ): + logger.debug("Running periodic radio sync") + await sync_and_offload_all() + await sync_radio_time() + except RadioOperationBusyError: + logger.debug("Skipping periodic sync: radio busy") except asyncio.CancelledError: logger.info("Periodic sync task cancelled") break @@ -536,93 +565,101 @@ async def sync_recent_contacts_to_radio(force: bool = False) -> dict: return {"loaded": 0, "error": "Radio not connected"} mc = radio_manager.meshcore - _last_contact_sync = now try: - # Build prioritized contact list: - # 1) favorite contacts, in favorite order - # 2) most recent non-repeater contacts (excluding already-selected favorites) - app_settings = await AppSettingsRepository.get() - max_contacts = app_settings.max_radio_contacts - selected_contacts: list[Contact] = [] - selected_keys: set[str] = set() + async with radio_manager.radio_operation( + "sync_recent_contacts_to_radio", + blocking=False, + ): + _last_contact_sync = now - favorite_contacts_loaded = 0 - for favorite in app_settings.favorites: - if favorite.type != "contact": - continue - contact = await ContactRepository.get_by_key_or_prefix(favorite.id) - if not contact: - continue - key = contact.public_key.lower() - if key in selected_keys: - continue - selected_keys.add(key) - selected_contacts.append(contact) - favorite_contacts_loaded += 1 - if len(selected_contacts) >= max_contacts: - break + # Build prioritized contact list: + # 1) favorite contacts, in favorite order + # 2) most recent non-repeater contacts (excluding already-selected favorites) + app_settings = await AppSettingsRepository.get() + max_contacts = app_settings.max_radio_contacts + selected_contacts: list[Contact] = [] + selected_keys: set[str] = set() - if len(selected_contacts) < max_contacts: - recent_contacts = await ContactRepository.get_recent_non_repeaters(limit=max_contacts) - for contact in recent_contacts: + favorite_contacts_loaded = 0 + for favorite in app_settings.favorites: + if favorite.type != "contact": + continue + contact = await ContactRepository.get_by_key_or_prefix(favorite.id) + if not contact: + continue key = contact.public_key.lower() if key in selected_keys: continue selected_keys.add(key) selected_contacts.append(contact) + favorite_contacts_loaded += 1 if len(selected_contacts) >= max_contacts: break - logger.debug( - "Selected %d contacts to sync (%d favorite contacts first, limit=%d)", - len(selected_contacts), - favorite_contacts_loaded, - max_contacts, - ) + if len(selected_contacts) < max_contacts: + recent_contacts = await ContactRepository.get_recent_non_repeaters(limit=max_contacts) + for contact in recent_contacts: + key = contact.public_key.lower() + if key in selected_keys: + continue + selected_keys.add(key) + selected_contacts.append(contact) + if len(selected_contacts) >= max_contacts: + break - loaded = 0 - already_on_radio = 0 - failed = 0 - - for contact in selected_contacts: - # Check if already on radio - radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) - if radio_contact: - already_on_radio += 1 - # Update DB if not marked as on_radio - if not contact.on_radio: - await ContactRepository.set_on_radio(contact.public_key, True) - continue - - try: - result = await mc.commands.add_contact(contact.to_radio_dict()) - if result.type == EventType.OK: - loaded += 1 - await ContactRepository.set_on_radio(contact.public_key, True) - logger.debug("Loaded contact %s to radio", contact.public_key[:12]) - else: - failed += 1 - logger.warning( - "Failed to load contact %s: %s", contact.public_key[:12], result.payload - ) - except Exception as e: - failed += 1 - logger.warning("Error loading contact %s: %s", contact.public_key[:12], e) - - if loaded > 0 or failed > 0: - logger.info( - "Contact sync: loaded %d, already on radio %d, failed %d", - loaded, - already_on_radio, - failed, + logger.debug( + "Selected %d contacts to sync (%d favorite contacts first, limit=%d)", + len(selected_contacts), + favorite_contacts_loaded, + max_contacts, ) - return { - "loaded": loaded, - "already_on_radio": already_on_radio, - "failed": failed, - } + loaded = 0 + already_on_radio = 0 + failed = 0 + + for contact in selected_contacts: + # Check if already on radio + radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) + if radio_contact: + already_on_radio += 1 + # Update DB if not marked as on_radio + if not contact.on_radio: + await ContactRepository.set_on_radio(contact.public_key, True) + continue + + try: + result = await mc.commands.add_contact(contact.to_radio_dict()) + if result.type == EventType.OK: + loaded += 1 + await ContactRepository.set_on_radio(contact.public_key, True) + logger.debug("Loaded contact %s to radio", contact.public_key[:12]) + else: + failed += 1 + logger.warning( + "Failed to load contact %s: %s", contact.public_key[:12], result.payload + ) + except Exception as e: + failed += 1 + logger.warning("Error loading contact %s: %s", contact.public_key[:12], e) + + if loaded > 0 or failed > 0: + logger.info( + "Contact sync: loaded %d, already on radio %d, failed %d", + loaded, + already_on_radio, + failed, + ) + + return { + "loaded": loaded, + "already_on_radio": already_on_radio, + "failed": failed, + } + except RadioOperationBusyError: + logger.debug("Skipping contact sync to radio: radio busy") + return {"loaded": 0, "busy": True} except Exception as e: logger.error("Error syncing contacts to radio: %s", e) diff --git a/app/routers/channels.py b/app/routers/channels.py index 177d018..b2f4c2a 100644 --- a/app/routers/channels.py +++ b/app/routers/channels.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field from app.dependencies import require_connected from app.models import Channel +from app.radio import radio_manager from app.radio_sync import ensure_default_channels from app.repository import ChannelRepository @@ -89,30 +90,31 @@ async def sync_channels_from_radio(max_channels: int = Query(default=40, ge=1, l logger.info("Syncing channels from radio (checking %d slots)", max_channels) count = 0 - for idx in range(max_channels): - result = await mc.commands.get_channel(idx) + async with radio_manager.radio_operation("sync_channels_from_radio"): + for idx in range(max_channels): + result = await mc.commands.get_channel(idx) - if result.type == EventType.CHANNEL_INFO: - payload = result.payload - name = payload.get("channel_name", "") - secret = payload.get("channel_secret", b"") + if result.type == EventType.CHANNEL_INFO: + payload = result.payload + name = payload.get("channel_name", "") + secret = payload.get("channel_secret", b"") - # Skip empty channels - if not name or name == "\x00" * len(name): - continue + # Skip empty channels + if not name or name == "\x00" * len(name): + continue - is_hashtag = name.startswith("#") - key_bytes = secret if isinstance(secret, bytes) else bytes(secret) - key_hex = key_bytes.hex().upper() + is_hashtag = name.startswith("#") + key_bytes = secret if isinstance(secret, bytes) else bytes(secret) + key_hex = key_bytes.hex().upper() - await ChannelRepository.upsert( - key=key_hex, - name=name, - is_hashtag=is_hashtag, - on_radio=True, - ) - count += 1 - logger.debug("Synced channel %s: %s", key_hex, name) + await ChannelRepository.upsert( + key=key_hex, + name=name, + is_hashtag=is_hashtag, + on_radio=True, + ) + count += 1 + logger.debug("Synced channel %s: %s", key_hex, name) logger.info("Synced %d channels from radio", count) return {"synced": count} diff --git a/app/routers/contacts.py b/app/routers/contacts.py index 6a02c2e..16b1ba3 100644 --- a/app/routers/contacts.py +++ b/app/routers/contacts.py @@ -20,7 +20,6 @@ from app.models import ( ) from app.packet_processor import start_historical_dm_decryption from app.radio import radio_manager -from app.radio_sync import pause_polling from app.repository import ContactRepository, MessageRepository logger = logging.getLogger(__name__) @@ -294,117 +293,115 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem detail=f"Contact is not a repeater (type={contact.type}, expected {CONTACT_TYPE_REPEATER})", ) - # Prepare connection (add/remove dance + login) - await prepare_repeater_connection(mc, contact, request.password) + async with radio_manager.radio_operation( + "request_telemetry", + meshcore=mc, + pause_polling=True, + suspend_auto_fetch=True, + ): + # Prepare connection (add/remove dance + login) + await prepare_repeater_connection(mc, contact, request.password) - # Request status with retries - logger.info("Requesting status from repeater %s", contact.public_key[:12]) - status = None - for attempt in range(1, 4): - logger.debug("Status request attempt %d/3", attempt) - status = await mc.commands.req_status_sync(contact.public_key, timeout=10, min_timeout=5) - if status: - break - logger.debug("Status request timeout, retrying...") + # Request status with retries + logger.info("Requesting status from repeater %s", contact.public_key[:12]) + status = None + for attempt in range(1, 4): + logger.debug("Status request attempt %d/3", attempt) + status = await mc.commands.req_status_sync(contact.public_key, timeout=10, min_timeout=5) + if status: + break + logger.debug("Status request timeout, retrying...") - if not status: - raise HTTPException(status_code=504, detail="No response from repeater after 3 attempts") + if not status: + raise HTTPException(status_code=504, detail="No response from repeater after 3 attempts") - logger.info("Received telemetry from %s: %s", contact.public_key[:12], status) + logger.info("Received telemetry from %s: %s", contact.public_key[:12], status) - # Fetch neighbors (fetch_all_neighbours handles pagination) - logger.info("Fetching neighbors from repeater %s", contact.public_key[:12]) - neighbors_data = None - for attempt in range(1, 4): - logger.debug("Neighbors request attempt %d/3", attempt) - neighbors_data = await mc.commands.fetch_all_neighbours( - contact.public_key, timeout=10, min_timeout=5 - ) - if neighbors_data: - break - logger.debug("Neighbors request timeout, retrying...") - - # Process neighbors - resolve pubkey prefixes to contact names - neighbors: list[NeighborInfo] = [] - if neighbors_data and "neighbours" in neighbors_data: - logger.info("Received %d neighbors", len(neighbors_data["neighbours"])) - for n in neighbors_data["neighbours"]: - pubkey_prefix = n.get("pubkey", "") - # Try to resolve to a contact name from our database - resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) - neighbors.append( - NeighborInfo( - pubkey_prefix=pubkey_prefix, - name=resolved_contact.name if resolved_contact else None, - snr=n.get("snr", 0.0), - last_heard_seconds=n.get("secs_ago", 0), - ) + # Fetch neighbors (fetch_all_neighbours handles pagination) + logger.info("Fetching neighbors from repeater %s", contact.public_key[:12]) + neighbors_data = None + for attempt in range(1, 4): + logger.debug("Neighbors request attempt %d/3", attempt) + neighbors_data = await mc.commands.fetch_all_neighbours( + contact.public_key, timeout=10, min_timeout=5 ) + if neighbors_data: + break + logger.debug("Neighbors request timeout, retrying...") - # Fetch ACL - logger.info("Fetching ACL from repeater %s", contact.public_key[:12]) - acl_data = None - for attempt in range(1, 4): - logger.debug("ACL request attempt %d/3", attempt) - acl_data = await mc.commands.req_acl_sync(contact.public_key, timeout=10, min_timeout=5) - if acl_data: - break - logger.debug("ACL request timeout, retrying...") - - # Process ACL - resolve pubkey prefixes to contact names - acl_entries: list[AclEntry] = [] - if acl_data and isinstance(acl_data, list): - logger.info("Received %d ACL entries", len(acl_data)) - for entry in acl_data: - pubkey_prefix = entry.get("key", "") - perm = entry.get("perm", 0) - # Try to resolve to a contact name from our database - resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) - acl_entries.append( - AclEntry( - pubkey_prefix=pubkey_prefix, - name=resolved_contact.name if resolved_contact else None, - permission=perm, - permission_name=ACL_PERMISSION_NAMES.get(perm, f"Unknown({perm})"), + # Process neighbors - resolve pubkey prefixes to contact names + neighbors: list[NeighborInfo] = [] + if neighbors_data and "neighbours" in neighbors_data: + logger.info("Received %d neighbors", len(neighbors_data["neighbours"])) + for n in neighbors_data["neighbours"]: + pubkey_prefix = n.get("pubkey", "") + # Try to resolve to a contact name from our database + resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) + neighbors.append( + NeighborInfo( + pubkey_prefix=pubkey_prefix, + name=resolved_contact.name if resolved_contact else None, + snr=n.get("snr", 0.0), + last_heard_seconds=n.get("secs_ago", 0), + ) ) - ) - # Fetch clock output (up to 2 attempts) - # Must pause polling and stop auto-fetch to prevent race condition where - # the CLI response is consumed before we can call get_msg() - logger.info("Fetching clock from repeater %s", contact.public_key[:12]) - clock_output: str | None = None + # Fetch ACL + logger.info("Fetching ACL from repeater %s", contact.public_key[:12]) + acl_data = None + for attempt in range(1, 4): + logger.debug("ACL request attempt %d/3", attempt) + acl_data = await mc.commands.req_acl_sync(contact.public_key, timeout=10, min_timeout=5) + if acl_data: + break + logger.debug("ACL request timeout, retrying...") - async with pause_polling(): - await mc.stop_auto_message_fetching() - try: - for attempt in range(1, 3): - logger.debug("Clock request attempt %d/2", attempt) - try: - send_result = await mc.commands.send_cmd(contact.public_key, "clock") - if send_result.type == EventType.ERROR: - logger.debug("Clock command send error: %s", send_result.payload) - continue + # Process ACL - resolve pubkey prefixes to contact names + acl_entries: list[AclEntry] = [] + if acl_data and isinstance(acl_data, list): + logger.info("Received %d ACL entries", len(acl_data)) + for entry in acl_data: + pubkey_prefix = entry.get("key", "") + perm = entry.get("perm", 0) + # Try to resolve to a contact name from our database + resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) + acl_entries.append( + AclEntry( + pubkey_prefix=pubkey_prefix, + name=resolved_contact.name if resolved_contact else None, + permission=perm, + permission_name=ACL_PERMISSION_NAMES.get(perm, f"Unknown({perm})"), + ) + ) - # Wait for response - wait_result = await mc.wait_for_event(EventType.MESSAGES_WAITING, timeout=5.0) - if wait_result is None: - logger.debug("Clock request timeout, retrying...") - continue - - response_event = await mc.commands.get_msg() - if response_event.type == EventType.ERROR: - logger.debug("Clock get_msg error: %s", response_event.payload) - continue - - clock_output = response_event.payload.get("text", "") - logger.info("Received clock output: %s", clock_output) - break - except Exception as e: - logger.debug("Clock request exception: %s", e) + # Fetch clock output (up to 2 attempts) + logger.info("Fetching clock from repeater %s", contact.public_key[:12]) + clock_output: str | None = None + for attempt in range(1, 3): + logger.debug("Clock request attempt %d/2", attempt) + try: + send_result = await mc.commands.send_cmd(contact.public_key, "clock") + if send_result.type == EventType.ERROR: + logger.debug("Clock command send error: %s", send_result.payload) continue - finally: - await mc.start_auto_message_fetching() + + # Wait for response + wait_result = await mc.wait_for_event(EventType.MESSAGES_WAITING, timeout=5.0) + if wait_result is None: + logger.debug("Clock request timeout, retrying...") + continue + + response_event = await mc.commands.get_msg() + if response_event.type == EventType.ERROR: + logger.debug("Clock get_msg error: %s", response_event.payload) + continue + + clock_output = response_event.payload.get("text", "") + logger.info("Received clock output: %s", clock_output) + break + except Exception as e: + logger.debug("Clock request exception: %s", e) + continue if clock_output is None: clock_output = "Unable to fetch `clock` output (repeater did not respond)" @@ -469,71 +466,60 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com detail=f"Contact is not a repeater (type={contact.type}, expected {CONTACT_TYPE_REPEATER})", ) - # Pause message polling to prevent it from stealing our response - async with pause_polling(): - # Stop auto-fetch to prevent race condition where it consumes our CLI response - # before we can call get_msg(). This was causing every other command to fail - # with {'messages_available': False}. - await mc.stop_auto_message_fetching() + async with radio_manager.radio_operation( + "send_repeater_command", + meshcore=mc, + pause_polling=True, + suspend_auto_fetch=True, + ): + # Add contact to radio with path from DB + logger.info("Adding repeater %s to radio", contact.public_key[:12]) + await mc.commands.add_contact(contact.to_radio_dict()) + # Send the command + logger.info("Sending command to repeater %s: %s", contact.public_key[:12], request.command) + + send_result = await mc.commands.send_cmd(contact.public_key, request.command) + + if send_result.type == EventType.ERROR: + raise HTTPException(status_code=500, detail=f"Failed to send command: {send_result.payload}") + + # Wait for response (MESSAGES_WAITING event, then get_msg) try: - # Add contact to radio with path from DB - logger.info("Adding repeater %s to radio", contact.public_key[:12]) - await mc.commands.add_contact(contact.to_radio_dict()) + wait_result = await mc.wait_for_event(EventType.MESSAGES_WAITING, timeout=10.0) - # Send the command - logger.info( - "Sending command to repeater %s: %s", contact.public_key[:12], request.command - ) - - send_result = await mc.commands.send_cmd(contact.public_key, request.command) - - if send_result.type == EventType.ERROR: - raise HTTPException( - status_code=500, detail=f"Failed to send command: {send_result.payload}" + if wait_result is None: + # Timeout - no response received + logger.warning( + "No response from repeater %s for command: %s", + contact.public_key[:12], + request.command, ) - - # Wait for response (MESSAGES_WAITING event, then get_msg) - try: - wait_result = await mc.wait_for_event(EventType.MESSAGES_WAITING, timeout=10.0) - - if wait_result is None: - # Timeout - no response received - logger.warning( - "No response from repeater %s for command: %s", - contact.public_key[:12], - request.command, - ) - return CommandResponse( - command=request.command, - response="(no response - command may have been processed)", - ) - - response_event = await mc.commands.get_msg() - - if response_event.type == EventType.ERROR: - return CommandResponse( - command=request.command, response=f"(error: {response_event.payload})" - ) - - # Extract the response text and timestamp from the payload - response_text = response_event.payload.get("text", str(response_event.payload)) - sender_timestamp = response_event.payload.get("timestamp") - logger.info("Received response from %s: %s", contact.public_key[:12], response_text) - return CommandResponse( command=request.command, - response=response_text, - sender_timestamp=sender_timestamp, + response="(no response - command may have been processed)", ) - except Exception as e: - logger.error("Error waiting for response: %s", e) - return CommandResponse( - command=request.command, response=f"(error waiting for response: {e})" - ) - finally: - # Always restart auto-fetch, even if an error occurred - await mc.start_auto_message_fetching() + + response_event = await mc.commands.get_msg() + + if response_event.type == EventType.ERROR: + return CommandResponse(command=request.command, response=f"(error: {response_event.payload})") + + # Extract the response text and timestamp from the payload + response_text = response_event.payload.get("text", str(response_event.payload)) + sender_timestamp = response_event.payload.get("timestamp") + logger.info("Received response from %s: %s", contact.public_key[:12], response_text) + + return CommandResponse( + command=request.command, + response=response_text, + sender_timestamp=sender_timestamp, + ) + except Exception as e: + logger.error("Error waiting for response: %s", e) + return CommandResponse( + command=request.command, response=f"(error waiting for response: {e})" + ) @router.post("/{public_key}/trace", response_model=TraceResponse) @@ -554,10 +540,9 @@ async def request_trace(public_key: str) -> TraceResponse: # First 2 hex chars of pubkey = 1-byte hash used by the trace protocol contact_hash = contact.public_key[:2] - # Note: unlike command/telemetry endpoints, trace does NOT need - # stop/start_auto_message_fetching because the response arrives as a - # TRACE_DATA event through the reader loop, not via get_msg(). - async with pause_polling(): + # Trace does not need auto-fetch suspension: response arrives as TRACE_DATA + # from the reader loop, not via get_msg(). + async with radio_manager.radio_operation("request_trace", pause_polling=True): # Ensure contact is on radio so the trace can reach them await mc.commands.add_contact(contact.to_radio_dict()) diff --git a/app/routers/messages.py b/app/routers/messages.py index 3248abe..0bc9438 100644 --- a/app/routers/messages.py +++ b/app/routers/messages.py @@ -8,14 +8,12 @@ from meshcore import EventType from app.dependencies import require_connected from app.event_handlers import track_pending_ack from app.models import Message, SendChannelMessageRequest, SendDirectMessageRequest +from app.radio import radio_manager from app.repository import MessageRepository logger = logging.getLogger(__name__) router = APIRouter(prefix="/messages", tags=["messages"]) -# Serialize channel sends that reuse a temporary radio slot. -_channel_send_lock = asyncio.Lock() - @router.get("", response_model=list[Message]) async def list_messages( @@ -60,28 +58,29 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message: # so we can't rely on it to know if the firmware has the contact. # add_contact is idempotent - updates if exists, adds if not. contact_data = db_contact.to_radio_dict() - logger.debug("Ensuring contact %s is on radio before sending", db_contact.public_key[:12]) - add_result = await mc.commands.add_contact(contact_data) - if add_result.type == EventType.ERROR: - logger.warning("Failed to add contact to radio: %s", add_result.payload) - # Continue anyway - might still work if contact exists + async with radio_manager.radio_operation("send_direct_message"): + logger.debug("Ensuring contact %s is on radio before sending", db_contact.public_key[:12]) + add_result = await mc.commands.add_contact(contact_data) + if add_result.type == EventType.ERROR: + logger.warning("Failed to add contact to radio: %s", add_result.payload) + # Continue anyway - might still work if contact exists - # Get the contact from the library cache (may have updated info like path) - contact = mc.get_contact_by_key_prefix(db_contact.public_key[:12]) - if not contact: - contact = contact_data + # Get the contact from the library cache (may have updated info like path) + contact = mc.get_contact_by_key_prefix(db_contact.public_key[:12]) + if not contact: + contact = contact_data - logger.info("Sending direct message to %s", db_contact.public_key[:12]) + logger.info("Sending direct message to %s", db_contact.public_key[:12]) - # Capture timestamp BEFORE sending so we can pass the same value to both the radio - # and the database. This ensures consistency for deduplication. - now = int(time.time()) + # Capture timestamp BEFORE sending so we can pass the same value to both the radio + # and the database. This ensures consistency for deduplication. + now = int(time.time()) - result = await mc.commands.send_msg( - dst=contact, - msg=request.text, - timestamp=now, - ) + result = await mc.commands.send_msg( + dst=contact, + msg=request.text, + timestamp=now, + ) if result.type == EventType.ERROR: raise HTTPException(status_code=500, detail=f"Failed to send message: {result.payload}") @@ -179,7 +178,7 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message: expected_hash, ) - async with _channel_send_lock: + async with radio_manager.radio_operation("send_channel_message"): # Load the channel to a temporary radio slot before sending set_result = await mc.commands.set_channel( channel_idx=TEMP_RADIO_SLOT, diff --git a/app/routers/radio.py b/app/routers/radio.py index cbbcac1..a6d9dd8 100644 --- a/app/routers/radio.py +++ b/app/routers/radio.py @@ -5,6 +5,7 @@ from meshcore import EventType from pydantic import BaseModel, Field from app.dependencies import require_connected +from app.radio import radio_manager from app.radio_sync import send_advertisement as do_send_advertisement from app.radio_sync import sync_radio_time @@ -71,43 +72,44 @@ async def update_radio_config(update: RadioConfigUpdate) -> RadioConfigResponse: """Update radio configuration. Only provided fields will be updated.""" mc = require_connected() - if update.name is not None: - logger.info("Setting radio name to %s", update.name) - await mc.commands.set_name(update.name) + async with radio_manager.radio_operation("update_radio_config"): + if update.name is not None: + logger.info("Setting radio name to %s", update.name) + await mc.commands.set_name(update.name) - if update.lat is not None or update.lon is not None: - current_info = mc.self_info - lat = update.lat if update.lat is not None else current_info.get("adv_lat", 0.0) - lon = update.lon if update.lon is not None else current_info.get("adv_lon", 0.0) - logger.info("Setting radio coordinates to %f, %f", lat, lon) - await mc.commands.set_coords(lat=lat, lon=lon) + if update.lat is not None or update.lon is not None: + current_info = mc.self_info + lat = update.lat if update.lat is not None else current_info.get("adv_lat", 0.0) + lon = update.lon if update.lon is not None else current_info.get("adv_lon", 0.0) + logger.info("Setting radio coordinates to %f, %f", lat, lon) + await mc.commands.set_coords(lat=lat, lon=lon) - if update.tx_power is not None: - logger.info("Setting TX power to %d dBm", update.tx_power) - await mc.commands.set_tx_power(val=update.tx_power) + if update.tx_power is not None: + logger.info("Setting TX power to %d dBm", update.tx_power) + await mc.commands.set_tx_power(val=update.tx_power) - if update.radio is not None: - logger.info( - "Setting radio params: freq=%f MHz, bw=%f kHz, sf=%d, cr=%d", - update.radio.freq, - update.radio.bw, - update.radio.sf, - update.radio.cr, - ) - await mc.commands.set_radio( - freq=update.radio.freq, - bw=update.radio.bw, - sf=update.radio.sf, - cr=update.radio.cr, - ) + if update.radio is not None: + logger.info( + "Setting radio params: freq=%f MHz, bw=%f kHz, sf=%d, cr=%d", + update.radio.freq, + update.radio.bw, + update.radio.sf, + update.radio.cr, + ) + await mc.commands.set_radio( + freq=update.radio.freq, + bw=update.radio.bw, + sf=update.radio.sf, + cr=update.radio.cr, + ) - # Sync time with system clock - await sync_radio_time() + # Sync time with system clock + await sync_radio_time() - # Re-fetch self_info so the response reflects the changes we just made. - # Commands like set_name() write to flash but don't update the cached - # self_info — send_appstart() triggers a fresh SELF_INFO from the radio. - await mc.commands.send_appstart() + # Re-fetch self_info so the response reflects the changes we just made. + # Commands like set_name() write to flash but don't update the cached + # self_info — send_appstart() triggers a fresh SELF_INFO from the radio. + await mc.commands.send_appstart() return await get_radio_config() @@ -162,8 +164,6 @@ async def reboot_radio() -> dict: If connected: sends reboot command, connection will temporarily drop and auto-reconnect. If not connected: attempts to reconnect (same as /reconnect endpoint). """ - from app.radio import radio_manager - # If connected, send reboot command if radio_manager.is_connected and radio_manager.meshcore: logger.info("Rebooting radio") @@ -202,8 +202,6 @@ async def reconnect_radio() -> dict: if no specific port is configured. Useful when the radio has been disconnected or power-cycled. """ - from app.radio import radio_manager - if radio_manager.is_connected: return {"status": "ok", "message": "Already connected", "connected": True} diff --git a/tests/test_radio_operation.py b/tests/test_radio_operation.py new file mode 100644 index 0000000..e61d5d2 --- /dev/null +++ b/tests/test_radio_operation.py @@ -0,0 +1,91 @@ +"""Tests for shared radio operation locking behavior.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.radio import RadioOperationBusyError, radio_manager +from app.radio_sync import is_polling_paused + + +@pytest.fixture(autouse=True) +def reset_radio_operation_state(): + """Reset shared radio operation lock state before/after each test.""" + prev_meshcore = radio_manager._meshcore + radio_manager._operation_lock = None + radio_manager._meshcore = None + + import app.radio_sync as radio_sync + + radio_sync._polling_pause_count = 0 + yield + radio_manager._operation_lock = None + radio_manager._meshcore = prev_meshcore + radio_sync._polling_pause_count = 0 + + +class TestRadioOperationLock: + """Validate shared radio operation lock semantics.""" + + @pytest.mark.asyncio + async def test_non_blocking_fails_when_lock_held_by_other_task(self): + started = asyncio.Event() + release = asyncio.Event() + + async def holder(): + async with radio_manager.radio_operation("holder"): + started.set() + await release.wait() + + holder_task = asyncio.create_task(holder()) + await started.wait() + + with pytest.raises(RadioOperationBusyError): + async with radio_manager.radio_operation("contender", blocking=False): + pass + + release.set() + await holder_task + + @pytest.mark.asyncio + async def test_suspend_auto_fetch_stops_and_restarts(self): + mc = MagicMock() + mc.stop_auto_message_fetching = AsyncMock() + mc.start_auto_message_fetching = AsyncMock() + radio_manager._meshcore = mc + + async with radio_manager.radio_operation( + "auto_fetch_toggle", + suspend_auto_fetch=True, + ): + pass + + mc.stop_auto_message_fetching.assert_awaited_once() + mc.start_auto_message_fetching.assert_awaited_once() + + @pytest.mark.asyncio + async def test_lock_released_when_auto_fetch_restart_is_cancelled(self): + mc = MagicMock() + mc.stop_auto_message_fetching = AsyncMock() + mc.start_auto_message_fetching = AsyncMock(side_effect=asyncio.CancelledError()) + radio_manager._meshcore = mc + + with pytest.raises(asyncio.CancelledError): + async with radio_manager.radio_operation( + "cancelled_restart", + suspend_auto_fetch=True, + ): + pass + + async with radio_manager.radio_operation("after_cancel", blocking=False): + pass + + @pytest.mark.asyncio + async def test_pause_polling_toggles_state(self): + assert not is_polling_paused() + + async with radio_manager.radio_operation("pause_polling", pause_polling=True): + assert is_polling_paused() + + assert not is_polling_paused() diff --git a/tests/test_radio_router.py b/tests/test_radio_router.py index aad9f97..a9758e0 100644 --- a/tests/test_radio_router.py +++ b/tests/test_radio_router.py @@ -156,7 +156,7 @@ class TestRebootAndReconnect: mock_rm.meshcore = MagicMock() mock_rm.meshcore.commands.reboot = AsyncMock() - with patch("app.radio.radio_manager", mock_rm): + with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() assert result["status"] == "ok" @@ -169,7 +169,7 @@ class TestRebootAndReconnect: mock_rm.meshcore = None mock_rm.is_reconnecting = True - with patch("app.radio.radio_manager", mock_rm): + with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() assert result["status"] == "pending" @@ -184,7 +184,7 @@ class TestRebootAndReconnect: mock_rm.reconnect = AsyncMock(return_value=True) mock_rm.post_connect_setup = AsyncMock() - with patch("app.radio.radio_manager", mock_rm): + with patch("app.routers.radio.radio_manager", mock_rm): result = await reboot_radio() assert result["status"] == "ok" @@ -197,7 +197,7 @@ class TestRebootAndReconnect: mock_rm = MagicMock() mock_rm.is_connected = True - with patch("app.radio.radio_manager", mock_rm): + with patch("app.routers.radio.radio_manager", mock_rm): result = await reconnect_radio() assert result["status"] == "ok" @@ -210,7 +210,7 @@ class TestRebootAndReconnect: mock_rm.is_reconnecting = False mock_rm.reconnect = AsyncMock(return_value=False) - with patch("app.radio.radio_manager", mock_rm): + with patch("app.routers.radio.radio_manager", mock_rm): with pytest.raises(HTTPException) as exc: await reconnect_radio()