feat: enhance trace processing and path handling in RepeaterDaemon and TraceHelper

- Added local_identity parameter to RepeaterDaemon for improved trace path matching.
- Refactored trace path handling in TraceHelper to support multi-byte hashes and structured hops.
- Updated methods to ensure compatibility with new trace data formats and improved logging.
- Enhanced tests to validate new trace processing logic and path handling.
This commit is contained in:
agessaman
2026-03-22 14:34:04 -07:00
parent 2e25467c5d
commit 3cb27d3310
4 changed files with 234 additions and 113 deletions

View File

@@ -9,11 +9,12 @@ of packets through the mesh network.
import asyncio
import logging
import time
from typing import Any, Dict
from typing import Any, Dict, List
from pymc_core.hardware.signal_utils import snr_register_to_db
from pymc_core.node.handlers.trace import TraceHandler
from pymc_core.protocol.constants import MAX_PATH_SIZE, ROUTE_TYPE_DIRECT
from pymc_core.protocol.packet_utils import PathUtils
logger = logging.getLogger("TraceHelper")
@@ -21,17 +22,33 @@ logger = logging.getLogger("TraceHelper")
class TraceHelper:
"""Helper class for processing trace packets in the repeater."""
def __init__(self, local_hash: int, repeater_handler, packet_injector=None, log_fn=None):
def __init__(
self,
local_hash: int,
repeater_handler,
packet_injector=None,
log_fn=None,
local_identity=None,
):
"""
Initialize the trace helper.
Args:
local_hash: The local node's hash identifier
local_hash: The local node's 1-byte hash (first byte of pubkey); legacy
repeater_handler: The RepeaterHandler instance
packet_injector: Callable to inject new packets into the router for sending
log_fn: Optional logging function for TraceHandler
local_identity: LocalIdentity (or any object with get_public_key()) for
multibyte TRACE path matching (Mesh.cpp isHashMatch with 1<<path_sz bytes)
"""
self.local_hash = local_hash
self.local_identity = local_identity
self._pubkey_bytes: bytes = b""
if local_identity is not None and hasattr(local_identity, "get_public_key"):
try:
self._pubkey_bytes = bytes(local_identity.get_public_key())
except Exception:
self._pubkey_bytes = b""
self.repeater_handler = repeater_handler
self.packet_injector = packet_injector # Function to inject packets into router
@@ -46,6 +63,11 @@ class TraceHelper:
# Create TraceHandler internally as a parsing utility
self.trace_handler = TraceHandler(log_fn=log_fn or logger.info)
def _pubkey_prefix(self, width: int) -> bytes:
if width <= 0 or not self._pubkey_bytes:
return b""
return self._pubkey_bytes[:width]
async def process_trace_packet(self, packet) -> None:
"""
Process an incoming trace packet.
@@ -57,8 +79,8 @@ class TraceHelper:
packet: The trace packet to process
"""
try:
# Only process direct route trace packets
if packet.get_route_type() != ROUTE_TYPE_DIRECT or packet.path_len >= MAX_PATH_SIZE:
# Only process direct route trace packets (SNR path uses len(packet.path))
if packet.get_route_type() != ROUTE_TYPE_DIRECT or len(packet.path) >= MAX_PATH_SIZE:
return
# Parse the trace payload
@@ -68,8 +90,12 @@ class TraceHelper:
logger.warning(f"Invalid trace packet: {parsed_data.get('error', 'Unknown error')}")
return
trace_path = parsed_data["trace_path"]
trace_path_len = len(trace_path)
trace_bytes: bytes = parsed_data.get("trace_path_bytes") or b""
flags = parsed_data.get("flags", 0)
hash_width = PathUtils.trace_payload_hash_width(flags)
trace_hops: List[bytes] = parsed_data.get("trace_hops") or []
num_hops = len(trace_hops)
legacy_trace_path = parsed_data.get("trace_path") or []
# Check if this is a response to one of our pings
trace_tag = parsed_data.get("tag")
@@ -82,9 +108,11 @@ class TraceHelper:
)
return # wait for a valid response or let timeout handle it
ping_info = self.pending_pings[trace_tag]
# Store response data
# Store response data (legacy path list + structured hops)
ping_info["result"] = {
"path": trace_path,
"path": legacy_trace_path,
"trace_hops": trace_hops,
"trace_path_bytes": trace_bytes,
"snr": packet.get_snr(),
"rssi": rssi_val,
"received_at": time.time(),
@@ -95,11 +123,11 @@ class TraceHelper:
# Record the trace packet for dashboard/statistics
if self.repeater_handler:
packet_record = self._create_trace_record(packet, trace_path, parsed_data)
packet_record = self._create_trace_record(packet, parsed_data)
self.repeater_handler.log_trace_record(packet_record)
# Extract and log path SNRs and hashes
path_snrs, path_hashes = self._extract_path_info(packet, trace_path)
path_snrs, path_hashes = self._extract_path_info(packet, parsed_data)
# Add packet metadata for logging
parsed_data["snr"] = packet.get_snr()
@@ -109,16 +137,18 @@ class TraceHelper:
logger.info(f"{formatted_response}")
logger.info(f"Path SNRs: [{', '.join(path_snrs)}], Hashes: [{', '.join(path_hashes)}]")
# Check if we should forward this trace packet
should_forward = self._should_forward_trace(packet, trace_path, trace_path_len)
should_forward = self._should_forward_trace(packet, trace_bytes, flags, hash_width)
if should_forward:
await self._forward_trace_packet(packet, trace_path_len)
await self._forward_trace_packet(packet, num_hops)
else:
# This is the final destination or can't forward - just log and record
self._log_no_forward_reason(packet, trace_path, trace_path_len)
# When trace completed (reached end of path), push PUSH_CODE_TRACE_DATA (0x89) to companions (firmware onTraceRecv)
if packet.path_len >= trace_path_len and self.on_trace_complete:
self._log_no_forward_reason(packet, trace_bytes, hash_width)
if (
self.on_trace_complete
and self._is_trace_complete(packet, trace_bytes, hash_width)
and self.repeater_handler
and not self.repeater_handler.is_duplicate(packet)
):
try:
await self.on_trace_complete(packet, parsed_data)
except Exception as e:
@@ -127,38 +157,56 @@ class TraceHelper:
except Exception as e:
logger.error(f"Error processing trace packet: {e}")
def _create_trace_record(self, packet, trace_path: list, parsed_data: dict) -> Dict[str, Any]:
def _is_trace_complete(self, packet, trace_bytes: bytes, hash_width: int) -> bool:
"""Mirror Mesh.cpp: offset = path_len<<path_sz >= len(trace hash bytes)."""
if not trace_bytes or hash_width <= 0:
return False
snr_count = len(packet.path)
return snr_count * hash_width >= len(trace_bytes)
def _create_trace_record(self, packet, parsed_data: dict) -> Dict[str, Any]:
"""
Create a packet record for trace packets to log to statistics.
Args:
packet: The trace packet
trace_path: The parsed trace path from the payload
parsed_data: The parsed trace data
parsed_data: Full parse result from TraceHandler
Returns:
A dictionary containing the packet record
"""
# Format trace path for display
trace_path_bytes = [f"{h:02X}" for h in trace_path[:8]]
if len(trace_path) > 8:
trace_hops: List[bytes] = parsed_data.get("trace_hops") or []
legacy = parsed_data.get("trace_path") or []
trace_path_bytes = [h.hex().upper() for h in trace_hops[:8]]
if len(trace_hops) > 8:
trace_path_bytes.append("...")
path_hash = "[" + ", ".join(trace_path_bytes) + "]"
# Extract SNR information from the path
# Extract SNR information from the path (one SNR byte per hop along trace)
path_snrs = []
path_snr_details = []
for i in range(packet.path_len):
if i < len(packet.path):
snr_val = packet.path[i]
snr_db = snr_register_to_db(snr_val)
path_snrs.append(f"{snr_val}({snr_db:.1f}dB)")
for i in range(len(packet.path)):
snr_val = packet.path[i]
snr_db = snr_register_to_db(snr_val)
path_snrs.append(f"{snr_val}({snr_db:.1f}dB)")
# Add detailed SNR info if we have the corresponding hash
if i < len(trace_path):
path_snr_details.append(
{"hash": f"{trace_path[i]:02X}", "snr_raw": snr_val, "snr_db": snr_db}
)
if i < len(trace_hops):
path_snr_details.append(
{
"hash": trace_hops[i].hex().upper(),
"snr_raw": snr_val,
"snr_db": snr_db,
}
)
elif i < len(legacy):
path_snr_details.append(
{
"hash": f"{legacy[i]:02X}",
"snr_raw": snr_val,
"snr_db": snr_db,
}
)
return {
"timestamp": time.time(),
@@ -195,69 +243,74 @@ class TraceHelper:
"path_hash": path_hash,
"src_hash": None,
"dst_hash": None,
"original_path": [f"{h:02X}" for h in trace_path],
"original_path": [h.hex() for h in trace_hops],
"forwarded_path": None,
# Add trace-specific SNR path information
"path_snrs": path_snrs, # ["58(14.5dB)", "19(4.8dB)"]
"path_snr_details": path_snr_details, # [{"hash": "29", "snr_raw": 58, "snr_db": 14.5}]
"path_snr_details": path_snr_details,
"is_trace": True,
"raw_packet": packet.write_to().hex() if hasattr(packet, "write_to") else None,
}
def _extract_path_info(self, packet, trace_path: list) -> tuple:
def _extract_path_info(self, packet, parsed_data: dict) -> tuple:
"""
Extract SNR and hash information from the packet path.
Args:
packet: The trace packet
trace_path: The parsed trace path from the payload
Returns:
A tuple of (path_snrs, path_hashes) lists
A tuple of (path_snrs, path_hashes) display lists
"""
trace_hops: List[bytes] = parsed_data.get("trace_hops") or []
path_snrs = []
path_hashes = []
for i in range(packet.path_len):
for i in range(len(packet.path)):
if i < len(packet.path):
snr_val = packet.path[i]
snr_db = snr_register_to_db(snr_val)
path_snrs.append(f"{snr_val}({snr_db:.1f}dB)")
if i < len(trace_path):
path_hashes.append(f"0x{trace_path[i]:02x}")
if i < len(trace_hops):
path_hashes.append(f"0x{trace_hops[i].hex()}")
return path_snrs, path_hashes
def _should_forward_trace(self, packet, trace_path: list, trace_path_len: int) -> bool:
def _should_forward_trace(
self, packet, trace_bytes: bytes, flags: int, hash_width: int
) -> bool:
"""
Determine if this node should forward the trace packet.
Uses the same logic as the original working implementation.
Args:
packet: The trace packet
trace_path: The parsed trace path from the payload
trace_path_len: The length of the trace path
Returns:
True if the packet should be forwarded, False otherwise
Mesh.cpp TRACE branch: forward if offset < len and next hash matches identity.
offset = pkt->path_len<<path_sz uses SNR count in packet.path (len(packet.path)).
"""
# Use the exact logic from the original working code
return (
packet.path_len < trace_path_len
and len(trace_path) > packet.path_len
and trace_path[packet.path_len] == self.local_hash
and self.repeater_handler
and not self.repeater_handler.is_duplicate(packet)
)
if not trace_bytes or hash_width <= 0:
return False
snr_count = len(packet.path)
byte_off = snr_count * hash_width
if byte_off >= len(trace_bytes):
return False
async def _forward_trace_packet(self, packet, trace_path_len: int) -> None:
next_hop = trace_bytes[byte_off : byte_off + hash_width]
if len(next_hop) != hash_width:
return False
pubkey_pfx = self._pubkey_prefix(hash_width)
if len(pubkey_pfx) >= hash_width:
match = next_hop == pubkey_pfx[:hash_width]
else:
match = hash_width == 1 and next_hop[0] == (self.local_hash & 0xFF)
if not match:
return False
if not self.repeater_handler:
return False
return not self.repeater_handler.is_duplicate(packet)
async def _forward_trace_packet(self, packet, num_hops: int) -> None:
"""
Forward a trace packet by appending SNR and sending via injection.
Args:
packet: The trace packet to forward
trace_path_len: The length of the trace path
num_hops: Total hops in trace path (for logging)
"""
# Update the packet record to show it will be transmitted
if self.repeater_handler and hasattr(self.repeater_handler, "recent_packets"):
@@ -290,7 +343,8 @@ class TraceHelper:
packet.path_len += 1
logger.info(
f"Forwarding trace, stored SNR {current_snr:.1f}dB at position {packet.path_len - 1}"
f"Forwarding trace ({num_hops} hop path), stored SNR {current_snr:.1f}dB "
f"at SNR index {packet.path_len - 1}"
)
# Inject packet into router for proper routing and transmission
@@ -299,26 +353,34 @@ class TraceHelper:
else:
logger.warning("No packet injector available - trace packet not forwarded")
def _log_no_forward_reason(self, packet, trace_path: list, trace_path_len: int) -> None:
"""
Log the reason why a trace packet was not forwarded.
Args:
packet: The trace packet
trace_path: The parsed trace path from the payload
trace_path_len: The length of the trace path
"""
if packet.path_len >= trace_path_len:
logger.info("Trace completed (reached end of path)")
elif len(trace_path) <= packet.path_len:
logger.info("Path index out of bounds")
elif trace_path[packet.path_len] != self.local_hash:
expected_hash = (
trace_path[packet.path_len] if packet.path_len < len(trace_path) else None
)
logger.info(f"Not our turn (next hop: 0x{expected_hash:02x})")
elif self.repeater_handler and self.repeater_handler.is_duplicate(packet):
def _log_no_forward_reason(self, packet, trace_bytes: bytes, hash_width: int) -> None:
"""Log the reason why this node did not forward the trace."""
if self.repeater_handler and self.repeater_handler.is_duplicate(packet):
logger.info("Duplicate packet, ignoring")
return
snr_count = len(packet.path)
if not trace_bytes or hash_width <= 0:
logger.info("Trace: empty path or invalid hash width")
return
if snr_count * hash_width >= len(trace_bytes):
logger.info("Trace completed (reached end of path)")
return
byte_off = snr_count * hash_width
next_hop = trace_bytes[byte_off : byte_off + hash_width]
pubkey_pfx = self._pubkey_prefix(hash_width)
if len(next_hop) == hash_width and len(pubkey_pfx) >= hash_width:
if next_hop != pubkey_pfx[:hash_width]:
logger.info(f"Not our turn (next hop: 0x{next_hop.hex()})")
return
elif hash_width == 1 and next_hop:
if (next_hop[0] & 0xFF) != (self.local_hash & 0xFF):
logger.info(f"Not our turn (next hop: 0x{next_hop.hex()})")
return
logger.info("Trace: not forwarded (internal)")
def register_ping(self, tag: int, target_hash: int) -> asyncio.Event:
"""Register a ping request and return an event to wait on.

View File

@@ -170,6 +170,7 @@ class RepeaterDaemon:
repeater_handler=self.repeater_handler,
packet_injector=self.router.inject_packet,
log_fn=logger.info,
local_identity=self.local_identity,
)
logger.info("Trace processing helper initialized")
@@ -745,21 +746,28 @@ class RepeaterDaemon:
async def _on_trace_complete_for_companions(self, packet, parsed_data) -> None:
"""Trace completed at this node: push PUSH_CODE_TRACE_DATA (0x89) to companion clients (firmware onTraceRecv)."""
path_len = len(parsed_data.get("trace_path", []))
if path_len == 0:
path_hashes = parsed_data.get("trace_path_bytes") or b""
if not path_hashes:
return
path_hashes = bytes(parsed_data["trace_path"])
flags = parsed_data.get("flags", 0)
path_sz = flags & 0x03
hash_len = len(path_hashes)
expected_snr_len = hash_len >> path_sz
if expected_snr_len <= 0:
return
tag = parsed_data.get("tag", 0)
auth_code = parsed_data.get("auth_code", 0)
# path_snrs: exactly path_len bytes = (path_len-1) from forwarding hops + 1 (our receive SNR)
snr_scaled = max(-128, min(127, int(round(packet.get_snr() * 4))))
snr_byte = snr_scaled if snr_scaled >= 0 else (256 + snr_scaled)
path_snrs = bytes(packet.path)[: path_len - 1] + bytes([snr_byte])
# Firmware: memcpy path_snrs from pkt->path (length hash_len >> path_sz), then final SNR byte
raw = bytes(packet.path)[:expected_snr_len]
if len(raw) < expected_snr_len:
raw = raw + b"\x00" * (expected_snr_len - len(raw))
path_snrs = raw
for fs in getattr(self, "companion_frame_servers", []):
try:
fs.push_trace_data(
path_len, flags, tag, auth_code, path_hashes, path_snrs, snr_byte
await fs.push_trace_data_async(
hash_len, flags, tag, auth_code, path_hashes, path_snrs, snr_byte
)
except Exception as e:
logger.debug("Push trace data to companion: %s", e)

View File

@@ -2197,17 +2197,20 @@ class APIEndpoints:
# Calculate round-trip time
rtt_ms = (result["received_at"] - ping_info["sent_at"]) * 1000
# result["path"] is a flat byte list from _parse_trace_payload.
# For multi-byte hash mode, group into byte_count-sized chunks
# before formatting (e.g. [0xb5, 0xd8] → ["0xb5d8"] for 2-byte mode).
raw_path = result["path"]
if byte_count > 1:
# Prefer structured hops from TraceHelper; else legacy flat list.
if result.get("trace_hops"):
grouped_path = [
int.from_bytes(bytes(raw_path[i:i + byte_count]), "big")
for i in range(0, len(raw_path), byte_count)
int.from_bytes(bytes(h), "big") for h in result["trace_hops"]
]
else:
grouped_path = raw_path
raw_path = result["path"]
if byte_count > 1:
grouped_path = [
int.from_bytes(bytes(raw_path[i : i + byte_count]), "big")
for i in range(0, len(raw_path), byte_count)
]
else:
grouped_path = raw_path
return self._success(
{

View File

@@ -684,27 +684,25 @@ class TestTracePayloadParsing:
assert result["trace_path"] == [0xAA, 0xBB, 0xCC]
assert result["path_length"] == 3
def test_parse_trace_with_multibyte_path_is_flat(self):
"""
Trace path is raw bytes in payload — _parse_trace_payload returns it flat.
Multi-byte grouping is NOT done at the trace parser level.
"""
def test_parse_trace_with_multibyte_path_grouped_by_flags(self):
"""flags=1 → 2-byte hashes per hop (Mesh.cpp 1<<path_sz)."""
th = self._make_trace_handler()
# 2 hops of 2-byte hashes → 4 flat bytes in the payload
path = bytes([0xAA, 0xBB, 0xCC, 0xDD])
payload = struct.pack("<IIB", 10, 20, 0) + path
payload = struct.pack("<IIB", 10, 20, 1) + path
result = th._parse_trace_payload(payload)
assert result["valid"]
# Returns flat list, not grouped
assert result["trace_path"] == [0xAA, 0xBB, 0xCC, 0xDD]
assert result["path_hash_width"] == 2
assert result["trace_hops"] == [b"\xaa\xbb", b"\xcc\xdd"]
assert result["trace_path"] == [0xAA, 0xCC] # legacy: first byte per hop
assert result["path_length"] == 4
assert result["path_hop_count"] == 2
def test_parse_from_real_packet(self):
"""Create a trace with PacketBuilder, serialize, deserialize, then parse."""
th = self._make_trace_handler()
trace_path = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66]
pkt = PacketBuilder.create_trace(
tag=100, auth_code=200, flags=7, path=trace_path
tag=100, auth_code=200, flags=0, path=trace_path
)
wire = pkt.write_to()
pkt2 = Packet()
@@ -713,8 +711,10 @@ class TestTracePayloadParsing:
assert result["valid"]
assert result["tag"] == 100
assert result["auth_code"] == 200
assert result["flags"] == 7
assert result["flags"] == 0
assert result["trace_path"] == trace_path
assert result["path_hop_count"] == 6
assert result["trace_path_bytes"] == bytes(trace_path)
def test_parse_too_short_payload(self):
th = self._make_trace_handler()
@@ -723,6 +723,54 @@ class TestTracePayloadParsing:
assert "too short" in result["error"].lower()
class TestTraceHelperMultibyte:
"""TraceHelper._should_forward_trace with 2-byte TRACE payload hashes."""
def test_should_forward_when_next_hop_matches_pubkey_prefix(self):
from repeater.handler_helpers.trace import TraceHelper
from pymc_core.protocol import LocalIdentity
identity = LocalIdentity()
pub = bytes(identity.get_public_key())
rh = MagicMock()
rh.is_duplicate = MagicMock(return_value=False)
th = TraceHelper(
local_hash=pub[0],
repeater_handler=rh,
local_identity=identity,
)
trace_bytes = pub[:2] + b"\x11\x22"
flags = 1
hash_width = 2
pkt = Packet()
pkt.path = bytearray()
pkt.path_len = 0
assert th._should_forward_trace(pkt, trace_bytes, flags, hash_width)
def test_should_not_forward_when_next_hop_mismatch(self):
from repeater.handler_helpers.trace import TraceHelper
from pymc_core.protocol import LocalIdentity
identity = LocalIdentity()
pub = bytes(identity.get_public_key())
rh = MagicMock()
rh.is_duplicate = MagicMock(return_value=False)
th = TraceHelper(
local_hash=pub[0],
repeater_handler=rh,
local_identity=identity,
)
trace_bytes = bytes([pub[0] ^ 0xFF, pub[1] ^ 0xFF])
flags = 1
hash_width = 2
pkt = Packet()
pkt.path = bytearray()
pkt.path_len = 0
assert not th._should_forward_trace(pkt, trace_bytes, flags, hash_width)
# ===================================================================
# 8. Wire-level verification — manual byte inspection
# ===================================================================