add typed websocket event contracts

This commit is contained in:
Jack Kingsman
2026-03-09 17:47:31 -07:00
parent 088dcb39d6
commit 2d781cad56
6 changed files with 246 additions and 11 deletions

97
app/events.py Normal file
View File

@@ -0,0 +1,97 @@
"""Typed WebSocket event contracts and serialization helpers."""
import json
from typing import Any, Literal
from pydantic import TypeAdapter
from typing_extensions import NotRequired, TypedDict
from app.models import Channel, Contact, Message, MessagePath, RawPacketBroadcast
from app.routers.health import HealthResponse
WsEventType = Literal[
"health",
"message",
"contact",
"channel",
"contact_deleted",
"channel_deleted",
"raw_packet",
"message_acked",
"error",
"success",
]
class ContactDeletedPayload(TypedDict):
public_key: str
class ChannelDeletedPayload(TypedDict):
key: str
class MessageAckedPayload(TypedDict):
message_id: int
ack_count: int
paths: NotRequired[list[MessagePath]]
class ToastPayload(TypedDict):
message: str
details: NotRequired[str]
WsEventPayload = (
HealthResponse
| Message
| Contact
| Channel
| ContactDeletedPayload
| ChannelDeletedPayload
| RawPacketBroadcast
| MessageAckedPayload
| ToastPayload
)
_PAYLOAD_ADAPTERS: dict[WsEventType, TypeAdapter[Any]] = {
"health": TypeAdapter(HealthResponse),
"message": TypeAdapter(Message),
"contact": TypeAdapter(Contact),
"channel": TypeAdapter(Channel),
"contact_deleted": TypeAdapter(ContactDeletedPayload),
"channel_deleted": TypeAdapter(ChannelDeletedPayload),
"raw_packet": TypeAdapter(RawPacketBroadcast),
"message_acked": TypeAdapter(MessageAckedPayload),
"error": TypeAdapter(ToastPayload),
"success": TypeAdapter(ToastPayload),
}
def validate_ws_event_payload(event_type: str, data: Any) -> WsEventPayload | Any:
"""Validate known WebSocket payloads; pass unknown events through unchanged."""
adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type]
if adapter is None:
return data
return adapter.validate_python(data)
def dump_ws_event(event_type: str, data: Any) -> str:
"""Serialize a WebSocket event envelope with validation for known event types."""
adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type]
if adapter is None:
return json.dumps({"type": event_type, "data": data})
validated = adapter.validate_python(data)
payload = adapter.dump_python(validated, mode="json")
return json.dumps({"type": event_type, "data": payload})
def dump_ws_event_payload(event_type: str, data: Any) -> Any:
"""Return the JSON-serializable payload for a WebSocket event."""
adapter = _PAYLOAD_ADAPTERS.get(event_type) # type: ignore[arg-type]
if adapter is None:
return data
validated = adapter.validate_python(data)
return adapter.dump_python(validated, mode="json")

View File

@@ -1,12 +1,13 @@
"""WebSocket manager for real-time updates."""
import asyncio
import json
import logging
from typing import Any
from fastapi import WebSocket
from app.events import dump_ws_event
logger = logging.getLogger(__name__)
# Timeout for individual WebSocket send operations (seconds)
@@ -45,7 +46,7 @@ class WebSocketManager:
if not self.active_connections:
return
message = json.dumps({"type": event_type, "data": data})
message = dump_ws_event(event_type, data)
# Copy connection list under lock to avoid holding lock during I/O
async with self._lock:
@@ -81,7 +82,7 @@ class WebSocketManager:
async def send_personal(self, websocket: WebSocket, event_type: str, data: Any) -> None:
"""Send an event to a specific client."""
message = json.dumps({"type": event_type, "data": data})
message = dump_ws_event(event_type, data)
try:
await websocket.send_text(message)
except Exception as e:

View File

@@ -0,0 +1,41 @@
import { describe, expect, it } from 'vitest';
import { parseWsEvent } from '../wsEvents';
describe('wsEvents', () => {
it('parses contact_deleted events', () => {
const event = parseWsEvent(
JSON.stringify({ type: 'contact_deleted', data: { public_key: 'aa' } })
);
expect(event).toEqual({
type: 'contact_deleted',
data: { public_key: 'aa' },
});
});
it('parses channel_deleted events', () => {
const event = parseWsEvent(JSON.stringify({ type: 'channel_deleted', data: { key: 'bb' } }));
expect(event).toEqual({
type: 'channel_deleted',
data: { key: 'bb' },
});
});
it('returns unknown events with rawType preserved', () => {
const event = parseWsEvent(JSON.stringify({ type: 'mystery', data: { ok: true } }));
expect(event).toEqual({
type: 'unknown',
rawType: 'mystery',
data: { ok: true },
});
});
it('rejects invalid envelopes', () => {
expect(() => parseWsEvent(JSON.stringify({ data: {} }))).toThrow(
'Invalid WebSocket event envelope'
);
});
});

