From 2107d6790df6ee6ebac5b338b810f16f9c9b7b9c Mon Sep 17 00:00:00 2001 From: l5y <220195275+l5yth@users.noreply.github.com> Date: Wed, 12 Nov 2025 12:39:36 +0100 Subject: [PATCH] Guard NodeInfo handler against missing IDs (#426) (#431) --- data/mesh_ingestor/interfaces.py | 248 ++++++++++++++++++++----------- tests/test_mesh.py | 51 +++++++ 2 files changed, 212 insertions(+), 87 deletions(-) diff --git a/data/mesh_ingestor/interfaces.py b/data/mesh_ingestor/interfaces.py index 17823e3..9de9fde 100644 --- a/data/mesh_ingestor/interfaces.py +++ b/data/mesh_ingestor/interfaces.py @@ -17,6 +17,7 @@ from __future__ import annotations import contextlib +import importlib import glob import ipaddress import re @@ -29,6 +30,108 @@ from meshtastic.tcp_interface import TCPInterface from . import channels, config, serialization + +def _ensure_mapping(value) -> Mapping | None: + """Return ``value`` as a mapping when conversion is possible.""" + + if isinstance(value, Mapping): + return value + if hasattr(value, "__dict__") and isinstance(value.__dict__, Mapping): + return value.__dict__ + with contextlib.suppress(Exception): + converted = serialization._node_to_dict(value) + if isinstance(converted, Mapping): + return converted + return None + + +def _candidate_node_id(mapping: Mapping | None) -> str | None: + """Extract a canonical node identifier from ``mapping`` when present.""" + + if mapping is None: + return None + + primary_keys = ( + "id", + "userId", + "user_id", + "fromId", + "from_id", + "from", + "nodeId", + "node_id", + "nodeNum", + "node_num", + "num", + ) + + for key in primary_keys: + with contextlib.suppress(Exception): + node_id = serialization._canonical_node_id(mapping.get(key)) + if node_id: + return node_id + + user_section = _ensure_mapping(mapping.get("user")) + if user_section is not None: + for key in ("id", "userId", "user_id", "num", "nodeNum", "node_num"): + with contextlib.suppress(Exception): + node_id = serialization._canonical_node_id(user_section.get(key)) + if node_id: + return node_id + + decoded_section = _ensure_mapping(mapping.get("decoded")) + if decoded_section is not None: + node_id = _candidate_node_id(decoded_section) + if node_id: + return node_id + + payload_section = _ensure_mapping(mapping.get("payload")) + if payload_section is not None: + node_id = _candidate_node_id(payload_section) + if node_id: + return node_id + + for key in ("packet", "meta", "info"): + node_id = _candidate_node_id(_ensure_mapping(mapping.get(key))) + if node_id: + return node_id + + for value in mapping.values(): + if isinstance(value, (list, tuple)): + for item in value: + node_id = _candidate_node_id(_ensure_mapping(item)) + if node_id: + return node_id + else: + node_id = _candidate_node_id(_ensure_mapping(value)) + if node_id: + return node_id + + return None + + +def _normalise_nodeinfo_packet(packet) -> dict | None: + """Return a dictionary view of ``packet`` with a guaranteed ``id`` when known.""" + + mapping = _ensure_mapping(packet) + if mapping is None: + return None + + try: + normalised: dict = dict(mapping) + except Exception: + try: + normalised = {key: mapping[key] for key in mapping} + except Exception: + return None + + node_id = _candidate_node_id(normalised) + if node_id and normalised.get("id") != node_id: + normalised["id"] = node_id + + return normalised + + if TYPE_CHECKING: # pragma: no cover - import only used for type checking from meshtastic.ble_interface import BLEInterface as _BLEInterface @@ -46,100 +149,25 @@ def _patch_meshtastic_nodeinfo_handler() -> None: original = getattr(meshtastic, "_onNodeInfoReceive", None) if not callable(original): return - if getattr(original, "_potato_mesh_safe_wrapper", False): - return - def _ensure_mapping(value) -> Mapping | None: - """Return ``value`` as a mapping when conversion is possible.""" - - if isinstance(value, Mapping): - return value - if hasattr(value, "__dict__") and isinstance(value.__dict__, Mapping): - return value.__dict__ + mesh_interface_module = getattr(meshtastic, "mesh_interface", None) + if mesh_interface_module is None: with contextlib.suppress(Exception): - converted = serialization._node_to_dict(value) - if isinstance(converted, Mapping): - return converted - return None + mesh_interface_module = importlib.import_module("meshtastic.mesh_interface") - def _candidate_node_id(mapping: Mapping | None) -> str | None: - """Extract a canonical node identifier from ``mapping`` when present.""" + if not getattr(original, "_potato_mesh_safe_wrapper", False): + meshtastic._onNodeInfoReceive = _build_safe_nodeinfo_callback(original) - if mapping is None: - return None + _patch_nodeinfo_handler_class(mesh_interface_module) - primary_keys = ( - "id", - "userId", - "user_id", - "fromId", - "from_id", - "from", - "nodeId", - "node_id", - "nodeNum", - "node_num", - "num", - ) - for key in primary_keys: - node_id = serialization._canonical_node_id(mapping.get(key)) - if node_id: - return node_id - - user_section = _ensure_mapping(mapping.get("user")) - if user_section is not None: - for key in ("id", "userId", "user_id", "num", "nodeNum", "node_num"): - node_id = serialization._canonical_node_id(user_section.get(key)) - if node_id: - return node_id - - decoded_section = _ensure_mapping(mapping.get("decoded")) - if decoded_section is not None: - node_id = _candidate_node_id(decoded_section) - if node_id: - return node_id - - payload_section = _ensure_mapping(mapping.get("payload")) - if payload_section is not None: - node_id = _candidate_node_id(payload_section) - if node_id: - return node_id - - for key in ("packet", "meta", "info"): - node_id = _candidate_node_id(_ensure_mapping(mapping.get(key))) - if node_id: - return node_id - - for value in mapping.values(): - if isinstance(value, (list, tuple)): - for item in value: - node_id = _candidate_node_id(_ensure_mapping(item)) - if node_id: - return node_id - else: - node_id = _candidate_node_id(_ensure_mapping(value)) - if node_id: - return node_id - - return None +def _build_safe_nodeinfo_callback(original): + """Return a wrapper that injects a missing ``id`` before dispatching.""" def _safe_on_node_info_receive(iface, packet): # type: ignore[override] - candidate_mapping = _ensure_mapping(packet) - - node_id = _candidate_node_id(candidate_mapping) - - if node_id and candidate_mapping is not None: - if not isinstance(candidate_mapping, dict): - try: - candidate_mapping = dict(candidate_mapping) - except Exception: - candidate_mapping = { - k: candidate_mapping[k] for k in candidate_mapping - } - if candidate_mapping.get("id") != node_id: - candidate_mapping["id"] = node_id - packet = candidate_mapping + normalised = _normalise_nodeinfo_packet(packet) + if normalised is not None: + packet = normalised try: return original(iface, packet) @@ -149,7 +177,53 @@ def _patch_meshtastic_nodeinfo_handler() -> None: raise _safe_on_node_info_receive._potato_mesh_safe_wrapper = True # type: ignore[attr-defined] - meshtastic._onNodeInfoReceive = _safe_on_node_info_receive + return _safe_on_node_info_receive + + +def _patch_nodeinfo_handler_class(mesh_interface_module) -> None: + """Wrap ``NodeInfoHandler.onReceive`` to normalise packets before callbacks.""" + + if mesh_interface_module is None: + return + + handler_class = getattr(mesh_interface_module, "NodeInfoHandler", None) + if handler_class is None: + return + if getattr(handler_class, "_potato_mesh_safe_wrapper", False): + return + + original_on_receive = getattr(handler_class, "onReceive", None) + if not callable(original_on_receive): + return + + class _SafeNodeInfoHandler(handler_class): # type: ignore[misc] + """Subclass that guards against missing node identifiers.""" + + def onReceive(self, iface, packet): # type: ignore[override] + normalised = _normalise_nodeinfo_packet(packet) + if normalised is not None: + packet = normalised + + try: + return super().onReceive(iface, packet) + except KeyError as exc: # pragma: no cover - defensive only + if exc.args and exc.args[0] == "id": + return None + raise + + _SafeNodeInfoHandler.__name__ = handler_class.__name__ + _SafeNodeInfoHandler.__qualname__ = getattr( + handler_class, "__qualname__", handler_class.__name__ + ) + _SafeNodeInfoHandler.__module__ = getattr( + handler_class, "__module__", mesh_interface_module.__name__ + ) + _SafeNodeInfoHandler.__doc__ = getattr( + handler_class, "__doc__", _SafeNodeInfoHandler.__doc__ + ) + _SafeNodeInfoHandler._potato_mesh_safe_wrapper = True # type: ignore[attr-defined] + + setattr(mesh_interface_module, "NodeInfoHandler", _SafeNodeInfoHandler) _patch_meshtastic_nodeinfo_handler() diff --git a/tests/test_mesh.py b/tests/test_mesh.py index c797198..06ca228 100644 --- a/tests/test_mesh.py +++ b/tests/test_mesh.py @@ -129,6 +129,32 @@ def mesh_module(monkeypatch): meshtastic_mod.serial_interface = serial_interface_mod meshtastic_mod.tcp_interface = tcp_interface_mod meshtastic_mod.ble_interface = ble_interface_mod + + mesh_interface_mod = types.ModuleType("meshtastic.mesh_interface") + + def _default_nodeinfo_callback(iface, packet): + iface.nodes[packet["id"]] = packet + return packet["id"] + + class DummyNodeInfoHandler: + """Stub that mimics Meshtastic's NodeInfo handler semantics.""" + + def __init__(self): + self.callback = getattr( + meshtastic_mod, "_onNodeInfoReceive", _default_nodeinfo_callback + ) + + def onReceive(self, iface, packet): + nodes = getattr(iface, "nodes", None) + if isinstance(nodes, dict): + nodes[packet["id"]] = packet + return self.callback(iface, packet) + + mesh_interface_mod.NodeInfoHandler = DummyNodeInfoHandler + meshtastic_mod.mesh_interface = mesh_interface_mod + monkeypatch.setitem(sys.modules, "meshtastic.mesh_interface", mesh_interface_mod) + + meshtastic_mod._onNodeInfoReceive = _default_nodeinfo_callback if real_protobuf is not None: meshtastic_mod.protobuf = real_protobuf else: @@ -1080,6 +1106,31 @@ def test_nodeinfo_wrapper_infers_missing_identifier(mesh_module, monkeypatch): assert packet["id"] == "!88776655" +def test_nodeinfo_handler_wrapper_prevents_key_error(mesh_module): + """The NodeInfo handler should operate safely when the ID field is absent.""" + + import meshtastic + from data.mesh_ingestor import interfaces + + interfaces._patch_meshtastic_nodeinfo_handler() + + assert getattr( + meshtastic.mesh_interface.NodeInfoHandler, + "_potato_mesh_safe_wrapper", + False, + ), "Expected NodeInfoHandler to be replaced with a safe subclass" + + handler = meshtastic.mesh_interface.NodeInfoHandler() + iface = types.SimpleNamespace(nodes={}) + + packet = {"decoded": {"user": {"id": "!01020304"}}} + + result = handler.onReceive(iface, packet) + + assert iface.nodes["!01020304"]["id"] == "!01020304" + assert result == "!01020304" + + def test_store_packet_dict_ignores_non_text(mesh_module, monkeypatch): mesh = mesh_module captured = []