Guard NodeInfo handler against missing IDs (#426) (#431)

This commit is contained in:
l5y
2025-11-12 12:39:36 +01:00
committed by GitHub
parent 8823b7cb48
commit 2107d6790d
2 changed files with 212 additions and 87 deletions

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import contextlib
import importlib
import glob
import ipaddress
import re
@@ -29,6 +30,108 @@ from meshtastic.tcp_interface import TCPInterface
from . import channels, config, serialization
def _ensure_mapping(value) -> Mapping | None:
"""Return ``value`` as a mapping when conversion is possible."""
if isinstance(value, Mapping):
return value
if hasattr(value, "__dict__") and isinstance(value.__dict__, Mapping):
return value.__dict__
with contextlib.suppress(Exception):
converted = serialization._node_to_dict(value)
if isinstance(converted, Mapping):
return converted
return None
def _candidate_node_id(mapping: Mapping | None) -> str | None:
"""Extract a canonical node identifier from ``mapping`` when present."""
if mapping is None:
return None
primary_keys = (
"id",
"userId",
"user_id",
"fromId",
"from_id",
"from",
"nodeId",
"node_id",
"nodeNum",
"node_num",
"num",
)
for key in primary_keys:
with contextlib.suppress(Exception):
node_id = serialization._canonical_node_id(mapping.get(key))
if node_id:
return node_id
user_section = _ensure_mapping(mapping.get("user"))
if user_section is not None:
for key in ("id", "userId", "user_id", "num", "nodeNum", "node_num"):
with contextlib.suppress(Exception):
node_id = serialization._canonical_node_id(user_section.get(key))
if node_id:
return node_id
decoded_section = _ensure_mapping(mapping.get("decoded"))
if decoded_section is not None:
node_id = _candidate_node_id(decoded_section)
if node_id:
return node_id
payload_section = _ensure_mapping(mapping.get("payload"))
if payload_section is not None:
node_id = _candidate_node_id(payload_section)
if node_id:
return node_id
for key in ("packet", "meta", "info"):
node_id = _candidate_node_id(_ensure_mapping(mapping.get(key)))
if node_id:
return node_id
for value in mapping.values():
if isinstance(value, (list, tuple)):
for item in value:
node_id = _candidate_node_id(_ensure_mapping(item))
if node_id:
return node_id
else:
node_id = _candidate_node_id(_ensure_mapping(value))
if node_id:
return node_id
return None
def _normalise_nodeinfo_packet(packet) -> dict | None:
"""Return a dictionary view of ``packet`` with a guaranteed ``id`` when known."""
mapping = _ensure_mapping(packet)
if mapping is None:
return None
try:
normalised: dict = dict(mapping)
except Exception:
try:
normalised = {key: mapping[key] for key in mapping}
except Exception:
return None
node_id = _candidate_node_id(normalised)
if node_id and normalised.get("id") != node_id:
normalised["id"] = node_id
return normalised
if TYPE_CHECKING: # pragma: no cover - import only used for type checking
from meshtastic.ble_interface import BLEInterface as _BLEInterface
@@ -46,100 +149,25 @@ def _patch_meshtastic_nodeinfo_handler() -> None:
original = getattr(meshtastic, "_onNodeInfoReceive", None)
if not callable(original):
return
if getattr(original, "_potato_mesh_safe_wrapper", False):
return
def _ensure_mapping(value) -> Mapping | None:
"""Return ``value`` as a mapping when conversion is possible."""
if isinstance(value, Mapping):
return value
if hasattr(value, "__dict__") and isinstance(value.__dict__, Mapping):
return value.__dict__
mesh_interface_module = getattr(meshtastic, "mesh_interface", None)
if mesh_interface_module is None:
with contextlib.suppress(Exception):
converted = serialization._node_to_dict(value)
if isinstance(converted, Mapping):
return converted
return None
mesh_interface_module = importlib.import_module("meshtastic.mesh_interface")
def _candidate_node_id(mapping: Mapping | None) -> str | None:
"""Extract a canonical node identifier from ``mapping`` when present."""
if not getattr(original, "_potato_mesh_safe_wrapper", False):
meshtastic._onNodeInfoReceive = _build_safe_nodeinfo_callback(original)
if mapping is None:
return None
_patch_nodeinfo_handler_class(mesh_interface_module)
primary_keys = (
"id",
"userId",
"user_id",
"fromId",
"from_id",
"from",
"nodeId",
"node_id",
"nodeNum",
"node_num",
"num",
)
for key in primary_keys:
node_id = serialization._canonical_node_id(mapping.get(key))
if node_id:
return node_id
user_section = _ensure_mapping(mapping.get("user"))
if user_section is not None:
for key in ("id", "userId", "user_id", "num", "nodeNum", "node_num"):
node_id = serialization._canonical_node_id(user_section.get(key))
if node_id:
return node_id
decoded_section = _ensure_mapping(mapping.get("decoded"))
if decoded_section is not None:
node_id = _candidate_node_id(decoded_section)
if node_id:
return node_id
payload_section = _ensure_mapping(mapping.get("payload"))
if payload_section is not None:
node_id = _candidate_node_id(payload_section)
if node_id:
return node_id
for key in ("packet", "meta", "info"):
node_id = _candidate_node_id(_ensure_mapping(mapping.get(key)))
if node_id:
return node_id
for value in mapping.values():
if isinstance(value, (list, tuple)):
for item in value:
node_id = _candidate_node_id(_ensure_mapping(item))
if node_id:
return node_id
else:
node_id = _candidate_node_id(_ensure_mapping(value))
if node_id:
return node_id
return None
def _build_safe_nodeinfo_callback(original):
"""Return a wrapper that injects a missing ``id`` before dispatching."""
def _safe_on_node_info_receive(iface, packet): # type: ignore[override]
candidate_mapping = _ensure_mapping(packet)
node_id = _candidate_node_id(candidate_mapping)
if node_id and candidate_mapping is not None:
if not isinstance(candidate_mapping, dict):
try:
candidate_mapping = dict(candidate_mapping)
except Exception:
candidate_mapping = {
k: candidate_mapping[k] for k in candidate_mapping
}
if candidate_mapping.get("id") != node_id:
candidate_mapping["id"] = node_id
packet = candidate_mapping
normalised = _normalise_nodeinfo_packet(packet)
if normalised is not None:
packet = normalised
try:
return original(iface, packet)
@@ -149,7 +177,53 @@ def _patch_meshtastic_nodeinfo_handler() -> None:
raise
_safe_on_node_info_receive._potato_mesh_safe_wrapper = True # type: ignore[attr-defined]
meshtastic._onNodeInfoReceive = _safe_on_node_info_receive
return _safe_on_node_info_receive
def _patch_nodeinfo_handler_class(mesh_interface_module) -> None:
"""Wrap ``NodeInfoHandler.onReceive`` to normalise packets before callbacks."""
if mesh_interface_module is None:
return
handler_class = getattr(mesh_interface_module, "NodeInfoHandler", None)
if handler_class is None:
return
if getattr(handler_class, "_potato_mesh_safe_wrapper", False):
return
original_on_receive = getattr(handler_class, "onReceive", None)
if not callable(original_on_receive):
return
class _SafeNodeInfoHandler(handler_class): # type: ignore[misc]
"""Subclass that guards against missing node identifiers."""
def onReceive(self, iface, packet): # type: ignore[override]
normalised = _normalise_nodeinfo_packet(packet)
if normalised is not None:
packet = normalised
try:
return super().onReceive(iface, packet)
except KeyError as exc: # pragma: no cover - defensive only
if exc.args and exc.args[0] == "id":
return None
raise
_SafeNodeInfoHandler.__name__ = handler_class.__name__
_SafeNodeInfoHandler.__qualname__ = getattr(
handler_class, "__qualname__", handler_class.__name__
)
_SafeNodeInfoHandler.__module__ = getattr(
handler_class, "__module__", mesh_interface_module.__name__
)
_SafeNodeInfoHandler.__doc__ = getattr(
handler_class, "__doc__", _SafeNodeInfoHandler.__doc__
)
_SafeNodeInfoHandler._potato_mesh_safe_wrapper = True # type: ignore[attr-defined]
setattr(mesh_interface_module, "NodeInfoHandler", _SafeNodeInfoHandler)
_patch_meshtastic_nodeinfo_handler()

View File

@@ -129,6 +129,32 @@ def mesh_module(monkeypatch):
meshtastic_mod.serial_interface = serial_interface_mod
meshtastic_mod.tcp_interface = tcp_interface_mod
meshtastic_mod.ble_interface = ble_interface_mod
mesh_interface_mod = types.ModuleType("meshtastic.mesh_interface")
def _default_nodeinfo_callback(iface, packet):
iface.nodes[packet["id"]] = packet
return packet["id"]
class DummyNodeInfoHandler:
"""Stub that mimics Meshtastic's NodeInfo handler semantics."""
def __init__(self):
self.callback = getattr(
meshtastic_mod, "_onNodeInfoReceive", _default_nodeinfo_callback
)
def onReceive(self, iface, packet):
nodes = getattr(iface, "nodes", None)
if isinstance(nodes, dict):
nodes[packet["id"]] = packet
return self.callback(iface, packet)
mesh_interface_mod.NodeInfoHandler = DummyNodeInfoHandler
meshtastic_mod.mesh_interface = mesh_interface_mod
monkeypatch.setitem(sys.modules, "meshtastic.mesh_interface", mesh_interface_mod)
meshtastic_mod._onNodeInfoReceive = _default_nodeinfo_callback
if real_protobuf is not None:
meshtastic_mod.protobuf = real_protobuf
else:
@@ -1080,6 +1106,31 @@ def test_nodeinfo_wrapper_infers_missing_identifier(mesh_module, monkeypatch):
assert packet["id"] == "!88776655"
def test_nodeinfo_handler_wrapper_prevents_key_error(mesh_module):
"""The NodeInfo handler should operate safely when the ID field is absent."""
import meshtastic
from data.mesh_ingestor import interfaces
interfaces._patch_meshtastic_nodeinfo_handler()
assert getattr(
meshtastic.mesh_interface.NodeInfoHandler,
"_potato_mesh_safe_wrapper",
False,
), "Expected NodeInfoHandler to be replaced with a safe subclass"
handler = meshtastic.mesh_interface.NodeInfoHandler()
iface = types.SimpleNamespace(nodes={})
packet = {"decoded": {"user": {"id": "!01020304"}}}
result = handler.onReceive(iface, packet)
assert iface.nodes["!01020304"]["id"] == "!01020304"
assert result == "!01020304"
def test_store_packet_dict_ignores_non_text(mesh_module, monkeypatch):
mesh = mesh_module
captured = []