View File

@@ -1,10 +1,6 @@
import { useEffect, useRef, useCallback } from 'react';
import type { Channel, HealthStatus, Contact, Message, MessagePath, RawPacket } from './types';
interface WebSocketMessage {
type: string;
data: unknown;
}
import { parseWsEvent } from './wsEvents';
interface ErrorEvent {
message: string;
@@ -92,7 +88,7 @@ export function useWebSocket(options: UseWebSocketOptions) {
ws.onmessage = (event) => {
try {
const msg: WebSocketMessage = JSON.parse(event.data);
const msg = parseWsEvent(event.data);
// Access handlers through ref to always use current versions
const handlers = optionsRef.current;
@@ -136,8 +132,8 @@ export function useWebSocket(options: UseWebSocketOptions) {
case 'pong':
// Heartbeat response, ignore
break;
default:
console.warn('Unknown WebSocket message type:', msg.type);
case 'unknown':
console.warn('Unknown WebSocket message type:', msg.rawType);
}
} catch (e) {
console.error('Failed to parse WebSocket message:', e);

78
frontend/src/wsEvents.ts Normal file
View File

@@ -0,0 +1,78 @@
import type { Channel, Contact, HealthStatus, Message, MessagePath, RawPacket } from './types';
export interface MessageAckedPayload {
message_id: number;
ack_count: number;
paths?: MessagePath[];
}
export interface ContactDeletedPayload {
public_key: string;
}
export interface ChannelDeletedPayload {
key: string;
}
export interface ToastPayload {
message: string;
details?: string;
}
export type KnownWsEvent =
| { type: 'health'; data: HealthStatus }
| { type: 'message'; data: Message }
| { type: 'contact'; data: Contact }
| { type: 'channel'; data: Channel }
| { type: 'contact_deleted'; data: ContactDeletedPayload }
| { type: 'channel_deleted'; data: ChannelDeletedPayload }
| { type: 'raw_packet'; data: RawPacket }
| { type: 'message_acked'; data: MessageAckedPayload }
| { type: 'error'; data: ToastPayload }
| { type: 'success'; data: ToastPayload }
| { type: 'pong'; data?: null };
export interface UnknownWsEvent {
type: 'unknown';
rawType: string;
data: unknown;
}
export type ParsedWsEvent = KnownWsEvent | UnknownWsEvent;
interface RawWsEnvelope {
type?: unknown;
data?: unknown;
}
export function parseWsEvent(raw: string): ParsedWsEvent {
const parsed: RawWsEnvelope = JSON.parse(raw);
if (!parsed || typeof parsed !== 'object' || typeof parsed.type !== 'string') {
throw new Error('Invalid WebSocket event envelope');
}
switch (parsed.type) {
case 'health':
case 'message':
case 'contact':
case 'channel':
case 'contact_deleted':
case 'channel_deleted':
case 'raw_packet':
case 'message_acked':
case 'error':
case 'success':
return {
type: parsed.type,
data: parsed.data,
} as KnownWsEvent;
case 'pong':
return { type: 'pong', data: parsed.data as null | undefined };
default:
return {
type: 'unknown',
rawType: parsed.type,
data: parsed.data,
};
}
}

View File

@@ -1,9 +1,11 @@
"""Tests for WebSocket manager functionality."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import ValidationError
from app.websocket import SEND_TIMEOUT_SECONDS, WebSocketManager
@@ -245,3 +247,23 @@ class TestBroadcastEventFanout:
mock_ws.broadcast.assert_called_once()
mock_fm.broadcast_raw.assert_called_once_with({"data": "ff00"})
class TestTypedEventSerialization:
"""Tests for typed websocket event serialization."""
def test_dump_ws_event_preserves_optional_message_acked_shape(self):
from app.events import dump_ws_event
serialized = dump_ws_event("message_acked", {"message_id": 7, "ack_count": 2})
assert json.loads(serialized) == {
"type": "message_acked",
"data": {"message_id": 7, "ack_count": 2},
}
def test_dump_ws_event_validates_supported_payloads(self):
from app.events import dump_ws_event
with pytest.raises(ValidationError):
dump_ws_event("message_acked", {"ack_count": 2})