add typed websocket event contracts

This commit is contained in:
Jack Kingsman
2026-03-09 17:47:31 -07:00
parent 088dcb39d6
commit 2d781cad56
6 changed files with 246 additions and 11 deletions

97
app/events.py Normal file
View File

@@ -0,0 +1,97 @@
"""Typed WebSocket event contracts and serialization helpers."""
import json
from typing import Any, Literal
from pydantic import TypeAdapter
from typing_extensions import NotRequired, TypedDict
from app.models import Channel, Contact, Message, MessagePath, RawPacketBroadcast
from app.routers.health import HealthResponse
WsEventType = Literal[
"health",
"message",
"contact",
"channel",
"contact_deleted",
"channel_deleted",
"raw_packet",
"message_acked",
"error",
"success",
]
class ContactDeletedPayload(TypedDict):
public_key: str
class ChannelDeletedPayload(TypedDict):
key: str
class MessageAckedPayload(TypedDict):
message_id: int
ack_count: int
paths: NotRequired[list[MessagePath]]
class ToastPayload(TypedDict):
message: str
details: NotRequired[str]
WsEventPayload = (
HealthResponse
| Message
| Contact
| Channel
| ContactDeletedPayload
| ChannelDeletedPayload
| RawPacketBroadcast
| MessageAckedPayload
| ToastPayload
)
_PAYLOAD_ADAPTERS: dict[WsEventType, TypeAdapter[Any]] = {
"health": TypeAdapter(HealthResponse),
"message": TypeAdapter(Message),
"contact": TypeAdapter(Contact),
"channel": TypeAdapter(Channel),
"contact_deleted": TypeAdapter(ContactDeletedPayload),
"channel_deleted": TypeAdapter(ChannelDeletedPayload),
"raw_packet": TypeAdapter(RawPacketBroadcast),
"message_acked": TypeAdapter(MessageAckedPayload),
"error": TypeAdapter(ToastPayload),
"success": TypeAdapter(ToastPayload),
}
def validate_ws_event_payload(event_type: str, data: Any) -> WsEventPayload | Any:
"""Validate known WebSocket payloads; pass unknown events through unchanged."""
adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type]
if adapter is None:
return data
return adapter.validate_python(data)
def dump_ws_event(event_type: str, data: Any) -> str:
"""Serialize a WebSocket event envelope with validation for known event types."""
adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type]
if adapter is None:
return json.dumps({"type": event_type, "data": data})
validated = adapter.validate_python(data)
payload = adapter.dump_python(validated, mode="json")
return json.dumps({"type": event_type, "data": payload})
def dump_ws_event_payload(event_type: str, data: Any) -> Any:
"""Return the JSON-serializable payload for a WebSocket event."""
adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type]
if adapter is None:
return data
validated = adapter.validate_python(data)
return adapter.dump_python(validated, mode="json")

View File

@@ -1,12 +1,13 @@
"""WebSocket manager for real-time updates.""" """WebSocket manager for real-time updates."""
import asyncio import asyncio
import json
import logging import logging
from typing import Any from typing import Any
from fastapi import WebSocket from fastapi import WebSocket
from app.events import dump_ws_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Timeout for individual WebSocket send operations (seconds) # Timeout for individual WebSocket send operations (seconds)
@@ -45,7 +46,7 @@ class WebSocketManager:
if not self.active_connections: if not self.active_connections:
return return
message = json.dumps({"type": event_type, "data": data}) message = dump_ws_event(event_type, data)
# Copy connection list under lock to avoid holding lock during I/O # Copy connection list under lock to avoid holding lock during I/O
async with self._lock: async with self._lock:
@@ -81,7 +82,7 @@ class WebSocketManager:
async def send_personal(self, websocket: WebSocket, event_type: str, data: Any) -> None: async def send_personal(self, websocket: WebSocket, event_type: str, data: Any) -> None:
"""Send an event to a specific client.""" """Send an event to a specific client."""
message = json.dumps({"type": event_type, "data": data}) message = dump_ws_event(event_type, data)
try: try:
await websocket.send_text(message) await websocket.send_text(message)
except Exception as e: except Exception as e:

View File

@@ -0,0 +1,41 @@
import { describe, expect, it } from 'vitest';
import { parseWsEvent } from '../wsEvents';
describe('wsEvents', () => {
it('parses contact_deleted events', () => {
const event = parseWsEvent(
JSON.stringify({ type: 'contact_deleted', data: { public_key: 'aa' } })
);
expect(event).toEqual({
type: 'contact_deleted',
data: { public_key: 'aa' },
});
});
it('parses channel_deleted events', () => {
const event = parseWsEvent(JSON.stringify({ type: 'channel_deleted', data: { key: 'bb' } }));
expect(event).toEqual({
type: 'channel_deleted',
data: { key: 'bb' },
});
});
it('returns unknown events with rawType preserved', () => {
const event = parseWsEvent(JSON.stringify({ type: 'mystery', data: { ok: true } }));
expect(event).toEqual({
type: 'unknown',
rawType: 'mystery',
data: { ok: true },
});
});
it('rejects invalid envelopes', () => {
expect(() => parseWsEvent(JSON.stringify({ data: {} }))).toThrow(
'Invalid WebSocket event envelope'
);
});
});

