From 2d781cad562330f870f20f0aeeccde2ff49b5328 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 9 Mar 2026 17:47:31 -0700 Subject: [PATCH] add typed websocket event contracts --- app/events.py | 97 ++++++++++++++++++++++++++++++ app/websocket.py | 7 ++- frontend/src/test/wsEvents.test.ts | 41 +++++++++++++ frontend/src/useWebSocket.ts | 12 ++-- frontend/src/wsEvents.ts | 78 ++++++++++++++++++++++++ tests/test_websocket.py | 22 +++++++ 6 files changed, 246 insertions(+), 11 deletions(-) create mode 100644 app/events.py create mode 100644 frontend/src/test/wsEvents.test.ts create mode 100644 frontend/src/wsEvents.ts diff --git a/app/events.py b/app/events.py new file mode 100644 index 0000000..59fb163 --- /dev/null +++ b/app/events.py @@ -0,0 +1,97 @@ +"""Typed WebSocket event contracts and serialization helpers.""" + +import json +from typing import Any, Literal + +from pydantic import TypeAdapter +from typing_extensions import NotRequired, TypedDict + +from app.models import Channel, Contact, Message, MessagePath, RawPacketBroadcast +from app.routers.health import HealthResponse + +WsEventType = Literal[ + "health", + "message", + "contact", + "channel", + "contact_deleted", + "channel_deleted", + "raw_packet", + "message_acked", + "error", + "success", +] + + +class ContactDeletedPayload(TypedDict): + public_key: str + + +class ChannelDeletedPayload(TypedDict): + key: str + + +class MessageAckedPayload(TypedDict): + message_id: int + ack_count: int + paths: NotRequired[list[MessagePath]] + + +class ToastPayload(TypedDict): + message: str + details: NotRequired[str] + + +WsEventPayload = ( + HealthResponse + | Message + | Contact + | Channel + | ContactDeletedPayload + | ChannelDeletedPayload + | RawPacketBroadcast + | MessageAckedPayload + | ToastPayload +) + +_PAYLOAD_ADAPTERS: dict[WsEventType, TypeAdapter[Any]] = { + "health": TypeAdapter(HealthResponse), + "message": TypeAdapter(Message), + "contact": TypeAdapter(Contact), + "channel": TypeAdapter(Channel), + "contact_deleted": TypeAdapter(ContactDeletedPayload), + "channel_deleted": TypeAdapter(ChannelDeletedPayload), + "raw_packet": TypeAdapter(RawPacketBroadcast), + "message_acked": TypeAdapter(MessageAckedPayload), + "error": TypeAdapter(ToastPayload), + "success": TypeAdapter(ToastPayload), +} + + +def validate_ws_event_payload(event_type: str, data: Any) -> WsEventPayload | Any: + """Validate known WebSocket payloads; pass unknown events through unchanged.""" + adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type] + if adapter is None: + return data + return adapter.validate_python(data) + + +def dump_ws_event(event_type: str, data: Any) -> str: + """Serialize a WebSocket event envelope with validation for known event types.""" + adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type] + if adapter is None: + return json.dumps({"type": event_type, "data": data}) + + validated = adapter.validate_python(data) + payload = adapter.dump_python(validated, mode="json") + return json.dumps({"type": event_type, "data": payload}) + + +def dump_ws_event_payload(event_type: str, data: Any) -> Any: + """Return the JSON-serializable payload for a WebSocket event.""" + adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type] + if adapter is None: + return data + + validated = adapter.validate_python(data) + return adapter.dump_python(validated, mode="json") diff --git a/app/websocket.py b/app/websocket.py index 3ceb705..27ebdb0 100644 --- a/app/websocket.py +++ b/app/websocket.py @@ -1,12 +1,13 @@ """WebSocket manager for real-time updates.""" import asyncio -import json import logging from typing import Any from fastapi import WebSocket +from app.events import dump_ws_event + logger = logging.getLogger(__name__) # Timeout for individual WebSocket send operations (seconds) @@ -45,7 +46,7 @@ class WebSocketManager: if not self.active_connections: return - message = json.dumps({"type": event_type, "data": data}) + message = dump_ws_event(event_type, data) # Copy connection list under lock to avoid holding lock during I/O async with self._lock: @@ -81,7 +82,7 @@ class WebSocketManager: async def send_personal(self, websocket: WebSocket, event_type: str, data: Any) -> None: """Send an event to a specific client.""" - message = json.dumps({"type": event_type, "data": data}) + message = dump_ws_event(event_type, data) try: await websocket.send_text(message) except Exception as e: diff --git a/frontend/src/test/wsEvents.test.ts b/frontend/src/test/wsEvents.test.ts new file mode 100644 index 0000000..4ff03be --- /dev/null +++ b/frontend/src/test/wsEvents.test.ts @@ -0,0 +1,41 @@ +import { describe, expect, it } from 'vitest'; + +import { parseWsEvent } from '../wsEvents'; + +describe('wsEvents', () => { + it('parses contact_deleted events', () => { + const event = parseWsEvent( + JSON.stringify({ type: 'contact_deleted', data: { public_key: 'aa' } }) + ); + + expect(event).toEqual({ + type: 'contact_deleted', + data: { public_key: 'aa' }, + }); + }); + + it('parses channel_deleted events', () => { + const event = parseWsEvent(JSON.stringify({ type: 'channel_deleted', data: { key: 'bb' } })); + + expect(event).toEqual({ + type: 'channel_deleted', + data: { key: 'bb' }, + }); + }); + + it('returns unknown events with rawType preserved', () => { + const event = parseWsEvent(JSON.stringify({ type: 'mystery', data: { ok: true } })); + + expect(event).toEqual({ + type: 'unknown', + rawType: 'mystery', + data: { ok: true }, + }); + }); + + it('rejects invalid envelopes', () => { + expect(() => parseWsEvent(JSON.stringify({ data: {} }))).toThrow( + 'Invalid WebSocket event envelope' + ); + }); +}); diff --git a/frontend/src/useWebSocket.ts b/frontend/src/useWebSocket.ts index 19a740b..70548b2 100644 --- a/frontend/src/useWebSocket.ts +++ b/frontend/src/useWebSocket.ts @@ -1,10 +1,6 @@ import { useEffect, useRef, useCallback } from 'react'; import type { Channel, HealthStatus, Contact, Message, MessagePath, RawPacket } from './types'; - -interface WebSocketMessage { - type: string; - data: unknown; -} +import { parseWsEvent } from './wsEvents'; interface ErrorEvent { message: string; @@ -92,7 +88,7 @@ export function useWebSocket(options: UseWebSocketOptions) { ws.onmessage = (event) => { try { - const msg: WebSocketMessage = JSON.parse(event.data); + const msg = parseWsEvent(event.data); // Access handlers through ref to always use current versions const handlers = optionsRef.current; @@ -136,8 +132,8 @@ export function useWebSocket(options: UseWebSocketOptions) { case 'pong': // Heartbeat response, ignore break; - default: - console.warn('Unknown WebSocket message type:', msg.type); + case 'unknown': + console.warn('Unknown WebSocket message type:', msg.rawType); } } catch (e) { console.error('Failed to parse WebSocket message:', e); diff --git a/frontend/src/wsEvents.ts b/frontend/src/wsEvents.ts new file mode 100644 index 0000000..bd0719d --- /dev/null +++ b/frontend/src/wsEvents.ts @@ -0,0 +1,78 @@ +import type { Channel, Contact, HealthStatus, Message, MessagePath, RawPacket } from './types'; + +export interface MessageAckedPayload { + message_id: number; + ack_count: number; + paths?: MessagePath[]; +} + +export interface ContactDeletedPayload { + public_key: string; +} + +export interface ChannelDeletedPayload { + key: string; +} + +export interface ToastPayload { + message: string; + details?: string; +} + +export type KnownWsEvent = + | { type: 'health'; data: HealthStatus } + | { type: 'message'; data: Message } + | { type: 'contact'; data: Contact } + | { type: 'channel'; data: Channel } + | { type: 'contact_deleted'; data: ContactDeletedPayload } + | { type: 'channel_deleted'; data: ChannelDeletedPayload } + | { type: 'raw_packet'; data: RawPacket } + | { type: 'message_acked'; data: MessageAckedPayload } + | { type: 'error'; data: ToastPayload } + | { type: 'success'; data: ToastPayload } + | { type: 'pong'; data?: null }; + +export interface UnknownWsEvent { + type: 'unknown'; + rawType: string; + data: unknown; +} + +export type ParsedWsEvent = KnownWsEvent | UnknownWsEvent; + +interface RawWsEnvelope { + type?: unknown; + data?: unknown; +} + +export function parseWsEvent(raw: string): ParsedWsEvent { + const parsed: RawWsEnvelope = JSON.parse(raw); + if (!parsed || typeof parsed !== 'object' || typeof parsed.type !== 'string') { + throw new Error('Invalid WebSocket event envelope'); + } + + switch (parsed.type) { + case 'health': + case 'message': + case 'contact': + case 'channel': + case 'contact_deleted': + case 'channel_deleted': + case 'raw_packet': + case 'message_acked': + case 'error': + case 'success': + return { + type: parsed.type, + data: parsed.data, + } as KnownWsEvent; + case 'pong': + return { type: 'pong', data: parsed.data as null | undefined }; + default: + return { + type: 'unknown', + rawType: parsed.type, + data: parsed.data, + }; + } +} diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 7bd361f..49930ed 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,9 +1,11 @@ """Tests for WebSocket manager functionality.""" import asyncio +import json from unittest.mock import AsyncMock, patch import pytest +from pydantic import ValidationError from app.websocket import SEND_TIMEOUT_SECONDS, WebSocketManager @@ -245,3 +247,23 @@ class TestBroadcastEventFanout: mock_ws.broadcast.assert_called_once() mock_fm.broadcast_raw.assert_called_once_with({"data": "ff00"}) + + +class TestTypedEventSerialization: + """Tests for typed websocket event serialization.""" + + def test_dump_ws_event_preserves_optional_message_acked_shape(self): + from app.events import dump_ws_event + + serialized = dump_ws_event("message_acked", {"message_id": 7, "ack_count": 2}) + + assert json.loads(serialized) == { + "type": "message_acked", + "data": {"message_id": 7, "ack_count": 2}, + } + + def test_dump_ws_event_validates_supported_payloads(self): + from app.events import dump_ws_event + + with pytest.raises(ValidationError): + dump_ws_event("message_acked", {"ack_count": 2})