From 47867c50b80ec375460411fc4c6ee15c1d470efc Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 23 Feb 2026 20:37:32 -0800 Subject: [PATCH] Fix TOCTOU around radio reconnect --- app/radio.py | 11 ++- frontend/src/components/MessageList.tsx | 15 ++- frontend/src/components/RawPacketList.tsx | 9 +- tests/test_radio.py | 108 ++++++++++++++++++++++ 4 files changed, 132 insertions(+), 11 deletions(-) diff --git a/app/radio.py b/app/radio.py index 116089c..1165dd7 100644 --- a/app/radio.py +++ b/app/radio.py @@ -399,12 +399,13 @@ class RadioManager: if self._reconnect_lock is None: self._reconnect_lock = asyncio.Lock() - # Try to acquire lock without blocking to check if reconnect is in progress - if self._reconnect_lock.locked(): - logger.debug("Reconnection already in progress") - return False - async with self._reconnect_lock: + # If we became connected while waiting for the lock (another + # reconnect succeeded ahead of us), skip the redundant attempt. + if self.is_connected: + logger.debug("Already connected after acquiring lock, skipping reconnect") + return True + logger.info("Attempting to reconnect to radio...") try: diff --git a/frontend/src/components/MessageList.tsx b/frontend/src/components/MessageList.tsx index 7555ae4..03794a8 100644 --- a/frontend/src/components/MessageList.tsx +++ b/frontend/src/components/MessageList.tsx @@ -261,7 +261,16 @@ export function MessageList({ }; }, [messages, onResendChannelMessage]); + // Refs for scroll handler to read without causing callback recreation + const onLoadOlderRef = useRef(onLoadOlder); + const loadingOlderRef = useRef(loadingOlder); + const hasOlderMessagesRef = useRef(hasOlderMessages); + onLoadOlderRef.current = onLoadOlder; + loadingOlderRef.current = loadingOlder; + hasOlderMessagesRef.current = hasOlderMessages; + // Handle scroll - capture state and detect when user is near top/bottom + // Stable callback: reads changing values from refs, never recreated. const handleScroll = useCallback(() => { if (!listRef.current) return; @@ -280,13 +289,13 @@ export function MessageList({ // Show scroll-to-bottom button when not near the bottom (more than 100px away) setShowScrollToBottom(distanceFromBottom > 100); - if (!onLoadOlder || loadingOlder || !hasOlderMessages) return; + if (!onLoadOlderRef.current || loadingOlderRef.current || !hasOlderMessagesRef.current) return; // Trigger load when within 100px of top if (scrollTop < 100) { - onLoadOlder(); + onLoadOlderRef.current(); } - }, [onLoadOlder, loadingOlder, hasOlderMessages]); + }, []); // Scroll to bottom handler const scrollToBottom = useCallback(() => { diff --git a/frontend/src/components/RawPacketList.tsx b/frontend/src/components/RawPacketList.tsx index 3f08e3e..35b7a8e 100644 --- a/frontend/src/components/RawPacketList.tsx +++ b/frontend/src/components/RawPacketList.tsx @@ -188,6 +188,12 @@ export function RawPacketList({ packets }: RawPacketListProps) { })); }, [packets]); + // Sort packets by timestamp ascending (oldest first) + const sortedPackets = useMemo( + () => [...decodedPackets].sort((a, b) => a.packet.timestamp - b.packet.timestamp), + [decodedPackets] + ); + useEffect(() => { if (listRef.current) { listRef.current.scrollTop = listRef.current.scrollHeight; @@ -202,9 +208,6 @@ export function RawPacketList({ packets }: RawPacketListProps) { ); } - // Sort packets by timestamp ascending (oldest first) - const sortedPackets = [...decodedPackets].sort((a, b) => a.packet.timestamp - b.packet.timestamp); - return (
{sortedPackets.map(({ packet, decoded }) => ( diff --git a/tests/test_radio.py b/tests/test_radio.py index ed04db8..18997a2 100644 --- a/tests/test_radio.py +++ b/tests/test_radio.py @@ -222,6 +222,114 @@ class TestConnectionMonitor: assert rm._last_connected is False +class TestReconnectLock: + """Tests for reconnect() lock serialization — no duplicate reconnections.""" + + @pytest.mark.asyncio + async def test_concurrent_reconnects_only_connect_once(self): + """Two concurrent reconnect() calls should only call connect() once.""" + from app.radio import RadioManager + + rm = RadioManager() + rm._meshcore = None + + connect_count = 0 + + async def mock_connect(): + nonlocal connect_count + connect_count += 1 + # Simulate connect taking some time + await asyncio.sleep(0.05) + mock_mc = MagicMock() + mock_mc.is_connected = True + rm._meshcore = mock_mc + rm._connection_info = "TCP: test:4000" + + rm.connect = AsyncMock(side_effect=mock_connect) + + with ( + patch("app.websocket.broadcast_health"), + patch("app.websocket.broadcast_error"), + ): + result_a, result_b = await asyncio.gather( + rm.reconnect(broadcast_on_success=False), + rm.reconnect(broadcast_on_success=False), + ) + + # First caller does the real connect, second sees is_connected=True + assert connect_count == 1 + assert result_a is True + assert result_b is True + + @pytest.mark.asyncio + async def test_second_reconnect_skips_when_first_succeeds(self): + """Second caller returns True without connecting when first already succeeded.""" + from app.radio import RadioManager + + rm = RadioManager() + rm._meshcore = None + + call_order: list[str] = [] + + async def mock_connect(): + call_order.append("connect") + await asyncio.sleep(0.05) + mock_mc = MagicMock() + mock_mc.is_connected = True + rm._meshcore = mock_mc + rm._connection_info = "TCP: test:4000" + + rm.connect = AsyncMock(side_effect=mock_connect) + + with ( + patch("app.websocket.broadcast_health"), + patch("app.websocket.broadcast_error"), + ): + await asyncio.gather( + rm.reconnect(broadcast_on_success=False), + rm.reconnect(broadcast_on_success=False), + ) + + # connect should appear exactly once + assert call_order == ["connect"] + + @pytest.mark.asyncio + async def test_reconnect_retries_after_first_failure(self): + """If first reconnect fails, a subsequent call should attempt connect again.""" + from app.radio import RadioManager + + rm = RadioManager() + rm._meshcore = None + + attempt = 0 + + async def mock_connect(): + nonlocal attempt + attempt += 1 + if attempt == 1: + # First attempt fails + return + # Second attempt succeeds + mock_mc = MagicMock() + mock_mc.is_connected = True + rm._meshcore = mock_mc + rm._connection_info = "TCP: test:4000" + + rm.connect = AsyncMock(side_effect=mock_connect) + + with ( + patch("app.websocket.broadcast_health"), + patch("app.websocket.broadcast_error"), + ): + result1 = await rm.reconnect(broadcast_on_success=False) + assert result1 is False + assert attempt == 1 + + result2 = await rm.reconnect(broadcast_on_success=False) + assert result2 is True + assert attempt == 2 + + class TestSerialDeviceProbe: """Tests for test_serial_device() — verifies cleanup on all exit paths."""