View File

@@ -1,10 +1,6 @@
import { useEffect, useRef, useCallback } from 'react'; import { useEffect, useRef, useCallback } from 'react';
import type { Channel, HealthStatus, Contact, Message, MessagePath, RawPacket } from './types'; import type { Channel, HealthStatus, Contact, Message, MessagePath, RawPacket } from './types';
import { parseWsEvent } from './wsEvents';
interface WebSocketMessage {
type: string;
data: unknown;
}
interface ErrorEvent { interface ErrorEvent {
message: string; message: string;
@@ -92,7 +88,7 @@ export function useWebSocket(options: UseWebSocketOptions) {
ws.onmessage = (event) => { ws.onmessage = (event) => {
try { try {
const msg: WebSocketMessage = JSON.parse(event.data); const msg = parseWsEvent(event.data);
// Access handlers through ref to always use current versions // Access handlers through ref to always use current versions
const handlers = optionsRef.current; const handlers = optionsRef.current;
@@ -136,8 +132,8 @@ export function useWebSocket(options: UseWebSocketOptions) {
case 'pong': case 'pong':
// Heartbeat response, ignore // Heartbeat response, ignore
break; break;
default: case 'unknown':
console.warn('Unknown WebSocket message type:', msg.type); console.warn('Unknown WebSocket message type:', msg.rawType);
} }
} catch (e) { } catch (e) {
console.error('Failed to parse WebSocket message:', e); console.error('Failed to parse WebSocket message:', e);

78
frontend/src/wsEvents.ts Normal file
View File

@@ -0,0 +1,78 @@
import type { Channel, Contact, HealthStatus, Message, MessagePath, RawPacket } from './types';
export interface MessageAckedPayload {
message_id: number;
ack_count: number;
paths?: MessagePath[];
}
export interface ContactDeletedPayload {
public_key: string;
}
export interface ChannelDeletedPayload {
key: string;
}
export interface ToastPayload {
message: string;
details?: string;
}
export type KnownWsEvent =
| { type: 'health'; data: HealthStatus }
| { type: 'message'; data: Message }
| { type: 'contact'; data: Contact }
| { type: 'channel'; data: Channel }
| { type: 'contact_deleted'; data: ContactDeletedPayload }
| { type: 'channel_deleted'; data: ChannelDeletedPayload }
| { type: 'raw_packet'; data: RawPacket }
| { type: 'message_acked'; data: MessageAckedPayload }
| { type: 'error'; data: ToastPayload }
| { type: 'success'; data: ToastPayload }
| { type: 'pong'; data?: null };
export interface UnknownWsEvent {
type: 'unknown';
rawType: string;
data: unknown;
}
export type ParsedWsEvent = KnownWsEvent | UnknownWsEvent;
interface RawWsEnvelope {
type?: unknown;
data?: unknown;
}
export function parseWsEvent(raw: string): ParsedWsEvent {
const parsed: RawWsEnvelope = JSON.parse(raw);
if (!parsed || typeof parsed !== 'object' || typeof parsed.type !== 'string') {
throw new Error('Invalid WebSocket event envelope');
}
switch (parsed.type) {
case 'health':
case 'message':
case 'contact':
case 'channel':
case 'contact_deleted':
case 'channel_deleted':
case 'raw_packet':
case 'message_acked':
case 'error':
case 'success':
return {
type: parsed.type,
data: parsed.data,
} as KnownWsEvent;
case 'pong':
return { type: 'pong', data: parsed.data as null | undefined };
default:
return {
type: 'unknown',
rawType: parsed.type,
data: parsed.data,
};
}
}

View File

@@ -1,9 +1,11 @@
"""Tests for WebSocket manager functionality.""" """Tests for WebSocket manager functionality."""
import asyncio import asyncio
import json
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from pydantic import ValidationError
from app.websocket import SEND_TIMEOUT_SECONDS, WebSocketManager from app.websocket import SEND_TIMEOUT_SECONDS, WebSocketManager
@@ -245,3 +247,23 @@ class TestBroadcastEventFanout:
mock_ws.broadcast.assert_called_once() mock_ws.broadcast.assert_called_once()
mock_fm.broadcast_raw.assert_called_once_with({"data": "ff00"}) mock_fm.broadcast_raw.assert_called_once_with({"data": "ff00"})
class TestTypedEventSerialization:
"""Tests for typed websocket event serialization."""
def test_dump_ws_event_preserves_optional_message_acked_shape(self):
from app.events import dump_ws_event
serialized = dump_ws_event("message_acked", {"message_id": 7, "ack_count": 2})
assert json.loads(serialized) == {
"type": "message_acked",
"data": {"message_id": 7, "ack_count": 2},
}
def test_dump_ws_event_validates_supported_payloads(self):
from app.events import dump_ws_event
with pytest.raises(ValidationError):
dump_ws_event("message_acked", {"ack_count": 2})