"""WebSocket manager for real-time updates.""" import asyncio import json import logging import os from typing import Any from fastapi import WebSocket from app.config import settings logger = logging.getLogger(__name__) # Timeout for individual WebSocket send operations (seconds) # Prevents a slow client from blocking broadcasts to other clients SEND_TIMEOUT_SECONDS = 5.0 class WebSocketManager: """Manages WebSocket connections and broadcasts events.""" def __init__(self): self.active_connections: list[WebSocket] = [] self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket) -> None: await websocket.accept() async with self._lock: self.active_connections.append(websocket) logger.info("WebSocket client connected (%d total)", len(self.active_connections)) async def disconnect(self, websocket: WebSocket) -> None: async with self._lock: if websocket in self.active_connections: self.active_connections.remove(websocket) logger.info("WebSocket client disconnected (%d remaining)", len(self.active_connections)) async def broadcast(self, event_type: str, data: Any) -> None: """Broadcast an event to all connected clients. Uses a copy-then-send pattern to avoid holding the lock during I/O: 1. Copy connection list while holding lock 2. Release lock before sending 3. Send to all clients concurrently with timeout 4. Re-acquire lock to clean up disconnected clients """ if not self.active_connections: return message = json.dumps({"type": event_type, "data": data}) # Copy connection list under lock to avoid holding lock during I/O async with self._lock: connections = list(self.active_connections) if not connections: return # Send to all clients concurrently, collect failures disconnected: list[WebSocket] = [] async def send_to_client(connection: WebSocket) -> None: try: # Timeout prevents blocking on slow/unresponsive clients await asyncio.wait_for(connection.send_text(message), timeout=SEND_TIMEOUT_SECONDS) except asyncio.TimeoutError: logger.debug("Timeout sending to WebSocket client, marking disconnected") disconnected.append(connection) except Exception as e: logger.debug("Failed to send to client: %s", e) disconnected.append(connection) # Send to all clients concurrently await asyncio.gather(*[send_to_client(conn) for conn in connections]) # Clean up disconnected clients (re-acquire lock) if disconnected: async with self._lock: for conn in disconnected: if conn in self.active_connections: self.active_connections.remove(conn) logger.debug("Removed %d disconnected WebSocket clients", len(disconnected)) async def send_personal(self, websocket: WebSocket, event_type: str, data: Any) -> None: """Send an event to a specific client.""" message = json.dumps({"type": event_type, "data": data}) try: await websocket.send_text(message) except Exception as e: logger.debug("Failed to send to client: %s", e) # Global instance ws_manager = WebSocketManager() def broadcast_event(event_type: str, data: dict) -> None: """Schedule a broadcast without blocking. Convenience function that creates an asyncio task to broadcast an event to all connected WebSocket clients. """ asyncio.create_task(ws_manager.broadcast(event_type, data)) def broadcast_error(message: str, details: str | None = None) -> None: """Broadcast an error notification to all connected clients. This appears as a toast notification in the frontend. """ data = {"message": message} if details: data["details"] = details asyncio.create_task(ws_manager.broadcast("error", data)) def broadcast_success(message: str, details: str | None = None) -> None: """Broadcast a success notification to all connected clients. This appears as a toast notification in the frontend. """ data = {"message": message} if details: data["details"] = details asyncio.create_task(ws_manager.broadcast("success", data)) def broadcast_health(radio_connected: bool, serial_port: str | None = None) -> None: """Broadcast health status change to all connected clients.""" from app.repository import RawPacketRepository async def _broadcast(): # Get database file size in MB db_size_mb = 0.0 try: db_size_bytes = os.path.getsize(settings.database_path) db_size_mb = round(db_size_bytes / (1024 * 1024), 2) except OSError: pass # Get oldest undecrypted packet info oldest_ts = None try: oldest_ts = await RawPacketRepository.get_oldest_undecrypted() except RuntimeError: pass # Database not connected await ws_manager.broadcast( "health", { "status": "ok" if radio_connected else "degraded", "radio_connected": radio_connected, "serial_port": serial_port, "database_size_mb": db_size_mb, "oldest_undecrypted_timestamp": oldest_ts, }, ) asyncio.create_task(_broadcast())