mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-03-28 17:43:05 +01:00
248 lines
8.5 KiB
Python
248 lines
8.5 KiB
Python
"""Tests for WebSocket manager functionality."""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.websocket import SEND_TIMEOUT_SECONDS, WebSocketManager
|
|
|
|
|
|
@pytest.fixture
|
|
def ws_manager():
|
|
"""Create a fresh WebSocketManager for each test."""
|
|
return WebSocketManager()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket():
|
|
"""Create a mock WebSocket connection."""
|
|
ws = AsyncMock()
|
|
ws.send_text = AsyncMock()
|
|
return ws
|
|
|
|
|
|
class TestWebSocketBroadcast:
|
|
"""Tests for the broadcast functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_sends_to_all_clients(self, ws_manager: WebSocketManager):
|
|
"""Broadcast should send message to all connected clients."""
|
|
ws1 = AsyncMock()
|
|
ws2 = AsyncMock()
|
|
ws1.accept = AsyncMock()
|
|
ws2.accept = AsyncMock()
|
|
|
|
await ws_manager.connect(ws1)
|
|
await ws_manager.connect(ws2)
|
|
|
|
await ws_manager.broadcast("test", {"key": "value"})
|
|
|
|
# Both clients should receive the message
|
|
ws1.send_text.assert_called_once()
|
|
ws2.send_text.assert_called_once()
|
|
|
|
# Verify the message format
|
|
import json
|
|
|
|
expected = json.dumps({"type": "test", "data": {"key": "value"}})
|
|
ws1.send_text.assert_called_with(expected)
|
|
ws2.send_text.assert_called_with(expected)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_removes_failed_clients(self, ws_manager: WebSocketManager):
|
|
"""Clients that fail to receive should be removed."""
|
|
good_ws = AsyncMock()
|
|
bad_ws = AsyncMock()
|
|
good_ws.accept = AsyncMock()
|
|
bad_ws.accept = AsyncMock()
|
|
bad_ws.send_text.side_effect = Exception("Connection closed")
|
|
|
|
await ws_manager.connect(good_ws)
|
|
await ws_manager.connect(bad_ws)
|
|
|
|
assert len(ws_manager.active_connections) == 2
|
|
|
|
await ws_manager.broadcast("test", {})
|
|
|
|
# Bad client should be removed
|
|
assert len(ws_manager.active_connections) == 1
|
|
assert good_ws in ws_manager.active_connections
|
|
assert bad_ws not in ws_manager.active_connections
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_handles_timeout(self, ws_manager: WebSocketManager):
|
|
"""Clients that timeout should be removed."""
|
|
good_ws = AsyncMock()
|
|
slow_ws = AsyncMock()
|
|
good_ws.accept = AsyncMock()
|
|
slow_ws.accept = AsyncMock()
|
|
|
|
# Make slow_ws hang indefinitely
|
|
async def slow_send(_):
|
|
await asyncio.sleep(SEND_TIMEOUT_SECONDS + 1)
|
|
|
|
slow_ws.send_text.side_effect = slow_send
|
|
|
|
await ws_manager.connect(good_ws)
|
|
await ws_manager.connect(slow_ws)
|
|
|
|
assert len(ws_manager.active_connections) == 2
|
|
|
|
# Broadcast should complete despite slow client (due to timeout)
|
|
await ws_manager.broadcast("test", {})
|
|
|
|
# Slow client should be removed due to timeout
|
|
assert len(ws_manager.active_connections) == 1
|
|
assert good_ws in ws_manager.active_connections
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_concurrent_sends(self, ws_manager: WebSocketManager):
|
|
"""Verify that sends happen concurrently, not sequentially."""
|
|
call_times = []
|
|
|
|
async def record_send_time(ws_name):
|
|
async def _send(_):
|
|
call_times.append((ws_name, asyncio.get_event_loop().time()))
|
|
await asyncio.sleep(0.1) # Simulate some work
|
|
|
|
return _send
|
|
|
|
ws1 = AsyncMock()
|
|
ws2 = AsyncMock()
|
|
ws3 = AsyncMock()
|
|
ws1.accept = AsyncMock()
|
|
ws2.accept = AsyncMock()
|
|
ws3.accept = AsyncMock()
|
|
|
|
ws1.send_text.side_effect = await record_send_time("ws1")
|
|
ws2.send_text.side_effect = await record_send_time("ws2")
|
|
ws3.send_text.side_effect = await record_send_time("ws3")
|
|
|
|
await ws_manager.connect(ws1)
|
|
await ws_manager.connect(ws2)
|
|
await ws_manager.connect(ws3)
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
await ws_manager.broadcast("test", {})
|
|
elapsed = asyncio.get_event_loop().time() - start_time
|
|
|
|
# If sequential: 3 * 0.1 = 0.3s
|
|
# If concurrent: ~0.1s
|
|
# Allow some margin for test overhead
|
|
assert elapsed < 0.2, f"Sends should be concurrent, took {elapsed}s"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_does_not_block_on_slow_client(self, ws_manager: WebSocketManager):
|
|
"""A slow client should not block messages to fast clients."""
|
|
fast_ws = AsyncMock()
|
|
slow_ws = AsyncMock()
|
|
fast_ws.accept = AsyncMock()
|
|
slow_ws.accept = AsyncMock()
|
|
|
|
fast_received_at = None
|
|
slow_received_at = None
|
|
|
|
async def fast_send(_):
|
|
nonlocal fast_received_at
|
|
fast_received_at = asyncio.get_event_loop().time()
|
|
|
|
async def slow_send(_):
|
|
nonlocal slow_received_at
|
|
await asyncio.sleep(0.2) # Slow client
|
|
slow_received_at = asyncio.get_event_loop().time()
|
|
|
|
fast_ws.send_text.side_effect = fast_send
|
|
slow_ws.send_text.side_effect = slow_send
|
|
|
|
await ws_manager.connect(slow_ws)
|
|
await ws_manager.connect(fast_ws)
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
await ws_manager.broadcast("test", {})
|
|
|
|
# Fast client should receive message quickly, not waiting for slow client
|
|
assert fast_received_at is not None
|
|
assert fast_received_at - start_time < 0.1, "Fast client was blocked by slow client"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_empty_connections(self, ws_manager: WebSocketManager):
|
|
"""Broadcast should handle empty connection list gracefully."""
|
|
# Should not raise
|
|
await ws_manager.broadcast("test", {"data": "value"})
|
|
|
|
|
|
class TestWebSocketConnectionManagement:
|
|
"""Tests for connection/disconnection."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_adds_to_list(self, ws_manager: WebSocketManager, mock_websocket):
|
|
"""Connect should add websocket to active connections."""
|
|
assert len(ws_manager.active_connections) == 0
|
|
|
|
await ws_manager.connect(mock_websocket)
|
|
|
|
assert len(ws_manager.active_connections) == 1
|
|
assert mock_websocket in ws_manager.active_connections
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_removes_from_list(self, ws_manager: WebSocketManager, mock_websocket):
|
|
"""Disconnect should remove websocket from active connections."""
|
|
await ws_manager.connect(mock_websocket)
|
|
assert len(ws_manager.active_connections) == 1
|
|
|
|
await ws_manager.disconnect(mock_websocket)
|
|
|
|
assert len(ws_manager.active_connections) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_nonexistent_is_safe(
|
|
self, ws_manager: WebSocketManager, mock_websocket
|
|
):
|
|
"""Disconnecting a non-connected websocket should not raise."""
|
|
# Should not raise
|
|
await ws_manager.disconnect(mock_websocket)
|
|
assert len(ws_manager.active_connections) == 0
|
|
|
|
|
|
class TestBroadcastEventFanout:
|
|
"""Test that broadcast_event dispatches to WS and fanout manager."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_event_dispatches_to_ws_and_fanout(self):
|
|
"""broadcast_event creates a WS task and dispatches to fanout manager."""
|
|
from app.websocket import broadcast_event
|
|
|
|
with (
|
|
patch("app.websocket.ws_manager") as mock_ws,
|
|
patch("app.fanout.manager.fanout_manager") as mock_fm,
|
|
):
|
|
mock_ws.broadcast = AsyncMock()
|
|
mock_fm.broadcast_message = AsyncMock()
|
|
|
|
broadcast_event("message", {"id": 1, "text": "hello"})
|
|
|
|
# Let the asyncio tasks run
|
|
await asyncio.sleep(0)
|
|
|
|
mock_ws.broadcast.assert_called_once_with("message", {"id": 1, "text": "hello"})
|
|
mock_fm.broadcast_message.assert_called_once_with({"id": 1, "text": "hello"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_event_raw_packet_dispatches_to_fanout(self):
|
|
"""broadcast_event for raw_packet dispatches to fanout broadcast_raw."""
|
|
from app.websocket import broadcast_event
|
|
|
|
with (
|
|
patch("app.websocket.ws_manager") as mock_ws,
|
|
patch("app.fanout.manager.fanout_manager") as mock_fm,
|
|
):
|
|
mock_ws.broadcast = AsyncMock()
|
|
mock_fm.broadcast_raw = AsyncMock()
|
|
|
|
broadcast_event("raw_packet", {"data": "ff00"})
|
|
await asyncio.sleep(0)
|
|
|
|
mock_ws.broadcast.assert_called_once()
|
|
mock_fm.broadcast_raw.assert_called_once_with({"data": "ff00"})
|