Do an imitation of protecting our butts (race conditions in message loading, websocket defensiveness, optimistic UI update rollback handling

This commit is contained in:
Jack Kingsman
2026-01-19 11:47:20 -08:00
parent bf03d76c33
commit 0138233743
12 changed files with 553 additions and 91 deletions

View File

@@ -12,6 +12,10 @@ from app.config import settings
logger = logging.getLogger(__name__)
# Timeout for individual WebSocket send operations (seconds)
# Prevents a slow client from blocking broadcasts to other clients
SEND_TIMEOUT_SECONDS = 5.0
class WebSocketManager:
"""Manages WebSocket connections and broadcasts events."""
@@ -33,25 +37,50 @@ class WebSocketManager:
logger.info("WebSocket client disconnected (%d remaining)", len(self.active_connections))
async def broadcast(self, event_type: str, data: Any) -> None:
"""Broadcast an event to all connected clients."""
"""Broadcast an event to all connected clients.
Uses a copy-then-send pattern to avoid holding the lock during I/O:
1. Copy connection list while holding lock
2. Release lock before sending
3. Send to all clients concurrently with timeout
4. Re-acquire lock to clean up disconnected clients
"""
if not self.active_connections:
return
message = json.dumps({"type": event_type, "data": data})
# Copy connection list under lock to avoid holding lock during I/O
async with self._lock:
disconnected = []
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception as e:
logger.debug("Failed to send to client: %s", e)
disconnected.append(connection)
connections = list(self.active_connections)
# Clean up disconnected clients
for conn in disconnected:
if conn in self.active_connections:
self.active_connections.remove(conn)
if not connections:
return
# Send to all clients concurrently, collect failures
disconnected: list[WebSocket] = []
async def send_to_client(connection: WebSocket) -> None:
try:
# Timeout prevents blocking on slow/unresponsive clients
await asyncio.wait_for(connection.send_text(message), timeout=SEND_TIMEOUT_SECONDS)
except asyncio.TimeoutError:
logger.debug("Timeout sending to WebSocket client, marking disconnected")
disconnected.append(connection)
except Exception as e:
logger.debug("Failed to send to client: %s", e)
disconnected.append(connection)
# Send to all clients concurrently
await asyncio.gather(*[send_to_client(conn) for conn in connections])
# Clean up disconnected clients (re-acquire lock)
if disconnected:
async with self._lock:
for conn in disconnected:
if conn in self.active_connections:
self.active_connections.remove(conn)
logger.debug("Removed %d disconnected WebSocket clients", len(disconnected))
async def send_personal(self, websocket: WebSocket, event_type: str, data: Any) -> None:
"""Send an event to a specific client."""

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -13,7 +13,7 @@
<link rel="shortcut icon" href="/favicon.ico" />
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
<link rel="manifest" href="/site.webmanifest" />
<script type="module" crossorigin src="/assets/index-BGhVMB5J.js"></script>
<script type="module" crossorigin src="/assets/index-BeiXDnDV.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-CnRBRJ10.css">
</head>
<body>

View File

@@ -637,22 +637,26 @@ export function App() {
);
// Handle sort order change via API with optimistic update
const handleSortOrderChange = useCallback(async (order: 'recent' | 'alpha') => {
// Optimistic update for responsive UI
setAppSettings((prev) => (prev ? { ...prev, sidebar_sort_order: order } : prev));
const handleSortOrderChange = useCallback(
async (order: 'recent' | 'alpha') => {
// Capture previous value for rollback on error
const previousOrder = appSettings?.sidebar_sort_order ?? 'recent';
try {
const updatedSettings = await api.updateSettings({ sidebar_sort_order: order });
setAppSettings(updatedSettings);
} catch (err) {
console.error('Failed to update sort order:', err);
// Revert on error
setAppSettings((prev) =>
prev ? { ...prev, sidebar_sort_order: order === 'recent' ? 'alpha' : 'recent' } : prev
);
toast.error('Failed to save sort preference');
}
}, []);
// Optimistic update for responsive UI
setAppSettings((prev) => (prev ? { ...prev, sidebar_sort_order: order } : prev));
try {
const updatedSettings = await api.updateSettings({ sidebar_sort_order: order });
setAppSettings(updatedSettings);
} catch (err) {
console.error('Failed to update sort order:', err);
// Revert to previous value on error (not inverting the new value)
setAppSettings((prev) => (prev ? { ...prev, sidebar_sort_order: previousOrder } : prev));
toast.error('Failed to save sort preference');
}
},
[appSettings?.sidebar_sort_order]
);
// Sidebar content (shared between desktop and mobile)
const sidebarContent = (

View File

@@ -45,6 +45,16 @@ async function fetchJson<T>(url: string, options?: RequestInit): Promise<T> {
return res.json();
}
/** Check if an error is an AbortError (request was cancelled) */
export function isAbortError(err: unknown): boolean {
// DOMException is thrown by fetch when aborted, and it's not an Error subclass
if (err instanceof DOMException && err.name === 'AbortError') {
return true;
}
// Also check for Error with AbortError name (for compatibility)
return err instanceof Error && err.name === 'AbortError';
}
interface DecryptResult {
started: boolean;
total_packets: number;
@@ -134,19 +144,22 @@ export const api = {
}),
// Messages
getMessages: (params?: {
limit?: number;
offset?: number;
type?: 'PRIV' | 'CHAN';
conversation_key?: string;
}) => {
getMessages: (
params?: {
limit?: number;
offset?: number;
type?: 'PRIV' | 'CHAN';
conversation_key?: string;
},
signal?: AbortSignal
) => {
const searchParams = new URLSearchParams();
if (params?.limit) searchParams.set('limit', params.limit.toString());
if (params?.offset) searchParams.set('offset', params.offset.toString());
if (params?.type) searchParams.set('type', params.type);
if (params?.conversation_key) searchParams.set('conversation_key', params.conversation_key);
const query = searchParams.toString();
return fetchJson<Message[]>(`/messages${query ? `?${query}` : ''}`);
return fetchJson<Message[]>(`/messages${query ? `?${query}` : ''}`, { signal });
},
getMessagesBulk: (
conversations: Array<{ type: 'PRIV' | 'CHAN'; conversation_key: string }>,

View File

@@ -1,6 +1,6 @@
import { useState, useCallback, useEffect, useRef } from 'react';
import { toast } from '../components/ui/sonner';
import { api } from '../api';
import { api, isAbortError } from '../api';
import type { Conversation, Message, MessagePath } from '../types';
const MESSAGE_PAGE_SIZE = 200;
@@ -33,26 +33,49 @@ export function useConversationMessages(
// Track seen message content for deduplication
const seenMessageContent = useRef<Set<string>>(new Set());
// AbortController for cancelling in-flight requests on conversation change
const abortControllerRef = useRef<AbortController | null>(null);
// Ref to track the conversation ID being fetched to prevent stale responses
const fetchingConversationIdRef = useRef<string | null>(null);
// Fetch messages for active conversation
// Note: This is called manually and from the useEffect. The useEffect handles
// cancellation via AbortController; manual calls (e.g., after sending a message)
// don't need cancellation.
const fetchMessages = useCallback(
async (showLoading = false) => {
async (showLoading = false, signal?: AbortSignal) => {
if (!activeConversation || activeConversation.type === 'raw') {
setMessages([]);
setHasOlderMessages(false);
return;
}
// Track which conversation we're fetching for
const conversationId = activeConversation.id;
if (showLoading) {
setMessagesLoading(true);
// Clear messages first so MessageList resets scroll state for new conversation
setMessages([]);
}
try {
const data = await api.getMessages({
type: activeConversation.type === 'channel' ? 'CHAN' : 'PRIV',
conversation_key: activeConversation.id,
limit: MESSAGE_PAGE_SIZE,
});
const data = await api.getMessages(
{
type: activeConversation.type === 'channel' ? 'CHAN' : 'PRIV',
conversation_key: activeConversation.id,
limit: MESSAGE_PAGE_SIZE,
},
signal
);
// Check if this response is still for the current conversation
// This handles the race where the conversation changed while awaiting
if (fetchingConversationIdRef.current !== conversationId) {
// Stale response - conversation changed while we were fetching
return;
}
setMessages(data);
// Track seen content for new messages
seenMessageContent.current.clear();
@@ -62,6 +85,10 @@ export function useConversationMessages(
// If we got a full page, there might be more
setHasOlderMessages(data.length >= MESSAGE_PAGE_SIZE);
} catch (err) {
// Don't show error toast for aborted requests (user switched conversations)
if (isAbortError(err)) {
return;
}
console.error('Failed to fetch messages:', err);
toast.error('Failed to load messages', {
description: err instanceof Error ? err.message : 'Check your connection',
@@ -114,10 +141,40 @@ export function useConversationMessages(
}
}, [activeConversation, loadingOlder, hasOlderMessages, messages.length]);
// Fetch messages when conversation changes
// Fetch messages when conversation changes, with proper cancellation
useEffect(() => {
fetchMessages(true);
}, [fetchMessages]);
// Abort any previous in-flight request
if (abortControllerRef.current) {
abortControllerRef.current.abort();
}
// Track which conversation we're now fetching
fetchingConversationIdRef.current = activeConversation?.id ?? null;
// Clear state for new conversation
if (!activeConversation || activeConversation.type === 'raw') {
setMessages([]);
setHasOlderMessages(false);
return;
}
// Create new AbortController for this fetch
const controller = new AbortController();
abortControllerRef.current = controller;
// Fetch messages with the abort signal
fetchMessages(true, controller.signal);
// Cleanup: abort request if conversation changes or component unmounts
return () => {
controller.abort();
};
// NOTE: Intentionally omitting fetchMessages and activeConversation from deps:
// - fetchMessages is recreated when activeConversation changes, which would cause infinite loops
// - activeConversation object identity changes on every render; we only care about id/type
// - We use fetchingConversationIdRef and AbortController to handle stale responses safely
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [activeConversation?.id, activeConversation?.type]);
// Add a message if it's new (deduplication)
// Returns true if the message was added, false if it was a duplicate

View File

@@ -0,0 +1,62 @@
/**
* Tests for API utilities.
*/
import { describe, it, expect } from 'vitest';
import { isAbortError } from '../api';
describe('isAbortError', () => {
it('returns true for AbortError', () => {
const controller = new AbortController();
controller.abort();
// Create an error that mimics fetch abort
const error = new DOMException('The operation was aborted', 'AbortError');
expect(isAbortError(error)).toBe(true);
});
it('returns true for Error with name AbortError', () => {
const error = new Error('Request cancelled');
error.name = 'AbortError';
expect(isAbortError(error)).toBe(true);
});
it('returns false for regular Error', () => {
const error = new Error('Something went wrong');
expect(isAbortError(error)).toBe(false);
});
it('returns false for TypeError', () => {
const error = new TypeError('Network failure');
expect(isAbortError(error)).toBe(false);
});
it('returns false for null', () => {
expect(isAbortError(null)).toBe(false);
});
it('returns false for undefined', () => {
expect(isAbortError(undefined)).toBe(false);
});
it('returns false for non-Error objects', () => {
expect(isAbortError({ message: 'error' })).toBe(false);
expect(isAbortError('error string')).toBe(false);
expect(isAbortError(42)).toBe(false);
});
it('returns false for Error subclasses with different names', () => {
class CustomError extends Error {
constructor() {
super('Custom error');
this.name = 'CustomError';
}
}
expect(isAbortError(new CustomError())).toBe(false);
});
});

View File

@@ -193,3 +193,80 @@ describe('parseWebSocketMessage', () => {
expect(onRawPacket).toHaveBeenCalledWith(packetData);
});
});
describe('useWebSocket ref-based handler pattern', () => {
/**
* These tests verify the pattern used in useWebSocket to avoid stale closures.
* The hook stores handlers in a ref and accesses them through the ref in callbacks.
* This ensures that when handlers are updated, the WebSocket still calls the latest version.
*/
it('demonstrates ref pattern prevents stale closure', () => {
// Simulate the ref pattern used in useWebSocket
interface Handlers {
onMessage?: (msg: string) => void;
}
// This simulates what the hook does: store handlers in a ref
const handlersRef: { current: Handlers } = { current: {} };
// First handler version
const firstHandler = vi.fn();
handlersRef.current = { onMessage: firstHandler };
// Simulate what onmessage does: access handlers through ref
const processMessage = (data: string) => {
// This is the pattern: access through ref.current, not closed-over variable
handlersRef.current.onMessage?.(data);
};
// Send first message
processMessage('message1');
expect(firstHandler).toHaveBeenCalledWith('message1');
// Update handler (simulates React re-render with new handler)
const secondHandler = vi.fn();
handlersRef.current = { onMessage: secondHandler };
// Send second message
processMessage('message2');
// First handler should NOT be called again (would happen with stale closure)
expect(firstHandler).toHaveBeenCalledTimes(1);
// Second handler should be called (ref pattern works)
expect(secondHandler).toHaveBeenCalledWith('message2');
});
it('demonstrates stale closure problem without ref pattern', () => {
// This demonstrates the bug we fixed - without refs, handlers become stale
interface Handlers {
onMessage?: (msg: string) => void;
}
// First handler version
const firstHandler = vi.fn();
let handlers: Handlers = { onMessage: firstHandler };
// BAD PATTERN: capture handlers in closure (this is what we fixed)
const capturedHandlers = handlers;
const processMessageBad = (data: string) => {
// This captures `capturedHandlers` at creation time - STALE!
capturedHandlers.onMessage?.(data);
};
// Send first message
processMessageBad('message1');
expect(firstHandler).toHaveBeenCalledWith('message1');
// Update handler
const secondHandler = vi.fn();
handlers = { onMessage: secondHandler };
// Send second message - BUG: still calls first handler!
processMessageBad('message2');
// This demonstrates the stale closure bug
expect(firstHandler).toHaveBeenCalledTimes(2); // Called twice - bug!
expect(secondHandler).not.toHaveBeenCalled(); // Never called - bug!
});
});

View File

@@ -33,6 +33,19 @@ export function useWebSocket(options: UseWebSocketOptions) {
const reconnectTimeoutRef = useRef<number | null>(null);
const [connected, setConnected] = useState(false);
// Store options in ref to avoid stale closures in WebSocket handlers.
// The onmessage callback captures this ref, and we keep the ref updated
// with the latest handlers. This way, even though the WebSocket connection
// is only created once, it always calls the current handlers.
const optionsRef = useRef<UseWebSocketOptions>(options);
// Keep the ref updated with latest options
useEffect(() => {
optionsRef.current = options;
}, [options]);
// Connect function - uses ref for handlers to avoid stale closures
// No dependencies needed since we access handlers through ref
const connect = useCallback(() => {
// Determine WebSocket URL based on current location
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
@@ -68,25 +81,27 @@ export function useWebSocket(options: UseWebSocketOptions) {
ws.onmessage = (event) => {
try {
const msg: WebSocketMessage = JSON.parse(event.data);
// Access handlers through ref to always use current versions
const handlers = optionsRef.current;
switch (msg.type) {
case 'health':
options.onHealth?.(msg.data as HealthStatus);
handlers.onHealth?.(msg.data as HealthStatus);
break;
case 'contacts':
options.onContacts?.(msg.data as Contact[]);
handlers.onContacts?.(msg.data as Contact[]);
break;
case 'channels':
options.onChannels?.(msg.data as Channel[]);
handlers.onChannels?.(msg.data as Channel[]);
break;
case 'message':
options.onMessage?.(msg.data as Message);
handlers.onMessage?.(msg.data as Message);
break;
case 'contact':
options.onContact?.(msg.data as Contact);
handlers.onContact?.(msg.data as Contact);
break;
case 'raw_packet':
options.onRawPacket?.(msg.data as RawPacket);
handlers.onRawPacket?.(msg.data as RawPacket);
break;
case 'message_acked': {
const ackData = msg.data as {
@@ -94,14 +109,14 @@ export function useWebSocket(options: UseWebSocketOptions) {
ack_count: number;
paths?: MessagePath[];
};
options.onMessageAcked?.(ackData.message_id, ackData.ack_count, ackData.paths);
handlers.onMessageAcked?.(ackData.message_id, ackData.ack_count, ackData.paths);
break;
}
case 'error':
options.onError?.(msg.data as ErrorEvent);
handlers.onError?.(msg.data as ErrorEvent);
break;
case 'success':
options.onSuccess?.(msg.data as SuccessEvent);
handlers.onSuccess?.(msg.data as SuccessEvent);
break;
case 'pong':
// Heartbeat response, ignore
@@ -115,7 +130,7 @@ export function useWebSocket(options: UseWebSocketOptions) {
};
wsRef.current = ws;
}, [options]);
}, []); // No dependencies - handlers accessed through ref
useEffect(() => {
connect();

205
tests/test_websocket.py Normal file
View File

@@ -0,0 +1,205 @@
"""Tests for WebSocket manager functionality."""
import asyncio
from unittest.mock import AsyncMock
import pytest
from app.websocket import SEND_TIMEOUT_SECONDS, WebSocketManager
@pytest.fixture
def ws_manager():
"""Create a fresh WebSocketManager for each test."""
return WebSocketManager()
@pytest.fixture
def mock_websocket():
"""Create a mock WebSocket connection."""
ws = AsyncMock()
ws.send_text = AsyncMock()
return ws
class TestWebSocketBroadcast:
"""Tests for the broadcast functionality."""
@pytest.mark.asyncio
async def test_broadcast_sends_to_all_clients(self, ws_manager: WebSocketManager):
"""Broadcast should send message to all connected clients."""
ws1 = AsyncMock()
ws2 = AsyncMock()
ws1.accept = AsyncMock()
ws2.accept = AsyncMock()
await ws_manager.connect(ws1)
await ws_manager.connect(ws2)
await ws_manager.broadcast("test", {"key": "value"})
# Both clients should receive the message
ws1.send_text.assert_called_once()
ws2.send_text.assert_called_once()
# Verify the message format
import json
expected = json.dumps({"type": "test", "data": {"key": "value"}})
ws1.send_text.assert_called_with(expected)
ws2.send_text.assert_called_with(expected)
@pytest.mark.asyncio
async def test_broadcast_removes_failed_clients(self, ws_manager: WebSocketManager):
"""Clients that fail to receive should be removed."""
good_ws = AsyncMock()
bad_ws = AsyncMock()
good_ws.accept = AsyncMock()
bad_ws.accept = AsyncMock()
bad_ws.send_text.side_effect = Exception("Connection closed")
await ws_manager.connect(good_ws)
await ws_manager.connect(bad_ws)
assert len(ws_manager.active_connections) == 2
await ws_manager.broadcast("test", {})
# Bad client should be removed
assert len(ws_manager.active_connections) == 1
assert good_ws in ws_manager.active_connections
assert bad_ws not in ws_manager.active_connections
@pytest.mark.asyncio
async def test_broadcast_handles_timeout(self, ws_manager: WebSocketManager):
"""Clients that timeout should be removed."""
good_ws = AsyncMock()
slow_ws = AsyncMock()
good_ws.accept = AsyncMock()
slow_ws.accept = AsyncMock()
# Make slow_ws hang indefinitely
async def slow_send(_):
await asyncio.sleep(SEND_TIMEOUT_SECONDS + 1)
slow_ws.send_text.side_effect = slow_send
await ws_manager.connect(good_ws)
await ws_manager.connect(slow_ws)
assert len(ws_manager.active_connections) == 2
# Broadcast should complete despite slow client (due to timeout)
await ws_manager.broadcast("test", {})
# Slow client should be removed due to timeout
assert len(ws_manager.active_connections) == 1
assert good_ws in ws_manager.active_connections
@pytest.mark.asyncio
async def test_broadcast_concurrent_sends(self, ws_manager: WebSocketManager):
"""Verify that sends happen concurrently, not sequentially."""
call_times = []
async def record_send_time(ws_name):
async def _send(_):
call_times.append((ws_name, asyncio.get_event_loop().time()))
await asyncio.sleep(0.1) # Simulate some work
return _send
ws1 = AsyncMock()
ws2 = AsyncMock()
ws3 = AsyncMock()
ws1.accept = AsyncMock()
ws2.accept = AsyncMock()
ws3.accept = AsyncMock()
ws1.send_text.side_effect = await record_send_time("ws1")
ws2.send_text.side_effect = await record_send_time("ws2")
ws3.send_text.side_effect = await record_send_time("ws3")
await ws_manager.connect(ws1)
await ws_manager.connect(ws2)
await ws_manager.connect(ws3)
start_time = asyncio.get_event_loop().time()
await ws_manager.broadcast("test", {})
elapsed = asyncio.get_event_loop().time() - start_time
# If sequential: 3 * 0.1 = 0.3s
# If concurrent: ~0.1s
# Allow some margin for test overhead
assert elapsed < 0.2, f"Sends should be concurrent, took {elapsed}s"
@pytest.mark.asyncio
async def test_broadcast_does_not_block_on_slow_client(self, ws_manager: WebSocketManager):
"""A slow client should not block messages to fast clients."""
fast_ws = AsyncMock()
slow_ws = AsyncMock()
fast_ws.accept = AsyncMock()
slow_ws.accept = AsyncMock()
fast_received_at = None
slow_received_at = None
async def fast_send(_):
nonlocal fast_received_at
fast_received_at = asyncio.get_event_loop().time()
async def slow_send(_):
nonlocal slow_received_at
await asyncio.sleep(0.2) # Slow client
slow_received_at = asyncio.get_event_loop().time()
fast_ws.send_text.side_effect = fast_send
slow_ws.send_text.side_effect = slow_send
await ws_manager.connect(slow_ws)
await ws_manager.connect(fast_ws)
start_time = asyncio.get_event_loop().time()
await ws_manager.broadcast("test", {})
# Fast client should receive message quickly, not waiting for slow client
assert fast_received_at is not None
assert fast_received_at - start_time < 0.1, "Fast client was blocked by slow client"
@pytest.mark.asyncio
async def test_broadcast_empty_connections(self, ws_manager: WebSocketManager):
"""Broadcast should handle empty connection list gracefully."""
# Should not raise
await ws_manager.broadcast("test", {"data": "value"})
class TestWebSocketConnectionManagement:
"""Tests for connection/disconnection."""
@pytest.mark.asyncio
async def test_connect_adds_to_list(self, ws_manager: WebSocketManager, mock_websocket):
"""Connect should add websocket to active connections."""
assert len(ws_manager.active_connections) == 0
await ws_manager.connect(mock_websocket)
assert len(ws_manager.active_connections) == 1
assert mock_websocket in ws_manager.active_connections
@pytest.mark.asyncio
async def test_disconnect_removes_from_list(self, ws_manager: WebSocketManager, mock_websocket):
"""Disconnect should remove websocket from active connections."""
await ws_manager.connect(mock_websocket)
assert len(ws_manager.active_connections) == 1
await ws_manager.disconnect(mock_websocket)
assert len(ws_manager.active_connections) == 0
@pytest.mark.asyncio
async def test_disconnect_nonexistent_is_safe(
self, ws_manager: WebSocketManager, mock_websocket
):
"""Disconnecting a non-connected websocket should not raise."""
# Should not raise
await ws_manager.disconnect(mock_websocket)
assert len(ws_manager.active_connections) == 0