More stable MC object reference and proper radio disconnection detection

This commit is contained in:
Jack Kingsman
2026-02-23 19:11:58 -08:00
parent cba9e20698
commit 152eab99db
14 changed files with 342 additions and 126 deletions

View File

@@ -2,13 +2,14 @@ import logging
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.config import setup_logging
from app.database import db
from app.frontend_static import register_frontend_static_routes
from app.radio import radio_manager
from app.radio import RadioDisconnectedError, radio_manager
from app.radio_sync import (
stop_message_polling,
stop_periodic_advert,
@@ -75,6 +76,13 @@ app.add_middleware(
allow_headers=["*"],
)
@app.exception_handler(RadioDisconnectedError)
async def radio_disconnected_handler(request: Request, exc: RadioDisconnectedError):
"""Return 503 when a radio disconnect race occurs during an operation."""
return JSONResponse(status_code=503, content={"detail": "Radio not connected"})
# API routes - all prefixed with /api for production compatibility
app.include_router(health.router, prefix="/api")
app.include_router(radio.router, prefix="/api")

View File

@@ -20,6 +20,10 @@ class RadioOperationBusyError(RadioOperationError):
"""Raised when a non-blocking radio operation cannot acquire the lock."""
class RadioDisconnectedError(RadioOperationError):
"""Raised when the radio disconnects between pre-check and lock acquisition."""
@asynccontextmanager
async def _noop_context():
"""No-op async context manager for optional nesting."""
@@ -166,37 +170,48 @@ class RadioManager:
pause_polling: bool = False,
suspend_auto_fetch: bool = False,
blocking: bool = True,
meshcore: MeshCore | None = None,
):
"""Acquire shared radio lock and optionally pause polling / auto-fetch.
After acquiring the lock, resolves the current MeshCore instance and
yields it. Callers get a fresh reference via ``async with ... as mc:``,
avoiding stale-reference bugs when a reconnect swaps ``_meshcore``
between the pre-check and the lock acquisition.
Args:
name: Human-readable operation name for logs/errors.
pause_polling: Pause fallback message polling while held.
suspend_auto_fetch: Stop MeshCore auto message fetching while held.
blocking: If False, fail immediately when lock is held.
meshcore: Optional explicit MeshCore instance for auto-fetch control.
Raises:
RadioDisconnectedError: If the radio disconnected before the lock
was acquired (``_meshcore`` is ``None``).
"""
await self._acquire_operation_lock(name, blocking=blocking)
mc = self._meshcore
if mc is None:
self._release_operation_lock(name)
raise RadioDisconnectedError("Radio disconnected")
poll_context = _noop_context()
if pause_polling:
from app.radio_sync import pause_polling as pause_polling_context
poll_context = pause_polling_context()
mc = meshcore or self._meshcore
auto_fetch_paused = False
try:
async with poll_context:
if suspend_auto_fetch and mc is not None:
if suspend_auto_fetch:
await mc.stop_auto_message_fetching()
auto_fetch_paused = True
yield
yield mc
finally:
try:
if auto_fetch_paused and mc is not None:
if auto_fetch_paused:
try:
await mc.start_auto_message_fetching()
except Exception as e:

View File

@@ -573,13 +573,11 @@ async def sync_recent_contacts_to_radio(force: bool = False) -> dict:
logger.debug("Cannot sync contacts to radio: not connected")
return {"loaded": 0, "error": "Radio not connected"}
mc = radio_manager.meshcore
try:
async with radio_manager.radio_operation(
"sync_recent_contacts_to_radio",
blocking=False,
):
) as mc:
_last_contact_sync = now
# Build prioritized contact list:

View File

@@ -85,12 +85,12 @@ async def create_channel(request: CreateChannelRequest) -> Channel:
@router.post("/sync")
async def sync_channels_from_radio(max_channels: int = Query(default=40, ge=1, le=40)) -> dict:
"""Sync channels from the radio to the database."""
mc = require_connected()
require_connected()
logger.info("Syncing channels from radio (checking %d slots)", max_channels)
count = 0
async with radio_manager.radio_operation("sync_channels_from_radio"):
async with radio_manager.radio_operation("sync_channels_from_radio") as mc:
for idx in range(max_channels):
result = await mc.commands.get_channel(idx)

View File

