diff --git a/app/fanout/manager.py b/app/fanout/manager.py index c083360..644b643 100644 --- a/app/fanout/manager.py +++ b/app/fanout/manager.py @@ -2,12 +2,14 @@ from __future__ import annotations +import asyncio import logging from typing import Any from app.fanout.base import FanoutModule logger = logging.getLogger(__name__) +_DISPATCH_TIMEOUT_SECONDS = 30.0 # Type string -> module class mapping (extended in Phase 2/3) _MODULE_TYPES: dict[str, type] = {} @@ -77,6 +79,7 @@ class FanoutManager: def __init__(self) -> None: self._modules: dict[str, tuple[FanoutModule, dict]] = {} # id -> (module, scope) + self._restart_locks: dict[str, asyncio.Lock] = {} async def load_from_db(self) -> None: """Read enabled fanout_configs and instantiate modules.""" @@ -137,24 +140,77 @@ class FanoutManager: await module.stop() except Exception: logger.exception("Error stopping fanout module %s", config_id) + self._restart_locks.pop(config_id, None) + + async def _dispatch_matching( + self, + data: dict, + *, + matcher: Any, + handler_name: str, + log_label: str, + ) -> None: + """Dispatch to all matching modules concurrently.""" + tasks = [] + for config_id, (module, scope) in list(self._modules.items()): + if matcher(scope, data): + tasks.append(self._run_handler(config_id, module, handler_name, data, log_label)) + if tasks: + await asyncio.gather(*tasks) + + async def _run_handler( + self, + config_id: str, + module: FanoutModule, + handler_name: str, + data: dict, + log_label: str, + ) -> None: + """Run one module handler with per-module exception isolation.""" + try: + handler = getattr(module, handler_name) + await asyncio.wait_for(handler(data), timeout=_DISPATCH_TIMEOUT_SECONDS) + except asyncio.TimeoutError: + logger.error( + "Fanout %s %s timed out after %.1fs; restarting module", + config_id, + log_label, + _DISPATCH_TIMEOUT_SECONDS, + ) + await self._restart_module(config_id, module) + except Exception: + logger.exception("Fanout %s %s error", config_id, log_label) + + async def _restart_module(self, config_id: str, module: FanoutModule) -> None: + """Restart a timed-out module if it is still the active instance.""" + lock = self._restart_locks.setdefault(config_id, asyncio.Lock()) + async with lock: + entry = self._modules.get(config_id) + if entry is None or entry[0] is not module: + return + try: + await module.stop() + await module.start() + except Exception: + logger.exception("Failed to restart timed-out fanout module %s", config_id) async def broadcast_message(self, data: dict) -> None: """Dispatch a decoded message to modules whose scope matches.""" - for config_id, (module, scope) in list(self._modules.items()): - if _scope_matches_message(scope, data): - try: - await module.on_message(data) - except Exception: - logger.exception("Fanout %s on_message error", config_id) + await self._dispatch_matching( + data, + matcher=_scope_matches_message, + handler_name="on_message", + log_label="on_message", + ) async def broadcast_raw(self, data: dict) -> None: """Dispatch a raw packet to modules whose scope matches.""" - for config_id, (module, scope) in list(self._modules.items()): - if _scope_matches_raw(scope, data): - try: - await module.on_raw(data) - except Exception: - logger.exception("Fanout %s on_raw error", config_id) + await self._dispatch_matching( + data, + matcher=_scope_matches_raw, + handler_name="on_raw", + log_label="on_raw", + ) async def stop_all(self) -> None: """Shutdown all modules.""" @@ -164,6 +220,7 @@ class FanoutManager: except Exception: logger.exception("Error stopping fanout module %s", config_id) self._modules.clear() + self._restart_locks.clear() def get_statuses(self) -> dict[str, dict[str, str]]: """Return status info for each active module.""" diff --git a/tests/test_fanout.py b/tests/test_fanout.py index ea0ada9..1d33cab 100644 --- a/tests/test_fanout.py +++ b/tests/test_fanout.py @@ -1,5 +1,6 @@ """Tests for fanout bus: manager, scope matching, repository, and modules.""" +import asyncio from unittest.mock import AsyncMock, patch import pytest @@ -7,6 +8,7 @@ import pytest from app.database import Database from app.fanout.base import FanoutModule from app.fanout.manager import ( + _DISPATCH_TIMEOUT_SECONDS, FanoutManager, _scope_matches_message, _scope_matches_raw, @@ -199,6 +201,59 @@ class TestFanoutManagerDispatch: # Good module should still receive the message despite the bad one failing assert len(good_mod.message_calls) == 1 + @pytest.mark.asyncio + async def test_broadcast_message_dispatches_matching_modules_concurrently(self): + manager = FanoutManager() + + class BlockingModule(StubModule): + def __init__(self): + super().__init__() + self.started = asyncio.Event() + self.release = asyncio.Event() + + async def on_message(self, data: dict) -> None: + self.started.set() + await self.release.wait() + self.message_calls.append(data) + + slow_mod = BlockingModule() + fast_mod = StubModule() + + manager._modules["slow"] = (slow_mod, {"messages": "all"}) + manager._modules["fast"] = (fast_mod, {"messages": "all"}) + + broadcast_task = asyncio.create_task( + manager.broadcast_message({"type": "PRIV", "conversation_key": "pk1"}) + ) + + await slow_mod.started.wait() + await asyncio.sleep(0) + + assert len(fast_mod.message_calls) == 1 + assert not broadcast_task.done() + + slow_mod.release.set() + await broadcast_task + + @pytest.mark.asyncio + async def test_timed_out_module_is_restarted(self): + manager = FanoutManager() + mod = StubModule() + mod.start = AsyncMock() + mod.stop = AsyncMock() + + async def slow_message(data: dict) -> None: + await asyncio.sleep(_DISPATCH_TIMEOUT_SECONDS * 2) + + mod.on_message = slow_message + manager._modules["slow"] = (mod, {"messages": "all"}) + + with patch("app.fanout.manager._DISPATCH_TIMEOUT_SECONDS", 0.01): + await manager.broadcast_message({"type": "PRIV", "conversation_key": "pk1"}) + + mod.stop.assert_called_once() + mod.start.assert_called_once() + def test_get_statuses(self): manager = FanoutManager() mod = StubModule()