Parallelize fanouts and add module watchdog

This commit is contained in:
Jack Kingsman
2026-03-05 23:05:57 -08:00
parent 5808504ee0
commit 55ac9df681
2 changed files with 124 additions and 12 deletions

View File

@@ -2,12 +2,14 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any
from app.fanout.base import FanoutModule
logger = logging.getLogger(__name__)
_DISPATCH_TIMEOUT_SECONDS = 30.0
# Type string -> module class mapping (extended in Phase 2/3)
_MODULE_TYPES: dict[str, type] = {}
@@ -77,6 +79,7 @@ class FanoutManager:
def __init__(self) -> None:
self._modules: dict[str, tuple[FanoutModule, dict]] = {} # id -> (module, scope)
self._restart_locks: dict[str, asyncio.Lock] = {}
async def load_from_db(self) -> None:
"""Read enabled fanout_configs and instantiate modules."""
@@ -137,24 +140,77 @@ class FanoutManager:
await module.stop()
except Exception:
logger.exception("Error stopping fanout module %s", config_id)
self._restart_locks.pop(config_id, None)
async def _dispatch_matching(
self,
data: dict,
*,
matcher: Any,
handler_name: str,
log_label: str,
) -> None:
"""Dispatch to all matching modules concurrently."""
tasks = []
for config_id, (module, scope) in list(self._modules.items()):
if matcher(scope, data):
tasks.append(self._run_handler(config_id, module, handler_name, data, log_label))
if tasks:
await asyncio.gather(*tasks)
async def _run_handler(
self,
config_id: str,
module: FanoutModule,
handler_name: str,
data: dict,
log_label: str,
) -> None:
"""Run one module handler with per-module exception isolation."""
try:
handler = getattr(module, handler_name)
await asyncio.wait_for(handler(data), timeout=_DISPATCH_TIMEOUT_SECONDS)
except asyncio.TimeoutError:
logger.error(
"Fanout %s %s timed out after %.1fs; restarting module",
config_id,
log_label,
_DISPATCH_TIMEOUT_SECONDS,
)
await self._restart_module(config_id, module)
except Exception:
logger.exception("Fanout %s %s error", config_id, log_label)
async def _restart_module(self, config_id: str, module: FanoutModule) -> None:
"""Restart a timed-out module if it is still the active instance."""
lock = self._restart_locks.setdefault(config_id, asyncio.Lock())
async with lock:
entry = self._modules.get(config_id)
if entry is None or entry[0] is not module:
return
try:
await module.stop()
await module.start()
except Exception:
logger.exception("Failed to restart timed-out fanout module %s", config_id)
async def broadcast_message(self, data: dict) -> None:
"""Dispatch a decoded message to modules whose scope matches."""
for config_id, (module, scope) in list(self._modules.items()):
if _scope_matches_message(scope, data):
try:
await module.on_message(data)
except Exception:
logger.exception("Fanout %s on_message error", config_id)
await self._dispatch_matching(
data,
matcher=_scope_matches_message,
handler_name="on_message",
log_label="on_message",
)
async def broadcast_raw(self, data: dict) -> None:
"""Dispatch a raw packet to modules whose scope matches."""
for config_id, (module, scope) in list(self._modules.items()):
if _scope_matches_raw(scope, data):
try:
await module.on_raw(data)
except Exception:
logger.exception("Fanout %s on_raw error", config_id)
await self._dispatch_matching(
data,
matcher=_scope_matches_raw,
handler_name="on_raw",
log_label="on_raw",
)
async def stop_all(self) -> None:
"""Shutdown all modules."""
@@ -164,6 +220,7 @@ class FanoutManager:
except Exception:
logger.exception("Error stopping fanout module %s", config_id)
self._modules.clear()
self._restart_locks.clear()
def get_statuses(self) -> dict[str, dict[str, str]]:
"""Return status info for each active module."""

View File

@@ -1,5 +1,6 @@
"""Tests for fanout bus: manager, scope matching, repository, and modules."""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
@@ -7,6 +8,7 @@ import pytest
from app.database import Database
from app.fanout.base import FanoutModule
from app.fanout.manager import (
_DISPATCH_TIMEOUT_SECONDS,
FanoutManager,
_scope_matches_message,
_scope_matches_raw,
@@ -199,6 +201,59 @@ class TestFanoutManagerDispatch:
# Good module should still receive the message despite the bad one failing
assert len(good_mod.message_calls) == 1
@pytest.mark.asyncio
async def test_broadcast_message_dispatches_matching_modules_concurrently(self):
manager = FanoutManager()
class BlockingModule(StubModule):
def __init__(self):
super().__init__()
self.started = asyncio.Event()
self.release = asyncio.Event()
async def on_message(self, data: dict) -> None:
self.started.set()
await self.release.wait()
self.message_calls.append(data)
slow_mod = BlockingModule()
fast_mod = StubModule()
manager._modules["slow"] = (slow_mod, {"messages": "all"})
manager._modules["fast"] = (fast_mod, {"messages": "all"})
broadcast_task = asyncio.create_task(
manager.broadcast_message({"type": "PRIV", "conversation_key": "pk1"})
)
await slow_mod.started.wait()
await asyncio.sleep(0)
assert len(fast_mod.message_calls) == 1
assert not broadcast_task.done()
slow_mod.release.set()
await broadcast_task
@pytest.mark.asyncio
async def test_timed_out_module_is_restarted(self):
manager = FanoutManager()
mod = StubModule()
mod.start = AsyncMock()
mod.stop = AsyncMock()
async def slow_message(data: dict) -> None:
await asyncio.sleep(_DISPATCH_TIMEOUT_SECONDS * 2)
mod.on_message = slow_message
manager._modules["slow"] = (mod, {"messages": "all"})
with patch("app.fanout.manager._DISPATCH_TIMEOUT_SECONDS", 0.01):
await manager.broadcast_message({"type": "PRIV", "conversation_key": "pk1"})
mod.stop.assert_called_once()
mod.start.assert_called_once()
def test_get_statuses(self):
manager = FanoutManager()
mod = StubModule()