@@ -264,11 +264,11 @@ async def get_contact(public_key: str) -> Contact:
@router.post("/sync")
async def sync_contacts_from_radio() -> dict:
"""Sync contacts from the radio to the database."""
mc = require_connected()
require_connected()
logger.info("Syncing contacts from radio")
async with radio_manager.radio_operation("sync_contacts_from_radio", meshcore=mc):
async with radio_manager.radio_operation("sync_contacts_from_radio") as mc:
result = await mc.commands.get_contacts()
if result.type == EventType.ERROR:
@@ -294,11 +294,11 @@ async def sync_contacts_from_radio() -> dict:
@router.post("/{public_key}/remove-from-radio")
async def remove_contact_from_radio(public_key: str) -> dict:
"""Remove a contact from the radio (keeps it in database)."""
mc = require_connected()
require_connected()
contact = await _resolve_contact_or_404(public_key)
async with radio_manager.radio_operation("remove_contact_from_radio", meshcore=mc):
async with radio_manager.radio_operation("remove_contact_from_radio") as mc:
# Get the contact from radio
radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12])
if not radio_contact:
@@ -322,11 +322,11 @@ async def remove_contact_from_radio(public_key: str) -> dict:
@router.post("/{public_key}/add-to-radio")
async def add_contact_to_radio(public_key: str) -> dict:
"""Add a contact from the database to the radio."""
mc = require_connected()
require_connected()
contact = await _resolve_contact_or_404(public_key, "Contact not found in database")
async with radio_manager.radio_operation("add_contact_to_radio", meshcore=mc):
async with radio_manager.radio_operation("add_contact_to_radio") as mc:
# Check if already on radio
radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12])
if radio_contact:
@@ -361,9 +361,8 @@ async def delete_contact(public_key: str) -> dict:
contact = await _resolve_contact_or_404(public_key)
# Remove from radio if connected and contact is on radio
if radio_manager.is_connected and radio_manager.meshcore:
mc = radio_manager.meshcore
async with radio_manager.radio_operation("delete_contact_from_radio", meshcore=mc):
if radio_manager.is_connected:
async with radio_manager.radio_operation("delete_contact_from_radio") as mc:
radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12])
if radio_contact:
logger.info(
@@ -385,7 +384,7 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem
The contact must be a repeater (type=2). If not on the radio, it will be added.
Uses login + status request with retry logic.
"""
mc = require_connected()
require_connected()
# Get contact from database
contact = await _resolve_contact_or_404(public_key)
@@ -399,10 +398,9 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem
async with radio_manager.radio_operation(
"request_telemetry",
meshcore=mc,
pause_polling=True,
suspend_auto_fetch=True,
):
) as mc:
# Prepare connection (add/remove dance + login)
await prepare_repeater_connection(mc, contact, request.password)
@@ -552,7 +550,7 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com
- reboot
- ver
"""
mc = require_connected()
require_connected()
# Get contact from database
contact = await _resolve_contact_or_404(public_key)
@@ -566,10 +564,9 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com
async with radio_manager.radio_operation(
"send_repeater_command",
meshcore=mc,
pause_polling=True,
suspend_auto_fetch=True,
):
) as mc:
# Add contact to radio with path from DB
logger.info("Adding repeater %s to radio", contact.public_key[:12])
await mc.commands.add_contact(contact.to_radio_dict())
@@ -621,7 +618,7 @@ async def request_trace(public_key: str) -> TraceResponse:
(no intermediate repeaters). The radio firmware requires at least one
node in the path.
"""
mc = require_connected()
require_connected()
contact = await _resolve_contact_or_404(public_key)
@@ -631,7 +628,7 @@ async def request_trace(public_key: str) -> TraceResponse:
# Trace does not need auto-fetch suspension: response arrives as TRACE_DATA
# from the reader loop, not via get_msg().
async with radio_manager.radio_operation("request_trace", pause_polling=True):
async with radio_manager.radio_operation("request_trace", pause_polling=True) as mc:
# Ensure contact is on radio so the trace can reach them
await mc.commands.add_contact(contact.to_radio_dict())

View File

@@ -43,7 +43,7 @@ async def list_messages(
@router.post("/direct", response_model=Message)
async def send_direct_message(request: SendDirectMessageRequest) -> Message:
"""Send a direct message to a contact."""
mc = require_connected()
require_connected()
# First check our database for the contact
from app.repository import ContactRepository
@@ -69,7 +69,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
# so we can't rely on it to know if the firmware has the contact.
# add_contact is idempotent - updates if exists, adds if not.
contact_data = db_contact.to_radio_dict()
async with radio_manager.radio_operation("send_direct_message"):
async with radio_manager.radio_operation("send_direct_message") as mc:
logger.debug("Ensuring contact %s is on radio before sending", db_contact.public_key[:12])
add_result = await mc.commands.add_contact(contact_data)
if add_result.type == EventType.ERROR:
@@ -163,7 +163,7 @@ TEMP_RADIO_SLOT = 0
@router.post("/channel", response_model=Message)
async def send_channel_message(request: SendChannelMessageRequest) -> Message:
"""Send a message to a channel."""
mc = require_connected()
require_connected()
# Get channel info from our database
from app.decoder import calculate_channel_hash
@@ -192,12 +192,14 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
expected_hash,
)
channel_key_upper = request.channel_key.upper()
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
text_with_sender = f"{radio_name}: {request.text}" if radio_name else request.text
message_id: int | None = None
now: int | None = None
radio_name: str = ""
text_with_sender: str = request.text
async with radio_manager.radio_operation("send_channel_message"):
async with radio_manager.radio_operation("send_channel_message") as mc:
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
text_with_sender = f"{radio_name}: {request.text}" if radio_name else request.text
# Load the channel to a temporary radio slot before sending
set_result = await mc.commands.set_channel(
channel_idx=TEMP_RADIO_SLOT,
@@ -318,7 +320,7 @@ async def resend_channel_message(
When new_timestamp=True: resend with a fresh timestamp so repeaters treat it as a
new packet. Creates a new message row in the database. No time window restriction.
"""
mc = require_connected()
require_connected()
from app.repository import ChannelRepository
@@ -352,12 +354,6 @@ async def resend_channel_message(
else:
timestamp_bytes = msg.sender_timestamp.to_bytes(4, "little")
# Strip sender prefix: DB stores "RadioName: message" but radio needs "message"
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
text_to_send = msg.text
if radio_name and text_to_send.startswith(f"{radio_name}: "):
text_to_send = text_to_send[len(f"{radio_name}: ") :]
try:
key_bytes = bytes.fromhex(msg.conversation_key)
except ValueError:
@@ -365,7 +361,13 @@ async def resend_channel_message(
status_code=400, detail=f"Invalid channel key format: {msg.conversation_key}"
) from None
async with radio_manager.radio_operation("resend_channel_message"):
async with radio_manager.radio_operation("resend_channel_message") as mc:
# Strip sender prefix: DB stores "RadioName: message" but radio needs "message"
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
text_to_send = msg.text
if radio_name and text_to_send.startswith(f"{radio_name}: "):
text_to_send = text_to_send[len(f"{radio_name}: ") :]
set_result = await mc.commands.set_channel(
channel_idx=TEMP_RADIO_SLOT,
channel_name=db_channel.name,

View File

@@ -70,9 +70,9 @@ async def get_radio_config() -> RadioConfigResponse:
@router.patch("/config", response_model=RadioConfigResponse)
async def update_radio_config(update: RadioConfigUpdate) -> RadioConfigResponse:
"""Update radio configuration. Only provided fields will be updated."""
mc = require_connected()
require_connected()
async with radio_manager.radio_operation("update_radio_config"):
async with radio_manager.radio_operation("update_radio_config") as mc:
if update.name is not None:
logger.info("Setting radio name to %s", update.name)
await mc.commands.set_name(update.name)
@@ -117,7 +117,7 @@ async def update_radio_config(update: RadioConfigUpdate) -> RadioConfigResponse:
@router.put("/private-key")
async def set_private_key(update: PrivateKeyUpdate) -> dict:
"""Set the radio's private key. This is write-only."""
mc = require_connected()
require_connected()
try:
key_bytes = bytes.fromhex(update.private_key)
@@ -125,7 +125,7 @@ async def set_private_key(update: PrivateKeyUpdate) -> dict:
raise HTTPException(status_code=400, detail="Invalid hex string for private key") from None
logger.info("Importing private key")
async with radio_manager.radio_operation("import_private_key", meshcore=mc):
async with radio_manager.radio_operation("import_private_key") as mc:
result = await mc.commands.import_private_key(key_bytes)
if result.type == EventType.ERROR:
@@ -167,13 +167,10 @@ async def reboot_radio() -> dict:
If not connected: attempts to reconnect (same as /reconnect endpoint).
"""
# If connected, send reboot command
if radio_manager.is_connected and radio_manager.meshcore:
if radio_manager.is_connected:
logger.info("Rebooting radio")
async with radio_manager.radio_operation(
"reboot_radio",
meshcore=radio_manager.meshcore,
):
await radio_manager.meshcore.commands.reboot()
async with radio_manager.radio_operation("reboot_radio") as mc:
await mc.commands.reboot()
return {
"status": "ok",
"message": "Reboot command sent. Radio will reconnect automatically.",

View File

@@ -12,6 +12,7 @@ import httpx
import pytest
from app.database import Database
from app.radio import radio_manager
from app.repository import (
ChannelRepository,
ContactRepository,
@@ -20,6 +21,16 @@ from app.repository import (
)
@pytest.fixture(autouse=True)
def _reset_radio_state():
"""Save/restore radio_manager state so tests don't leak."""
prev = radio_manager._meshcore
prev_lock = radio_manager._operation_lock
yield
radio_manager._meshcore = prev
radio_manager._operation_lock = prev_lock
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
@@ -109,6 +120,29 @@ class TestHealthEndpoint:
assert data["connection_info"] is None
class TestRadioDisconnectedHandler:
"""Test that RadioDisconnectedError maps to 503."""
@pytest.mark.asyncio
async def test_disconnect_race_returns_503(self, test_db, client):
"""If radio disconnects between require_connected() and lock acquisition, return 503."""
pub_key = "ab" * 32
await _insert_contact(pub_key, "Alice")
# require_connected() passes, but _meshcore is None when radio_operation() checks
radio_manager._meshcore = None
with patch("app.dependencies.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = MagicMock()
response = await client.post(
"/api/messages/direct", json={"destination": pub_key, "text": "Hi"}
)
assert response.status_code == 503
assert "not connected" in response.json()["detail"].lower()
class TestMessagesEndpoint:
"""Test message-related endpoints."""
@@ -161,6 +195,7 @@ class TestMessagesEndpoint:
coro.close()
return MagicMock()
radio_manager._meshcore = mock_mc
with (
patch("app.dependencies.radio_manager") as mock_rm,
patch("app.bot.run_bot_for_message", new=AsyncMock()),
@@ -204,6 +239,7 @@ class TestMessagesEndpoint:
coro.close()
return MagicMock()
radio_manager._meshcore = mock_mc
with (
patch("app.dependencies.radio_manager") as mock_rm,
patch("app.decoder.calculate_channel_hash", return_value="abcd"),
@@ -260,6 +296,7 @@ class TestMessagesEndpoint:
return_value=MagicMock(type=MagicMock(name="OK"), payload={"expected_ack": b"\x00\x01"})
)
radio_manager._meshcore = mock_mc
with (
patch("app.dependencies.radio_manager") as mock_rm,
patch("app.routers.messages.MessageRepository") as mock_msg_repo,
@@ -296,6 +333,7 @@ class TestMessagesEndpoint:
return_value=MagicMock(type=MagicMock(name="OK"), payload={})
)
radio_manager._meshcore = mock_mc
with (
patch("app.dependencies.radio_manager") as mock_rm,
patch("app.routers.messages.MessageRepository") as mock_msg_repo,
@@ -355,6 +393,7 @@ class TestMessagesEndpoint:
return_value=MagicMock(type=EventType.MSG_SENT, payload={})
)
radio_manager._meshcore = mock_mc
with patch("app.dependencies.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc

View File

@@ -14,6 +14,7 @@ import pytest
from meshcore import EventType
from app.database import Database
from app.radio import radio_manager
from app.repository import ContactRepository, MessageRepository
# Sample 64-char hex public keys for testing
@@ -22,9 +23,24 @@ KEY_B = "bb" * 32 # bbbb...bb
KEY_C = "cc" * 32 # cccc...cc
@asynccontextmanager
async def _noop_radio_operation(*_args, **_kwargs):
def _noop_radio_operation(mc=None):
"""Factory for a no-op radio_operation context manager that yields mc."""
@asynccontextmanager
async def _ctx(*_args, **_kwargs):
yield mc
return _ctx
@pytest.fixture(autouse=True)
def _reset_radio_state():
"""Save/restore radio_manager state so tests don't leak."""
prev = radio_manager._meshcore
prev_lock = radio_manager._operation_lock
yield
radio_manager._meshcore = prev
radio_manager._operation_lock = prev_lock
@pytest.fixture
@@ -225,7 +241,7 @@ class TestDeleteContact:
with patch("app.routers.contacts.radio_manager") as mock_rm:
mock_rm.is_connected = False
mock_rm.meshcore = None
mock_rm.radio_operation = _noop_radio_operation
mock_rm.radio_operation = _noop_radio_operation()
response = await client.delete(f"/api/contacts/{KEY_A}")
@@ -255,7 +271,7 @@ class TestDeleteContact:
with patch("app.routers.contacts.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
mock_rm.radio_operation = _noop_radio_operation
mock_rm.radio_operation = _noop_radio_operation(mock_mc)
response = await client.delete(f"/api/contacts/{KEY_A}")
@@ -277,6 +293,7 @@ class TestSyncContacts:
}
mock_mc.commands.get_contacts = AsyncMock(return_value=mock_result)
radio_manager._meshcore = mock_mc
with patch("app.dependencies.radio_manager") as mock_dep_rm:
mock_dep_rm.is_connected = True
mock_dep_rm.meshcore = mock_mc
@@ -318,6 +335,7 @@ class TestSyncContacts:
mock_result.payload = {KEY_A: {"adv_name": "Alice", "type": 1, "flags": 0}}
mock_mc.commands.get_contacts = AsyncMock(return_value=mock_result)
radio_manager._meshcore = mock_mc
with patch("app.dependencies.radio_manager") as mock_dep_rm:
mock_dep_rm.is_connected = True
mock_dep_rm.meshcore = mock_mc
@@ -345,6 +363,7 @@ class TestAddRemoveRadio:
mock_result.type = EventType.OK
mock_mc.commands.add_contact = AsyncMock(return_value=mock_result)
radio_manager._meshcore = mock_mc
with patch("app.dependencies.radio_manager") as mock_dep_rm:
mock_dep_rm.is_connected = True
mock_dep_rm.meshcore = mock_mc
@@ -366,6 +385,7 @@ class TestAddRemoveRadio:
mock_mc = MagicMock()
mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # On radio
radio_manager._meshcore = mock_mc
with patch("app.dependencies.radio_manager") as mock_dep_rm:
mock_dep_rm.is_connected = True
mock_dep_rm.meshcore = mock_mc
@@ -386,6 +406,7 @@ class TestAddRemoveRadio:
mock_result.type = EventType.OK
mock_mc.commands.remove_contact = AsyncMock(return_value=mock_result)
radio_manager._meshcore = mock_mc
with patch("app.dependencies.radio_manager") as mock_dep_rm:
mock_dep_rm.is_connected = True
mock_dep_rm.meshcore = mock_mc

View File

@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.radio import RadioOperationBusyError, radio_manager
from app.radio import RadioDisconnectedError, RadioOperationBusyError, radio_manager
from app.radio_sync import is_polling_paused
@@ -14,7 +14,9 @@ def reset_radio_operation_state():
"""Reset shared radio operation lock state before/after each test."""
prev_meshcore = radio_manager._meshcore
radio_manager._operation_lock = None
radio_manager._meshcore = None
# Default to a non-None MagicMock so radio_operation() doesn't raise
# RadioDisconnectedError for tests that only exercise locking.
radio_manager._meshcore = MagicMock()
import app.radio_sync as radio_sync
@@ -123,3 +125,47 @@ class TestRadioOperationLock:
assert is_polling_paused()
assert not is_polling_paused()
class TestRadioOperationYield:
"""Validate that radio_operation() yields the current meshcore instance."""
@pytest.mark.asyncio
async def test_radio_operation_yields_current_meshcore(self):
"""The yielded value is the current _meshcore at lock-acquisition time."""
mc = MagicMock()
radio_manager._meshcore = mc
async with radio_manager.radio_operation("test_yield") as yielded:
assert yielded is mc
@pytest.mark.asyncio
async def test_radio_operation_raises_when_disconnected_after_lock(self):
"""RadioDisconnectedError is raised when _meshcore is None after acquiring the lock."""
radio_manager._meshcore = None
with pytest.raises(RadioDisconnectedError):
async with radio_manager.radio_operation("test_disconnected"):
pass # pragma: no cover
# Lock must be released even after the error
radio_manager._meshcore = MagicMock()
async with radio_manager.radio_operation("after_error", blocking=False):
pass
@pytest.mark.asyncio
async def test_radio_operation_yields_fresh_reference_after_swap(self):
"""If _meshcore is swapped between pre-check and lock acquisition,
the yielded value is the new (current) instance, not the old one."""
old_mc = MagicMock(name="old")
new_mc = MagicMock(name="new")
# Start with old_mc
radio_manager._meshcore = old_mc
# Simulate a reconnect swapping _meshcore before the caller enters the block
radio_manager._meshcore = new_mc
async with radio_manager.radio_operation("test_swap") as yielded:
assert yielded is new_mc
assert yielded is not old_mc

View File

@@ -8,7 +8,7 @@ import pytest
from fastapi import HTTPException
from meshcore import EventType
from app.radio import RadioManager
from app.radio import RadioManager, radio_manager
from app.routers.radio import (
PrivateKeyUpdate,
RadioConfigResponse,
@@ -30,9 +30,24 @@ def _radio_result(event_type=EventType.OK, payload=None):
return result
@asynccontextmanager
async def _noop_radio_operation(*_args, **_kwargs):
def _noop_radio_operation(mc=None):
"""Factory for a no-op radio_operation context manager that yields mc."""
@asynccontextmanager
async def _ctx(*_args, **_kwargs):
yield mc
return _ctx
@pytest.fixture(autouse=True)
def _reset_radio_state():
"""Save/restore radio_manager state so tests don't leak."""
prev = radio_manager._meshcore
prev_lock = radio_manager._operation_lock
yield
radio_manager._meshcore = prev
radio_manager._operation_lock = prev_lock
def _mock_meshcore_with_info():
@@ -100,6 +115,7 @@ class TestUpdateRadioConfig:
with (
patch("app.routers.radio.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.radio.sync_radio_time", new_callable=AsyncMock) as mock_sync_time,
patch(
"app.routers.radio.get_radio_config", new_callable=AsyncMock, return_value=expected
@@ -132,7 +148,10 @@ class TestPrivateKeyImport:
mc.commands.import_private_key = AsyncMock(
return_value=_radio_result(EventType.ERROR, {"error": "failed"})
)
with patch("app.routers.radio.require_connected", return_value=mc):
with (
patch("app.routers.radio.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
await set_private_key(PrivateKeyUpdate(private_key="aa" * 64))
@@ -142,6 +161,7 @@ class TestPrivateKeyImport:
class TestAdvertise:
@pytest.mark.asyncio
async def test_raises_when_send_fails(self):
radio_manager._meshcore = MagicMock()
with (
patch("app.routers.radio.require_connected"),
patch(
@@ -170,6 +190,7 @@ class TestAdvertise:
return True
isolated_manager = RadioManager()
isolated_manager._meshcore = MagicMock()
with (
patch("app.routers.radio.require_connected"),
patch("app.routers.radio.radio_manager", isolated_manager),
@@ -187,17 +208,19 @@ class TestAdvertise:
class TestRebootAndReconnect:
@pytest.mark.asyncio
async def test_reboot_connected_sends_reboot_command(self):
mock_mc = MagicMock()
mock_mc.commands.reboot = AsyncMock()
mock_rm = MagicMock()
mock_rm.is_connected = True
mock_rm.meshcore = MagicMock()
mock_rm.meshcore.commands.reboot = AsyncMock()
mock_rm.radio_operation = _noop_radio_operation
mock_rm.meshcore = mock_mc
mock_rm.radio_operation = _noop_radio_operation(mock_mc)
with patch("app.routers.radio.radio_manager", mock_rm):
result = await reboot_radio()
assert result["status"] == "ok"
mock_rm.meshcore.commands.reboot.assert_awaited_once()
mock_mc.commands.reboot.assert_awaited_once()
@pytest.mark.asyncio
async def test_reboot_returns_pending_when_reconnect_in_progress(self):
@@ -205,7 +228,7 @@ class TestRebootAndReconnect:
mock_rm.is_connected = False
mock_rm.meshcore = None
mock_rm.is_reconnecting = True
mock_rm.radio_operation = _noop_radio_operation
mock_rm.radio_operation = _noop_radio_operation()
with patch("app.routers.radio.radio_manager", mock_rm):
result = await reboot_radio()
@@ -221,7 +244,7 @@ class TestRebootAndReconnect:
mock_rm.is_reconnecting = False
mock_rm.reconnect = AsyncMock(return_value=True)
mock_rm.post_connect_setup = AsyncMock()
mock_rm.radio_operation = _noop_radio_operation
mock_rm.radio_operation = _noop_radio_operation()
with patch("app.routers.radio.radio_manager", mock_rm):
result = await reboot_radio()
@@ -235,7 +258,7 @@ class TestRebootAndReconnect:
async def test_reconnect_returns_already_connected(self):
mock_rm = MagicMock()
mock_rm.is_connected = True
mock_rm.radio_operation = _noop_radio_operation
mock_rm.radio_operation = _noop_radio_operation()
with patch("app.routers.radio.radio_manager", mock_rm):
result = await reconnect_radio()
@@ -249,7 +272,7 @@ class TestRebootAndReconnect:
mock_rm.is_connected = False
mock_rm.is_reconnecting = False
mock_rm.reconnect = AsyncMock(return_value=False)
mock_rm.radio_operation = _noop_radio_operation
mock_rm.radio_operation = _noop_radio_operation()
with patch("app.routers.radio.radio_manager", mock_rm):
with pytest.raises(HTTPException) as exc:

View File

@@ -11,6 +11,7 @@ from meshcore import EventType
from app.database import Database
from app.models import Favorite
from app.radio import radio_manager
from app.radio_sync import (
is_polling_paused,
pause_polling,
@@ -45,14 +46,19 @@ async def test_db():
@pytest.fixture(autouse=True)
def reset_sync_state():
"""Reset polling pause state and sync timestamp before and after each test."""
"""Reset polling pause state, sync timestamp, and radio_manager before/after each test."""
import app.radio_sync as radio_sync
prev_mc = radio_manager._meshcore
prev_lock = radio_manager._operation_lock
radio_sync._polling_pause_count = 0
radio_sync._last_contact_sync = 0.0
yield
radio_sync._polling_pause_count = 0
radio_sync._last_contact_sync = 0.0
radio_manager._meshcore = prev_mc
radio_manager._operation_lock = prev_lock
KEY_A = "aa" * 32
@@ -236,11 +242,8 @@ class TestSyncRecentContactsToRadio:
mock_result.type = EventType.OK
mock_mc.commands.add_contact = AsyncMock(return_value=mock_result)
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
radio_manager._meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
assert result["loaded"] == 2
# Verify contacts are now marked as on_radio in DB
@@ -268,11 +271,8 @@ class TestSyncRecentContactsToRadio:
mock_result.type = EventType.OK
mock_mc.commands.add_contact = AsyncMock(return_value=mock_result)
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
radio_manager._meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
assert result["loaded"] == 2
# KEY_A (favorite) should be loaded first, then KEY_B (most recent)
@@ -298,11 +298,8 @@ class TestSyncRecentContactsToRadio:
mock_result.type = EventType.OK
mock_mc.commands.add_contact = AsyncMock(return_value=mock_result)
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
radio_manager._meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
assert result["loaded"] == 2
loaded_keys = [
@@ -319,11 +316,8 @@ class TestSyncRecentContactsToRadio:
mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # Found
mock_mc.commands.add_contact = AsyncMock()
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
radio_manager._meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
assert result["loaded"] == 0
assert result["already_on_radio"] == 1
@@ -335,34 +329,30 @@ class TestSyncRecentContactsToRadio:
mock_mc = MagicMock()
mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None)
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
radio_manager._meshcore = mock_mc
# First call succeeds
result1 = await sync_recent_contacts_to_radio()
assert "throttled" not in result1
# First call succeeds
result1 = await sync_recent_contacts_to_radio()
assert "throttled" not in result1
# Second call is throttled
result2 = await sync_recent_contacts_to_radio()
assert result2["throttled"] is True
assert result2["loaded"] == 0
# Second call is throttled
result2 = await sync_recent_contacts_to_radio()
assert result2["throttled"] is True
assert result2["loaded"] == 0
@pytest.mark.asyncio
async def test_force_bypasses_throttle(self, test_db):
"""force=True bypasses the throttle window."""
mock_mc = MagicMock()
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
radio_manager._meshcore = mock_mc
# First call
await sync_recent_contacts_to_radio()
# First call
await sync_recent_contacts_to_radio()
# Forced second call is not throttled
result = await sync_recent_contacts_to_radio(force=True)
assert "throttled" not in result
# Forced second call is not throttled
result = await sync_recent_contacts_to_radio(force=True)
assert "throttled" not in result
@pytest.mark.asyncio
async def test_not_connected_returns_error(self):
@@ -384,11 +374,8 @@ class TestSyncRecentContactsToRadio:
mock_mc = MagicMock()
mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # Found
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
radio_manager._meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
assert result["already_on_radio"] == 1
# Should update the flag since contact.on_radio was False
@@ -407,15 +394,37 @@ class TestSyncRecentContactsToRadio:
mock_result.payload = {"error": "Radio full"}
mock_mc.commands.add_contact = AsyncMock(return_value=mock_result)
with patch("app.radio_sync.radio_manager") as mock_rm:
mock_rm.is_connected = True
mock_rm.meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
radio_manager._meshcore = mock_mc
result = await sync_recent_contacts_to_radio()
assert result["loaded"] == 0
assert result["failed"] == 1
@pytest.mark.asyncio
async def test_uses_post_lock_meshcore_after_swap(self, test_db):
"""If _meshcore is swapped between pre-check and lock acquisition,
the function uses the new (post-lock) instance, not the stale one."""
await _insert_contact(KEY_A, "Alice", last_contacted=2000)
old_mc = MagicMock(name="old_mc")
new_mc = MagicMock(name="new_mc")
new_mc.get_contact_by_key_prefix = MagicMock(return_value=None)
mock_result = MagicMock()
mock_result.type = EventType.OK
new_mc.commands.add_contact = AsyncMock(return_value=mock_result)
# Pre-check sees old_mc (truthy, passes is_connected guard)
radio_manager._meshcore = old_mc
# Simulate reconnect swapping _meshcore before lock acquisition
radio_manager._meshcore = new_mc
result = await sync_recent_contacts_to_radio()
assert result["loaded"] == 1
# new_mc was used, not old_mc
new_mc.commands.add_contact.assert_called_once()
old_mc.commands.add_contact.assert_not_called()
class TestSyncAndOffloadContacts:
"""Test sync_and_offload_contacts: pull contacts from radio, save to DB, remove from radio."""

View File

@@ -8,6 +8,7 @@ from meshcore import EventType
from app.database import Database
from app.models import CommandRequest, TelemetryRequest
from app.radio import radio_manager
from app.repository import ContactRepository
from app.routers.contacts import (
_fetch_repeater_response,
@@ -23,6 +24,16 @@ KEY_A = "aa" * 32
_MONOTONIC = "app.routers.contacts._monotonic"
@pytest.fixture(autouse=True)
def _reset_radio_state():
"""Save/restore radio_manager state so tests don't leak."""
prev = radio_manager._meshcore
prev_lock = radio_manager._operation_lock
yield
radio_manager._meshcore = prev
radio_manager._operation_lock = prev_lock
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
@@ -263,7 +274,10 @@ class TestTelemetryRoute:
@pytest.mark.asyncio
async def test_returns_404_when_contact_missing(self, test_db):
mc = _mock_mc()
with patch("app.routers.contacts.require_connected", return_value=mc):
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
await request_telemetry(KEY_A, TelemetryRequest(password="pw"))
@@ -274,7 +288,10 @@ class TestTelemetryRoute:
mc = _mock_mc()
await _insert_contact(KEY_A, name="Client", contact_type=1)
with patch("app.routers.contacts.require_connected", return_value=mc):
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
await request_telemetry(KEY_A, TelemetryRequest(password="pw"))
@@ -289,6 +306,7 @@ class TestTelemetryRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(
"app.routers.contacts.prepare_repeater_connection",
new_callable=AsyncMock,
@@ -329,6 +347,7 @@ class TestTelemetryRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(
"app.routers.contacts.prepare_repeater_connection",
new_callable=AsyncMock,
@@ -406,6 +425,7 @@ class TestTelemetryRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(
"app.routers.contacts.prepare_repeater_connection",
new_callable=AsyncMock,
@@ -462,6 +482,7 @@ class TestTelemetryRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(
"app.routers.contacts.prepare_repeater_connection",
new_callable=AsyncMock,
@@ -485,7 +506,10 @@ class TestRepeaterCommandRoute:
return_value=_radio_result(EventType.ERROR, {"err": "bad"})
)
with patch("app.routers.contacts.require_connected", return_value=mc):
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
await send_repeater_command(KEY_A, CommandRequest(command="ver"))
@@ -502,6 +526,7 @@ class TestRepeaterCommandRoute:
# Expire the deadline after a couple of ticks
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=[0.0, 5.0, 25.0]),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
):
@@ -530,6 +555,7 @@ class TestRepeaterCommandRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
response = await send_repeater_command(KEY_A, CommandRequest(command="ver"))
@@ -557,6 +583,7 @@ class TestRepeaterCommandRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
response = await send_repeater_command(KEY_A, CommandRequest(command="ver"))
@@ -584,6 +611,7 @@ class TestRepeaterCommandRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
response = await send_repeater_command(KEY_A, CommandRequest(command="ver"))
@@ -609,6 +637,7 @@ class TestRepeaterCommandRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
response = await send_repeater_command(KEY_A, CommandRequest(command="ver"))
@@ -631,6 +660,7 @@ class TestRepeaterCommandRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
):
@@ -651,6 +681,7 @@ class TestTraceRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.contacts.random.randint", return_value=1234),
):
with pytest.raises(HTTPException) as exc:
@@ -667,6 +698,7 @@ class TestTraceRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.contacts.random.randint", return_value=1234),
):
with pytest.raises(HTTPException) as exc:
@@ -685,6 +717,7 @@ class TestTraceRoute:
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.contacts.random.randint", return_value=1234),
):
response = await request_trace(KEY_A)

View File

@@ -13,6 +13,7 @@ from app.models import (
SendChannelMessageRequest,
SendDirectMessageRequest,
)
from app.radio import radio_manager
from app.repository import (
ChannelRepository,
ContactRepository,
@@ -25,6 +26,16 @@ from app.routers.messages import (
)
@pytest.fixture(autouse=True)
def _reset_radio_state():
"""Save/restore radio_manager state so tests don't leak."""
prev = radio_manager._meshcore
prev_lock = radio_manager._operation_lock
yield
radio_manager._meshcore = prev
radio_manager._operation_lock = prev_lock
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
@@ -96,6 +107,7 @@ class TestOutgoingDMBotTrigger:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot,
):
request = SendDirectMessageRequest(destination=pub_key, text="!lasttime Alice")
@@ -127,6 +139,7 @@ class TestOutgoingDMBotTrigger:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.bot.run_bot_for_message", new=slow_bot),
):
request = SendDirectMessageRequest(destination=pub_key, text="Hello")
@@ -144,6 +157,7 @@ class TestOutgoingDMBotTrigger:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot,
):
request = SendDirectMessageRequest(destination=pub_key, text="test")
@@ -184,6 +198,7 @@ class TestOutgoingChannelBotTrigger:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.decoder.calculate_channel_hash", return_value="abcd"),
patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot,
):
@@ -210,6 +225,7 @@ class TestOutgoingChannelBotTrigger:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.decoder.calculate_channel_hash", return_value="abcd"),
patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot,
):
@@ -234,6 +250,7 @@ class TestOutgoingChannelBotTrigger:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.decoder.calculate_channel_hash", return_value="abcd"),
patch("app.bot.run_bot_for_message", new=slow_bot),
):
@@ -250,6 +267,7 @@ class TestOutgoingChannelBotTrigger:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.decoder.calculate_channel_hash", return_value="abcd"),
patch("app.bot.run_bot_for_message", new=AsyncMock()),
):
@@ -282,7 +300,10 @@ class TestResendChannelMessage:
)
assert msg_id is not None
with patch("app.routers.messages.require_connected", return_value=mc):
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
result = await resend_channel_message(msg_id, new_timestamp=False)
assert result["status"] == "ok"
@@ -341,6 +362,7 @@ class TestResendChannelMessage:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.messages.broadcast_event"),
patch("app.routers.messages.time") as mock_time,
):
@@ -436,7 +458,10 @@ class TestResendChannelMessage:
)
assert msg_id is not None
with patch("app.routers.messages.require_connected", return_value=mc):
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
await resend_channel_message(msg_id, new_timestamp=False)
call_kwargs = mc.commands.send_chan_msg.await_args.kwargs
@@ -462,6 +487,7 @@ class TestResendChannelMessage:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.messages.broadcast_event"),
):
result = await resend_channel_message(msg_id, new_timestamp=True)
@@ -490,6 +516,7 @@ class TestResendChannelMessage:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.messages.broadcast_event"),
):
result = await resend_channel_message(msg_id, new_timestamp=True)
@@ -524,6 +551,7 @@ class TestResendChannelMessage:
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.messages.broadcast_event") as mock_broadcast,
):
result = await resend_channel_message(msg_id, new_timestamp=True)