mirror of
https://github.com/l5yth/potato-mesh.git
synced 2026-07-03 16:31:57 +02:00
Prevent message ids from being treated as node identifiers (#475)
* Prevent message ids from being treated as nodes (#)
* Cover node id candidate edge cases
* Revert "address missing id field ingestor bug (#469)"
This reverts commit 546e009867.
This commit is contained in:
@@ -48,16 +48,37 @@ def _ensure_mapping(value) -> Mapping | None:
|
||||
return None
|
||||
|
||||
|
||||
def _is_nodeish_identifier(value: Any) -> bool:
|
||||
"""Return ``True`` when ``value`` resembles a Meshtastic node identifier."""
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
|
||||
trimmed = value.strip()
|
||||
if not trimmed:
|
||||
return False
|
||||
if trimmed.startswith("^"):
|
||||
return True
|
||||
if trimmed.startswith("!"):
|
||||
trimmed = trimmed[1:]
|
||||
elif trimmed.lower().startswith("0x"):
|
||||
trimmed = trimmed[2:]
|
||||
elif not re.search(r"[a-fA-F]", trimmed):
|
||||
# Bare decimal strings should not be treated as node ids when labelled "id".
|
||||
return False
|
||||
|
||||
return bool(re.fullmatch(r"[0-9a-fA-F]{1,8}", trimmed))
|
||||
|
||||
|
||||
def _candidate_node_id(mapping: Mapping | None) -> str | None:
|
||||
"""Extract a canonical node identifier from ``mapping`` when present."""
|
||||
|
||||
if mapping is None:
|
||||
return None
|
||||
|
||||
primary_keys = (
|
||||
"id",
|
||||
"userId",
|
||||
"user_id",
|
||||
node_keys = (
|
||||
"fromId",
|
||||
"from_id",
|
||||
"from",
|
||||
@@ -66,21 +87,36 @@ def _candidate_node_id(mapping: Mapping | None) -> str | None:
|
||||
"nodeNum",
|
||||
"node_num",
|
||||
"num",
|
||||
"userId",
|
||||
"user_id",
|
||||
)
|
||||
|
||||
for key in primary_keys:
|
||||
for key in node_keys:
|
||||
with contextlib.suppress(Exception):
|
||||
node_id = serialization._canonical_node_id(mapping.get(key))
|
||||
if node_id:
|
||||
return node_id
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
value = mapping.get("id")
|
||||
if _is_nodeish_identifier(value):
|
||||
node_id = serialization._canonical_node_id(value)
|
||||
if node_id:
|
||||
return node_id
|
||||
|
||||
user_section = _ensure_mapping(mapping.get("user"))
|
||||
if user_section is not None:
|
||||
for key in ("id", "userId", "user_id", "num", "nodeNum", "node_num"):
|
||||
for key in ("userId", "user_id", "num", "nodeNum", "node_num"):
|
||||
with contextlib.suppress(Exception):
|
||||
node_id = serialization._canonical_node_id(user_section.get(key))
|
||||
if node_id:
|
||||
return node_id
|
||||
with contextlib.suppress(Exception):
|
||||
user_id_value = user_section.get("id")
|
||||
if _is_nodeish_identifier(user_id_value):
|
||||
node_id = serialization._canonical_node_id(user_id_value)
|
||||
if node_id:
|
||||
return node_id
|
||||
|
||||
decoded_section = _ensure_mapping(mapping.get("decoded"))
|
||||
if decoded_section is not None:
|
||||
@@ -173,17 +209,6 @@ def _normalise_nodeinfo_packet(packet) -> dict | None:
|
||||
if node_id and normalised.get("id") != node_id:
|
||||
normalised["id"] = node_id
|
||||
|
||||
decoded_section = _ensure_mapping(normalised.get("decoded"))
|
||||
if decoded_section is not None:
|
||||
decoded_dict = dict(decoded_section)
|
||||
user_section = _ensure_mapping(decoded_dict.get("user"))
|
||||
if user_section is not None:
|
||||
user_dict = dict(user_section)
|
||||
if node_id and user_dict.get("id") != node_id:
|
||||
user_dict["id"] = node_id
|
||||
decoded_dict["user"] = user_dict
|
||||
normalised["decoded"] = decoded_dict
|
||||
|
||||
return normalised
|
||||
|
||||
|
||||
@@ -213,18 +238,8 @@ def _patch_meshtastic_nodeinfo_handler() -> None:
|
||||
with contextlib.suppress(Exception):
|
||||
mesh_interface_module = importlib.import_module("meshtastic.mesh_interface")
|
||||
|
||||
safe_callback = original
|
||||
if not getattr(original, "_potato_mesh_safe_wrapper", False):
|
||||
safe_callback = _build_safe_nodeinfo_callback(original)
|
||||
module._onNodeInfoReceive = safe_callback
|
||||
if (
|
||||
mesh_interface_module is not None
|
||||
and getattr(mesh_interface_module, "_onNodeInfoReceive", None) is original
|
||||
):
|
||||
mesh_interface_module._onNodeInfoReceive = safe_callback
|
||||
|
||||
_patch_protocol_nodeinfo_callback(module, original, safe_callback)
|
||||
_patch_protocol_nodeinfo_callback(mesh_interface_module, original, safe_callback)
|
||||
module._onNodeInfoReceive = _build_safe_nodeinfo_callback(original)
|
||||
|
||||
_patch_nodeinfo_handler_class(mesh_interface_module, module)
|
||||
|
||||
@@ -248,49 +263,6 @@ def _build_safe_nodeinfo_callback(original):
|
||||
return _safe_on_node_info_receive
|
||||
|
||||
|
||||
def _replace_known_protocol_callback(protocol, replacement):
|
||||
"""Return ``protocol`` with ``onReceive`` set to ``replacement``."""
|
||||
|
||||
replacer = getattr(protocol, "_replace", None)
|
||||
if callable(replacer):
|
||||
try:
|
||||
return replacer(onReceive=replacement)
|
||||
except Exception:
|
||||
pass
|
||||
protocol_cls = getattr(protocol, "__class__", None)
|
||||
try:
|
||||
return protocol_cls(
|
||||
getattr(protocol, "name", None),
|
||||
getattr(protocol, "protobufFactory", None),
|
||||
replacement,
|
||||
)
|
||||
except Exception:
|
||||
return protocol
|
||||
|
||||
|
||||
def _patch_protocol_nodeinfo_callback(module, original, replacement) -> None:
|
||||
"""Swap the NodeInfo protocol callback to ``replacement`` when needed."""
|
||||
|
||||
if module is None or replacement is None:
|
||||
return
|
||||
|
||||
protocols = getattr(module, "protocols", None)
|
||||
if not isinstance(protocols, Mapping):
|
||||
return
|
||||
|
||||
portnums = getattr(module, "portnums_pb2", None)
|
||||
portnum_enum = getattr(portnums, "PortNum", None)
|
||||
try:
|
||||
nodeinfo_key = getattr(portnum_enum, "NODEINFO_APP")
|
||||
except Exception:
|
||||
nodeinfo_key = None
|
||||
|
||||
for key, protocol in list(protocols.items()):
|
||||
on_receive = getattr(protocol, "onReceive", None)
|
||||
if key == nodeinfo_key or on_receive is original:
|
||||
protocols[key] = _replace_known_protocol_callback(protocol, replacement)
|
||||
|
||||
|
||||
def _update_nodeinfo_handler_aliases(original, replacement) -> None:
|
||||
"""Ensure Meshtastic modules reference the patched ``NodeInfoHandler``."""
|
||||
|
||||
|
||||
@@ -722,7 +722,7 @@ def _nodeinfo_user_dict(node_info, decoded_user):
|
||||
use_integers_for_enums=False,
|
||||
)
|
||||
except Exception:
|
||||
user_dict = None
|
||||
user_dict = _node_to_dict(node_info.user)
|
||||
|
||||
if isinstance(decoded_user, ProtoMessage):
|
||||
try:
|
||||
|
||||
@@ -134,6 +134,22 @@ def test_candidate_node_id_and_normaliser():
|
||||
node_id = interfaces._candidate_node_id(nested)
|
||||
assert node_id == "!0000002a"
|
||||
|
||||
telemetry_packet = {"id": 123456, "from": "!0000000b"}
|
||||
node_id = interfaces._candidate_node_id(telemetry_packet)
|
||||
assert node_id == "!0000000b"
|
||||
|
||||
unknown_packet = {"id": "123456"}
|
||||
assert interfaces._candidate_node_id(unknown_packet) is None
|
||||
|
||||
preferred_hex_packet = {"id": "0x2a"}
|
||||
assert interfaces._candidate_node_id(preferred_hex_packet) == "!0000002a"
|
||||
|
||||
caret_alias_packet = {"id": "^abc"}
|
||||
assert interfaces._candidate_node_id(caret_alias_packet) == "^abc"
|
||||
|
||||
non_node_numeric = {"id": 42.0}
|
||||
assert interfaces._candidate_node_id(non_node_numeric) is None
|
||||
|
||||
packet = {"user": {"id": "!0000002a"}, "userId": None}
|
||||
normalised = interfaces._normalise_nodeinfo_packet(packet)
|
||||
assert normalised["id"] == "!0000002a"
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
import importlib
|
||||
import json
|
||||
@@ -1306,162 +1305,6 @@ def test_interfaces_patch_handles_preimported_serial():
|
||||
sys.modules[name] = module
|
||||
|
||||
|
||||
def test_nodeinfo_patch_updates_known_protocols(monkeypatch):
|
||||
"""Ensure NodeInfo protocol callbacks are replaced with safe wrappers."""
|
||||
|
||||
from data.mesh_ingestor import interfaces
|
||||
|
||||
Protocol = namedtuple("Protocol", ("name", "protobufFactory", "onReceive"))
|
||||
|
||||
callbacks: list[dict] = []
|
||||
|
||||
def _unsafe_handler(iface, packet):
|
||||
callbacks.append(packet)
|
||||
user = packet["decoded"]["user"]
|
||||
iface.nodes[user["id"]] = {"user": user}
|
||||
return user["id"]
|
||||
|
||||
nodeinfo_value = 42
|
||||
PortNum = enum.IntEnum("PortNum", {"NODEINFO_APP": nodeinfo_value})
|
||||
portnums_pb2 = types.SimpleNamespace(PortNum=PortNum)
|
||||
protocols = {PortNum.NODEINFO_APP: Protocol("user", object, _unsafe_handler)}
|
||||
|
||||
meshtastic_mod = types.ModuleType("meshtastic")
|
||||
meshtastic_mod._onNodeInfoReceive = _unsafe_handler
|
||||
meshtastic_mod.portnums_pb2 = portnums_pb2
|
||||
meshtastic_mod.protocols = protocols
|
||||
|
||||
mesh_interface_mod = types.ModuleType("meshtastic.mesh_interface")
|
||||
mesh_interface_mod.protocols = protocols
|
||||
mesh_interface_mod.portnums_pb2 = portnums_pb2
|
||||
|
||||
monkeypatch.setitem(sys.modules, "meshtastic", meshtastic_mod)
|
||||
monkeypatch.setitem(sys.modules, "meshtastic.mesh_interface", mesh_interface_mod)
|
||||
monkeypatch.setattr(interfaces, "meshtastic", meshtastic_mod, raising=False)
|
||||
|
||||
interfaces._patch_meshtastic_nodeinfo_handler()
|
||||
|
||||
handler = meshtastic_mod.protocols[PortNum.NODEINFO_APP].onReceive
|
||||
iface = types.SimpleNamespace(nodes={})
|
||||
|
||||
handler(iface, {"decoded": {"user": {"shortName": "anon"}}, "from": 0x01020304})
|
||||
|
||||
assert getattr(handler, "_potato_mesh_safe_wrapper", False)
|
||||
assert callbacks, "Expected patched handler to call original callback"
|
||||
assert iface.nodes["!01020304"]["user"]["id"] == "!01020304"
|
||||
|
||||
|
||||
def test_nodeinfo_patch_updates_protocols_without_replace(monkeypatch):
|
||||
"""Fallback protocol replacement path should still wrap unsafe callbacks."""
|
||||
|
||||
from data.mesh_ingestor import interfaces
|
||||
|
||||
class DummyProtocol:
|
||||
def __init__(self, name, factory, on_receive):
|
||||
self.name = name
|
||||
self.protobufFactory = factory
|
||||
self.onReceive = on_receive
|
||||
|
||||
callbacks: list[dict] = []
|
||||
|
||||
def _unsafe_handler(iface, packet):
|
||||
callbacks.append(packet)
|
||||
iface.nodes[packet["from"]] = {"user": packet["decoded"]["user"]}
|
||||
return packet["from"]
|
||||
|
||||
nodeinfo_value = 7
|
||||
PortNum = enum.IntEnum("PortNum", {"NODEINFO_APP": nodeinfo_value})
|
||||
portnums_pb2 = types.SimpleNamespace(PortNum=PortNum)
|
||||
protocol_obj = DummyProtocol("user", object, _unsafe_handler)
|
||||
protocols = {
|
||||
PortNum.NODEINFO_APP: protocol_obj,
|
||||
99: DummyProtocol("other", object, lambda *_: None),
|
||||
}
|
||||
|
||||
meshtastic_mod = types.ModuleType("meshtastic")
|
||||
meshtastic_mod._onNodeInfoReceive = _unsafe_handler
|
||||
meshtastic_mod.portnums_pb2 = portnums_pb2
|
||||
meshtastic_mod.protocols = protocols
|
||||
|
||||
mesh_interface_mod = types.ModuleType("meshtastic.mesh_interface")
|
||||
mesh_interface_mod.protocols = dict(protocols)
|
||||
mesh_interface_mod.portnums_pb2 = portnums_pb2
|
||||
mesh_interface_mod._onNodeInfoReceive = _unsafe_handler
|
||||
|
||||
monkeypatch.setitem(sys.modules, "meshtastic", meshtastic_mod)
|
||||
monkeypatch.setitem(sys.modules, "meshtastic.mesh_interface", mesh_interface_mod)
|
||||
monkeypatch.setattr(interfaces, "meshtastic", meshtastic_mod, raising=False)
|
||||
|
||||
interfaces._patch_meshtastic_nodeinfo_handler()
|
||||
|
||||
handler = meshtastic_mod.protocols[PortNum.NODEINFO_APP].onReceive
|
||||
iface = types.SimpleNamespace(nodes={})
|
||||
|
||||
handler(iface, {"decoded": {"user": {"shortName": "anon"}}, "from": 0x01020304})
|
||||
|
||||
assert getattr(handler, "_potato_mesh_safe_wrapper", False)
|
||||
assert callbacks, "Expected patched handler to call original callback"
|
||||
assert callbacks[0]["decoded"]["user"]["id"] == "!01020304"
|
||||
assert iface.nodes[0x01020304]["user"]["id"] == "!01020304"
|
||||
assert (
|
||||
getattr(mesh_interface_mod, "_onNodeInfoReceive").__name__ == handler.__name__
|
||||
)
|
||||
assert getattr(
|
||||
mesh_interface_mod.protocols[nodeinfo_value].onReceive,
|
||||
"_potato_mesh_safe_wrapper",
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def test_normalise_nodeinfo_packet_injects_decoded_user_id():
|
||||
"""Ensure decoded user payloads inherit the inferred node id."""
|
||||
|
||||
from data.mesh_ingestor import interfaces
|
||||
|
||||
packet = {"decoded": {"user": {"shortName": "anon"}}, "from": 0x0A0B0C0D}
|
||||
|
||||
normalised = interfaces._normalise_nodeinfo_packet(packet)
|
||||
|
||||
assert normalised["id"] == "!0a0b0c0d"
|
||||
assert normalised["decoded"]["user"]["id"] == "!0a0b0c0d"
|
||||
|
||||
|
||||
def test_patch_protocol_nodeinfo_callback_without_portnum(monkeypatch):
|
||||
"""Protocols lacking PortNum constants should still be wrapped."""
|
||||
|
||||
from data.mesh_ingestor import interfaces
|
||||
|
||||
captured: list[dict] = []
|
||||
|
||||
def _original(iface, packet):
|
||||
captured.append(packet)
|
||||
iface.nodes = {"observed": packet}
|
||||
return packet.get("id")
|
||||
|
||||
class DummyProtocol:
|
||||
def __init__(self, name, factory, on_receive):
|
||||
self.name = name
|
||||
self.protobufFactory = factory
|
||||
self.onReceive = on_receive
|
||||
|
||||
module = types.SimpleNamespace(
|
||||
protocols={123: DummyProtocol("user", object, _original)},
|
||||
portnums_pb2=None,
|
||||
)
|
||||
|
||||
safe_callback = interfaces._build_safe_nodeinfo_callback(_original)
|
||||
interfaces._patch_protocol_nodeinfo_callback(module, _original, safe_callback)
|
||||
|
||||
handler = module.protocols[123].onReceive
|
||||
iface = types.SimpleNamespace(nodes={})
|
||||
|
||||
handler(iface, {"decoded": {"user": {"shortName": "anon"}}, "from": 0x01020304})
|
||||
|
||||
assert getattr(handler, "_potato_mesh_safe_wrapper", False)
|
||||
assert captured[0]["id"] == "!01020304"
|
||||
assert iface.nodes["observed"]["decoded"]["user"]["id"] == "!01020304"
|
||||
|
||||
|
||||
def test_store_packet_dict_ignores_non_text(mesh_module, monkeypatch):
|
||||
mesh = mesh_module
|
||||
captured = []
|
||||
|
||||
Reference in New Issue
Block a user