mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-05-04 04:23:04 +02:00
Do an imitation of protecting our butts (race conditions in message loading, websocket defensiveness, optimistic UI update rollback handling
This commit is contained in:
@@ -12,6 +12,10 @@ from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Timeout for individual WebSocket send operations (seconds)
|
||||
# Prevents a slow client from blocking broadcasts to other clients
|
||||
SEND_TIMEOUT_SECONDS = 5.0
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manages WebSocket connections and broadcasts events."""
|
||||
@@ -33,25 +37,50 @@ class WebSocketManager:
|
||||
logger.info("WebSocket client disconnected (%d remaining)", len(self.active_connections))
|
||||
|
||||
async def broadcast(self, event_type: str, data: Any) -> None:
|
||||
"""Broadcast an event to all connected clients."""
|
||||
"""Broadcast an event to all connected clients.
|
||||
|
||||
Uses a copy-then-send pattern to avoid holding the lock during I/O:
|
||||
1. Copy connection list while holding lock
|
||||
2. Release lock before sending
|
||||
3. Send to all clients concurrently with timeout
|
||||
4. Re-acquire lock to clean up disconnected clients
|
||||
"""
|
||||
if not self.active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps({"type": event_type, "data": data})
|
||||
|
||||
# Copy connection list under lock to avoid holding lock during I/O
|
||||
async with self._lock:
|
||||
disconnected = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to send to client: %s", e)
|
||||
disconnected.append(connection)
|
||||
connections = list(self.active_connections)
|
||||
|
||||
# Clean up disconnected clients
|
||||
for conn in disconnected:
|
||||
if conn in self.active_connections:
|
||||
self.active_connections.remove(conn)
|
||||
if not connections:
|
||||
return
|
||||
|
||||
# Send to all clients concurrently, collect failures
|
||||
disconnected: list[WebSocket] = []
|
||||
|
||||
async def send_to_client(connection: WebSocket) -> None:
|
||||
try:
|
||||
# Timeout prevents blocking on slow/unresponsive clients
|
||||
await asyncio.wait_for(connection.send_text(message), timeout=SEND_TIMEOUT_SECONDS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Timeout sending to WebSocket client, marking disconnected")
|
||||
disconnected.append(connection)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to send to client: %s", e)
|
||||
disconnected.append(connection)
|
||||
|
||||
# Send to all clients concurrently
|
||||
await asyncio.gather(*[send_to_client(conn) for conn in connections])
|
||||
|
||||
# Clean up disconnected clients (re-acquire lock)
|
||||
if disconnected:
|
||||
async with self._lock:
|
||||
for conn in disconnected:
|
||||
if conn in self.active_connections:
|
||||
self.active_connections.remove(conn)
|
||||
logger.debug("Removed %d disconnected WebSocket clients", len(disconnected))
|
||||
|
||||
async def send_personal(self, websocket: WebSocket, event_type: str, data: Any) -> None:
|
||||
"""Send an event to a specific client."""
|
||||
|
||||
1
frontend/dist/assets/index-BGhVMB5J.js.map
vendored
1
frontend/dist/assets/index-BGhVMB5J.js.map
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
frontend/dist/assets/index-BeiXDnDV.js.map
vendored
Normal file
1
frontend/dist/assets/index-BeiXDnDV.js.map
vendored
Normal file
File diff suppressed because one or more lines are too long
2
frontend/dist/index.html
vendored
2
frontend/dist/index.html
vendored
@@ -13,7 +13,7 @@
|
||||
<link rel="shortcut icon" href="/favicon.ico" />
|
||||
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<script type="module" crossorigin src="/assets/index-BGhVMB5J.js"></script>
|
||||
<script type="module" crossorigin src="/assets/index-BeiXDnDV.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-CnRBRJ10.css">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
@@ -637,22 +637,26 @@ export function App() {
|
||||
);
|
||||
|
||||
// Handle sort order change via API with optimistic update
|
||||
const handleSortOrderChange = useCallback(async (order: 'recent' | 'alpha') => {
|
||||
// Optimistic update for responsive UI
|
||||
setAppSettings((prev) => (prev ? { ...prev, sidebar_sort_order: order } : prev));
|
||||
const handleSortOrderChange = useCallback(
|
||||
async (order: 'recent' | 'alpha') => {
|
||||
// Capture previous value for rollback on error
|
||||
const previousOrder = appSettings?.sidebar_sort_order ?? 'recent';
|
||||
|
||||
try {
|
||||
const updatedSettings = await api.updateSettings({ sidebar_sort_order: order });
|
||||
setAppSettings(updatedSettings);
|
||||
} catch (err) {
|
||||
console.error('Failed to update sort order:', err);
|
||||
// Revert on error
|
||||
setAppSettings((prev) =>
|
||||
prev ? { ...prev, sidebar_sort_order: order === 'recent' ? 'alpha' : 'recent' } : prev
|
||||
);
|
||||
toast.error('Failed to save sort preference');
|
||||
}
|
||||
}, []);
|
||||
// Optimistic update for responsive UI
|
||||
setAppSettings((prev) => (prev ? { ...prev, sidebar_sort_order: order } : prev));
|
||||
|
||||
try {
|
||||
const updatedSettings = await api.updateSettings({ sidebar_sort_order: order });
|
||||
setAppSettings(updatedSettings);
|
||||
} catch (err) {
|
||||
console.error('Failed to update sort order:', err);
|
||||
// Revert to previous value on error (not inverting the new value)
|
||||
setAppSettings((prev) => (prev ? { ...prev, sidebar_sort_order: previousOrder } : prev));
|
||||
toast.error('Failed to save sort preference');
|
||||
}
|
||||
},
|
||||
[appSettings?.sidebar_sort_order]
|
||||
);
|
||||
|
||||
// Sidebar content (shared between desktop and mobile)
|
||||
const sidebarContent = (
|
||||
|
||||
@@ -45,6 +45,16 @@ async function fetchJson<T>(url: string, options?: RequestInit): Promise<T> {
|
||||
return res.json();
|
||||
}
|
||||
|
||||
/** Check if an error is an AbortError (request was cancelled) */
|
||||
export function isAbortError(err: unknown): boolean {
|
||||
// DOMException is thrown by fetch when aborted, and it's not an Error subclass
|
||||
if (err instanceof DOMException && err.name === 'AbortError') {
|
||||
return true;
|
||||
}
|
||||
// Also check for Error with AbortError name (for compatibility)
|
||||
return err instanceof Error && err.name === 'AbortError';
|
||||
}
|
||||
|
||||
interface DecryptResult {
|
||||
started: boolean;
|
||||
total_packets: number;
|
||||
@@ -134,19 +144,22 @@ export const api = {
|
||||
}),
|
||||
|
||||
// Messages
|
||||
getMessages: (params?: {
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
type?: 'PRIV' | 'CHAN';
|
||||
conversation_key?: string;
|
||||
}) => {
|
||||
getMessages: (
|
||||
params?: {
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
type?: 'PRIV' | 'CHAN';
|
||||
conversation_key?: string;
|
||||
},
|
||||
signal?: AbortSignal
|
||||
) => {
|
||||
const searchParams = new URLSearchParams();
|
||||
if (params?.limit) searchParams.set('limit', params.limit.toString());
|
||||
if (params?.offset) searchParams.set('offset', params.offset.toString());
|
||||
if (params?.type) searchParams.set('type', params.type);
|
||||
if (params?.conversation_key) searchParams.set('conversation_key', params.conversation_key);
|
||||
const query = searchParams.toString();
|
||||
return fetchJson<Message[]>(`/messages${query ? `?${query}` : ''}`);
|
||||
return fetchJson<Message[]>(`/messages${query ? `?${query}` : ''}`, { signal });
|
||||
},
|
||||
getMessagesBulk: (
|
||||
conversations: Array<{ type: 'PRIV' | 'CHAN'; conversation_key: string }>,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useState, useCallback, useEffect, useRef } from 'react';
|
||||
import { toast } from '../components/ui/sonner';
|
||||
import { api } from '../api';
|
||||
import { api, isAbortError } from '../api';
|
||||
import type { Conversation, Message, MessagePath } from '../types';
|
||||
|
||||
const MESSAGE_PAGE_SIZE = 200;
|
||||
@@ -33,26 +33,49 @@ export function useConversationMessages(
|
||||
// Track seen message content for deduplication
|
||||
const seenMessageContent = useRef<Set<string>>(new Set());
|
||||
|
||||
// AbortController for cancelling in-flight requests on conversation change
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
// Ref to track the conversation ID being fetched to prevent stale responses
|
||||
const fetchingConversationIdRef = useRef<string | null>(null);
|
||||
|
||||
// Fetch messages for active conversation
|
||||
// Note: This is called manually and from the useEffect. The useEffect handles
|
||||
// cancellation via AbortController; manual calls (e.g., after sending a message)
|
||||
// don't need cancellation.
|
||||
const fetchMessages = useCallback(
|
||||
async (showLoading = false) => {
|
||||
async (showLoading = false, signal?: AbortSignal) => {
|
||||
if (!activeConversation || activeConversation.type === 'raw') {
|
||||
setMessages([]);
|
||||
setHasOlderMessages(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Track which conversation we're fetching for
|
||||
const conversationId = activeConversation.id;
|
||||
|
||||
if (showLoading) {
|
||||
setMessagesLoading(true);
|
||||
// Clear messages first so MessageList resets scroll state for new conversation
|
||||
setMessages([]);
|
||||
}
|
||||
try {
|
||||
const data = await api.getMessages({
|
||||
type: activeConversation.type === 'channel' ? 'CHAN' : 'PRIV',
|
||||
conversation_key: activeConversation.id,
|
||||
limit: MESSAGE_PAGE_SIZE,
|
||||
});
|
||||
const data = await api.getMessages(
|
||||
{
|
||||
type: activeConversation.type === 'channel' ? 'CHAN' : 'PRIV',
|
||||
conversation_key: activeConversation.id,
|
||||
limit: MESSAGE_PAGE_SIZE,
|
||||
},
|
||||
signal
|
||||
);
|
||||
|
||||
// Check if this response is still for the current conversation
|
||||
// This handles the race where the conversation changed while awaiting
|
||||
if (fetchingConversationIdRef.current !== conversationId) {
|
||||
// Stale response - conversation changed while we were fetching
|
||||
return;
|
||||
}
|
||||
|
||||
setMessages(data);
|
||||
// Track seen content for new messages
|
||||
seenMessageContent.current.clear();
|
||||
@@ -62,6 +85,10 @@ export function useConversationMessages(
|
||||
// If we got a full page, there might be more
|
||||
setHasOlderMessages(data.length >= MESSAGE_PAGE_SIZE);
|
||||
} catch (err) {
|
||||
// Don't show error toast for aborted requests (user switched conversations)
|
||||
if (isAbortError(err)) {
|
||||
return;
|
||||
}
|
||||
console.error('Failed to fetch messages:', err);
|
||||
toast.error('Failed to load messages', {
|
||||
description: err instanceof Error ? err.message : 'Check your connection',
|
||||
@@ -114,10 +141,40 @@ export function useConversationMessages(
|
||||
}
|
||||
}, [activeConversation, loadingOlder, hasOlderMessages, messages.length]);
|
||||
|
||||
// Fetch messages when conversation changes
|
||||
// Fetch messages when conversation changes, with proper cancellation
|
||||
useEffect(() => {
|
||||
fetchMessages(true);
|
||||
}, [fetchMessages]);
|
||||
// Abort any previous in-flight request
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
}
|
||||
|
||||
// Track which conversation we're now fetching
|
||||
fetchingConversationIdRef.current = activeConversation?.id ?? null;
|
||||
|
||||
// Clear state for new conversation
|
||||
if (!activeConversation || activeConversation.type === 'raw') {
|
||||
setMessages([]);
|
||||
setHasOlderMessages(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create new AbortController for this fetch
|
||||
const controller = new AbortController();
|
||||
abortControllerRef.current = controller;
|
||||
|
||||
// Fetch messages with the abort signal
|
||||
fetchMessages(true, controller.signal);
|
||||
|
||||
// Cleanup: abort request if conversation changes or component unmounts
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
// NOTE: Intentionally omitting fetchMessages and activeConversation from deps:
|
||||
// - fetchMessages is recreated when activeConversation changes, which would cause infinite loops
|
||||
// - activeConversation object identity changes on every render; we only care about id/type
|
||||
// - We use fetchingConversationIdRef and AbortController to handle stale responses safely
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [activeConversation?.id, activeConversation?.type]);
|
||||
|
||||
// Add a message if it's new (deduplication)
|
||||
// Returns true if the message was added, false if it was a duplicate
|
||||
|
||||
62
frontend/src/test/api.test.ts
Normal file
62
frontend/src/test/api.test.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Tests for API utilities.
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { isAbortError } from '../api';
|
||||
|
||||
describe('isAbortError', () => {
|
||||
it('returns true for AbortError', () => {
|
||||
const controller = new AbortController();
|
||||
controller.abort();
|
||||
|
||||
// Create an error that mimics fetch abort
|
||||
const error = new DOMException('The operation was aborted', 'AbortError');
|
||||
|
||||
expect(isAbortError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true for Error with name AbortError', () => {
|
||||
const error = new Error('Request cancelled');
|
||||
error.name = 'AbortError';
|
||||
|
||||
expect(isAbortError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for regular Error', () => {
|
||||
const error = new Error('Something went wrong');
|
||||
|
||||
expect(isAbortError(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for TypeError', () => {
|
||||
const error = new TypeError('Network failure');
|
||||
|
||||
expect(isAbortError(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for null', () => {
|
||||
expect(isAbortError(null)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for undefined', () => {
|
||||
expect(isAbortError(undefined)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for non-Error objects', () => {
|
||||
expect(isAbortError({ message: 'error' })).toBe(false);
|
||||
expect(isAbortError('error string')).toBe(false);
|
||||
expect(isAbortError(42)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for Error subclasses with different names', () => {
|
||||
class CustomError extends Error {
|
||||
constructor() {
|
||||
super('Custom error');
|
||||
this.name = 'CustomError';
|
||||
}
|
||||
}
|
||||
|
||||
expect(isAbortError(new CustomError())).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -193,3 +193,80 @@ describe('parseWebSocketMessage', () => {
|
||||
expect(onRawPacket).toHaveBeenCalledWith(packetData);
|
||||
});
|
||||
});
|
||||
|
||||
describe('useWebSocket ref-based handler pattern', () => {
|
||||
/**
|
||||
* These tests verify the pattern used in useWebSocket to avoid stale closures.
|
||||
* The hook stores handlers in a ref and accesses them through the ref in callbacks.
|
||||
* This ensures that when handlers are updated, the WebSocket still calls the latest version.
|
||||
*/
|
||||
|
||||
it('demonstrates ref pattern prevents stale closure', () => {
|
||||
// Simulate the ref pattern used in useWebSocket
|
||||
interface Handlers {
|
||||
onMessage?: (msg: string) => void;
|
||||
}
|
||||
|
||||
// This simulates what the hook does: store handlers in a ref
|
||||
const handlersRef: { current: Handlers } = { current: {} };
|
||||
|
||||
// First handler version
|
||||
const firstHandler = vi.fn();
|
||||
handlersRef.current = { onMessage: firstHandler };
|
||||
|
||||
// Simulate what onmessage does: access handlers through ref
|
||||
const processMessage = (data: string) => {
|
||||
// This is the pattern: access through ref.current, not closed-over variable
|
||||
handlersRef.current.onMessage?.(data);
|
||||
};
|
||||
|
||||
// Send first message
|
||||
processMessage('message1');
|
||||
expect(firstHandler).toHaveBeenCalledWith('message1');
|
||||
|
||||
// Update handler (simulates React re-render with new handler)
|
||||
const secondHandler = vi.fn();
|
||||
handlersRef.current = { onMessage: secondHandler };
|
||||
|
||||
// Send second message
|
||||
processMessage('message2');
|
||||
|
||||
// First handler should NOT be called again (would happen with stale closure)
|
||||
expect(firstHandler).toHaveBeenCalledTimes(1);
|
||||
// Second handler should be called (ref pattern works)
|
||||
expect(secondHandler).toHaveBeenCalledWith('message2');
|
||||
});
|
||||
|
||||
it('demonstrates stale closure problem without ref pattern', () => {
|
||||
// This demonstrates the bug we fixed - without refs, handlers become stale
|
||||
interface Handlers {
|
||||
onMessage?: (msg: string) => void;
|
||||
}
|
||||
|
||||
// First handler version
|
||||
const firstHandler = vi.fn();
|
||||
let handlers: Handlers = { onMessage: firstHandler };
|
||||
|
||||
// BAD PATTERN: capture handlers in closure (this is what we fixed)
|
||||
const capturedHandlers = handlers;
|
||||
const processMessageBad = (data: string) => {
|
||||
// This captures `capturedHandlers` at creation time - STALE!
|
||||
capturedHandlers.onMessage?.(data);
|
||||
};
|
||||
|
||||
// Send first message
|
||||
processMessageBad('message1');
|
||||
expect(firstHandler).toHaveBeenCalledWith('message1');
|
||||
|
||||
// Update handler
|
||||
const secondHandler = vi.fn();
|
||||
handlers = { onMessage: secondHandler };
|
||||
|
||||
// Send second message - BUG: still calls first handler!
|
||||
processMessageBad('message2');
|
||||
|
||||
// This demonstrates the stale closure bug
|
||||
expect(firstHandler).toHaveBeenCalledTimes(2); // Called twice - bug!
|
||||
expect(secondHandler).not.toHaveBeenCalled(); // Never called - bug!
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,6 +33,19 @@ export function useWebSocket(options: UseWebSocketOptions) {
|
||||
const reconnectTimeoutRef = useRef<number | null>(null);
|
||||
const [connected, setConnected] = useState(false);
|
||||
|
||||
// Store options in ref to avoid stale closures in WebSocket handlers.
|
||||
// The onmessage callback captures this ref, and we keep the ref updated
|
||||
// with the latest handlers. This way, even though the WebSocket connection
|
||||
// is only created once, it always calls the current handlers.
|
||||
const optionsRef = useRef<UseWebSocketOptions>(options);
|
||||
|
||||
// Keep the ref updated with latest options
|
||||
useEffect(() => {
|
||||
optionsRef.current = options;
|
||||
}, [options]);
|
||||
|
||||
// Connect function - uses ref for handlers to avoid stale closures
|
||||
// No dependencies needed since we access handlers through ref
|
||||
const connect = useCallback(() => {
|
||||
// Determine WebSocket URL based on current location
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
@@ -68,25 +81,27 @@ export function useWebSocket(options: UseWebSocketOptions) {
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const msg: WebSocketMessage = JSON.parse(event.data);
|
||||
// Access handlers through ref to always use current versions
|
||||
const handlers = optionsRef.current;
|
||||
|
||||
switch (msg.type) {
|
||||
case 'health':
|
||||
options.onHealth?.(msg.data as HealthStatus);
|
||||
handlers.onHealth?.(msg.data as HealthStatus);
|
||||
break;
|
||||
case 'contacts':
|
||||
options.onContacts?.(msg.data as Contact[]);
|
||||
handlers.onContacts?.(msg.data as Contact[]);
|
||||
break;
|
||||
case 'channels':
|
||||
options.onChannels?.(msg.data as Channel[]);
|
||||
handlers.onChannels?.(msg.data as Channel[]);
|
||||
break;
|
||||
case 'message':
|
||||
options.onMessage?.(msg.data as Message);
|
||||
handlers.onMessage?.(msg.data as Message);
|
||||
break;
|
||||
case 'contact':
|
||||
options.onContact?.(msg.data as Contact);
|
||||
handlers.onContact?.(msg.data as Contact);
|
||||
break;
|
||||
case 'raw_packet':
|
||||
options.onRawPacket?.(msg.data as RawPacket);
|
||||
handlers.onRawPacket?.(msg.data as RawPacket);
|
||||
break;
|
||||
case 'message_acked': {
|
||||
const ackData = msg.data as {
|
||||
@@ -94,14 +109,14 @@ export function useWebSocket(options: UseWebSocketOptions) {
|
||||
ack_count: number;
|
||||
paths?: MessagePath[];
|
||||
};
|
||||
options.onMessageAcked?.(ackData.message_id, ackData.ack_count, ackData.paths);
|
||||
handlers.onMessageAcked?.(ackData.message_id, ackData.ack_count, ackData.paths);
|
||||
break;
|
||||
}
|
||||
case 'error':
|
||||
options.onError?.(msg.data as ErrorEvent);
|
||||
handlers.onError?.(msg.data as ErrorEvent);
|
||||
break;
|
||||
case 'success':
|
||||
options.onSuccess?.(msg.data as SuccessEvent);
|
||||
handlers.onSuccess?.(msg.data as SuccessEvent);
|
||||
break;
|
||||
case 'pong':
|
||||
// Heartbeat response, ignore
|
||||
@@ -115,7 +130,7 @@ export function useWebSocket(options: UseWebSocketOptions) {
|
||||
};
|
||||
|
||||
wsRef.current = ws;
|
||||
}, [options]);
|
||||
}, []); // No dependencies - handlers accessed through ref
|
||||
|
||||
useEffect(() => {
|
||||
connect();
|
||||
|
||||
205
tests/test_websocket.py
Normal file
205
tests/test_websocket.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Tests for WebSocket manager functionality."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.websocket import SEND_TIMEOUT_SECONDS, WebSocketManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ws_manager():
|
||||
"""Create a fresh WebSocketManager for each test."""
|
||||
return WebSocketManager()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Create a mock WebSocket connection."""
|
||||
ws = AsyncMock()
|
||||
ws.send_text = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
class TestWebSocketBroadcast:
|
||||
"""Tests for the broadcast functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_sends_to_all_clients(self, ws_manager: WebSocketManager):
|
||||
"""Broadcast should send message to all connected clients."""
|
||||
ws1 = AsyncMock()
|
||||
ws2 = AsyncMock()
|
||||
ws1.accept = AsyncMock()
|
||||
ws2.accept = AsyncMock()
|
||||
|
||||
await ws_manager.connect(ws1)
|
||||
await ws_manager.connect(ws2)
|
||||
|
||||
await ws_manager.broadcast("test", {"key": "value"})
|
||||
|
||||
# Both clients should receive the message
|
||||
ws1.send_text.assert_called_once()
|
||||
ws2.send_text.assert_called_once()
|
||||
|
||||
# Verify the message format
|
||||
import json
|
||||
|
||||
expected = json.dumps({"type": "test", "data": {"key": "value"}})
|
||||
ws1.send_text.assert_called_with(expected)
|
||||
ws2.send_text.assert_called_with(expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_removes_failed_clients(self, ws_manager: WebSocketManager):
|
||||
"""Clients that fail to receive should be removed."""
|
||||
good_ws = AsyncMock()
|
||||
bad_ws = AsyncMock()
|
||||
good_ws.accept = AsyncMock()
|
||||
bad_ws.accept = AsyncMock()
|
||||
bad_ws.send_text.side_effect = Exception("Connection closed")
|
||||
|
||||
await ws_manager.connect(good_ws)
|
||||
await ws_manager.connect(bad_ws)
|
||||
|
||||
assert len(ws_manager.active_connections) == 2
|
||||
|
||||
await ws_manager.broadcast("test", {})
|
||||
|
||||
# Bad client should be removed
|
||||
assert len(ws_manager.active_connections) == 1
|
||||
assert good_ws in ws_manager.active_connections
|
||||
assert bad_ws not in ws_manager.active_connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_handles_timeout(self, ws_manager: WebSocketManager):
|
||||
"""Clients that timeout should be removed."""
|
||||
good_ws = AsyncMock()
|
||||
slow_ws = AsyncMock()
|
||||
good_ws.accept = AsyncMock()
|
||||
slow_ws.accept = AsyncMock()
|
||||
|
||||
# Make slow_ws hang indefinitely
|
||||
async def slow_send(_):
|
||||
await asyncio.sleep(SEND_TIMEOUT_SECONDS + 1)
|
||||
|
||||
slow_ws.send_text.side_effect = slow_send
|
||||
|
||||
await ws_manager.connect(good_ws)
|
||||
await ws_manager.connect(slow_ws)
|
||||
|
||||
assert len(ws_manager.active_connections) == 2
|
||||
|
||||
# Broadcast should complete despite slow client (due to timeout)
|
||||
await ws_manager.broadcast("test", {})
|
||||
|
||||
# Slow client should be removed due to timeout
|
||||
assert len(ws_manager.active_connections) == 1
|
||||
assert good_ws in ws_manager.active_connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_concurrent_sends(self, ws_manager: WebSocketManager):
|
||||
"""Verify that sends happen concurrently, not sequentially."""
|
||||
call_times = []
|
||||
|
||||
async def record_send_time(ws_name):
|
||||
async def _send(_):
|
||||
call_times.append((ws_name, asyncio.get_event_loop().time()))
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
|
||||
return _send
|
||||
|
||||
ws1 = AsyncMock()
|
||||
ws2 = AsyncMock()
|
||||
ws3 = AsyncMock()
|
||||
ws1.accept = AsyncMock()
|
||||
ws2.accept = AsyncMock()
|
||||
ws3.accept = AsyncMock()
|
||||
|
||||
ws1.send_text.side_effect = await record_send_time("ws1")
|
||||
ws2.send_text.side_effect = await record_send_time("ws2")
|
||||
ws3.send_text.side_effect = await record_send_time("ws3")
|
||||
|
||||
await ws_manager.connect(ws1)
|
||||
await ws_manager.connect(ws2)
|
||||
await ws_manager.connect(ws3)
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
await ws_manager.broadcast("test", {})
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# If sequential: 3 * 0.1 = 0.3s
|
||||
# If concurrent: ~0.1s
|
||||
# Allow some margin for test overhead
|
||||
assert elapsed < 0.2, f"Sends should be concurrent, took {elapsed}s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_does_not_block_on_slow_client(self, ws_manager: WebSocketManager):
|
||||
"""A slow client should not block messages to fast clients."""
|
||||
fast_ws = AsyncMock()
|
||||
slow_ws = AsyncMock()
|
||||
fast_ws.accept = AsyncMock()
|
||||
slow_ws.accept = AsyncMock()
|
||||
|
||||
fast_received_at = None
|
||||
slow_received_at = None
|
||||
|
||||
async def fast_send(_):
|
||||
nonlocal fast_received_at
|
||||
fast_received_at = asyncio.get_event_loop().time()
|
||||
|
||||
async def slow_send(_):
|
||||
nonlocal slow_received_at
|
||||
await asyncio.sleep(0.2) # Slow client
|
||||
slow_received_at = asyncio.get_event_loop().time()
|
||||
|
||||
fast_ws.send_text.side_effect = fast_send
|
||||
slow_ws.send_text.side_effect = slow_send
|
||||
|
||||
await ws_manager.connect(slow_ws)
|
||||
await ws_manager.connect(fast_ws)
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
await ws_manager.broadcast("test", {})
|
||||
|
||||
# Fast client should receive message quickly, not waiting for slow client
|
||||
assert fast_received_at is not None
|
||||
assert fast_received_at - start_time < 0.1, "Fast client was blocked by slow client"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_empty_connections(self, ws_manager: WebSocketManager):
|
||||
"""Broadcast should handle empty connection list gracefully."""
|
||||
# Should not raise
|
||||
await ws_manager.broadcast("test", {"data": "value"})
|
||||
|
||||
|
||||
class TestWebSocketConnectionManagement:
|
||||
"""Tests for connection/disconnection."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_adds_to_list(self, ws_manager: WebSocketManager, mock_websocket):
|
||||
"""Connect should add websocket to active connections."""
|
||||
assert len(ws_manager.active_connections) == 0
|
||||
|
||||
await ws_manager.connect(mock_websocket)
|
||||
|
||||
assert len(ws_manager.active_connections) == 1
|
||||
assert mock_websocket in ws_manager.active_connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_removes_from_list(self, ws_manager: WebSocketManager, mock_websocket):
|
||||
"""Disconnect should remove websocket from active connections."""
|
||||
await ws_manager.connect(mock_websocket)
|
||||
assert len(ws_manager.active_connections) == 1
|
||||
|
||||
await ws_manager.disconnect(mock_websocket)
|
||||
|
||||
assert len(ws_manager.active_connections) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_nonexistent_is_safe(
|
||||
self, ws_manager: WebSocketManager, mock_websocket
|
||||
):
|
||||
"""Disconnecting a non-connected websocket should not raise."""
|
||||
# Should not raise
|
||||
await ws_manager.disconnect(mock_websocket)
|
||||
assert len(ws_manager.active_connections) == 0
|
||||
Reference in New Issue
Block a user