From 3cb27d33108a320a845f0cd6dbf429236b3d1938 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sun, 22 Mar 2026 14:34:04 -0700 Subject: [PATCH] feat: enhance trace processing and path handling in RepeaterDaemon and TraceHelper - Added local_identity parameter to RepeaterDaemon for improved trace path matching. - Refactored trace path handling in TraceHelper to support multi-byte hashes and structured hops. - Updated methods to ensure compatibility with new trace data formats and improved logging. - Enhanced tests to validate new trace processing logic and path handling. --- repeater/handler_helpers/trace.py | 236 +++++++++++++++++++----------- repeater/main.py | 22 ++- repeater/web/api_endpoints.py | 19 ++- tests/test_path_hash_protocol.py | 70 +++++++-- 4 files changed, 234 insertions(+), 113 deletions(-) diff --git a/repeater/handler_helpers/trace.py b/repeater/handler_helpers/trace.py index be06264..448f6b0 100644 --- a/repeater/handler_helpers/trace.py +++ b/repeater/handler_helpers/trace.py @@ -9,11 +9,12 @@ of packets through the mesh network. import asyncio import logging import time -from typing import Any, Dict +from typing import Any, Dict, List from pymc_core.hardware.signal_utils import snr_register_to_db from pymc_core.node.handlers.trace import TraceHandler from pymc_core.protocol.constants import MAX_PATH_SIZE, ROUTE_TYPE_DIRECT +from pymc_core.protocol.packet_utils import PathUtils logger = logging.getLogger("TraceHelper") @@ -21,17 +22,33 @@ logger = logging.getLogger("TraceHelper") class TraceHelper: """Helper class for processing trace packets in the repeater.""" - def __init__(self, local_hash: int, repeater_handler, packet_injector=None, log_fn=None): + def __init__( + self, + local_hash: int, + repeater_handler, + packet_injector=None, + log_fn=None, + local_identity=None, + ): """ Initialize the trace helper. Args: - local_hash: The local node's hash identifier + local_hash: The local node's 1-byte hash (first byte of pubkey); legacy repeater_handler: The RepeaterHandler instance packet_injector: Callable to inject new packets into the router for sending log_fn: Optional logging function for TraceHandler + local_identity: LocalIdentity (or any object with get_public_key()) for + multibyte TRACE path matching (Mesh.cpp isHashMatch with 1< bytes: + if width <= 0 or not self._pubkey_bytes: + return b"" + return self._pubkey_bytes[:width] + async def process_trace_packet(self, packet) -> None: """ Process an incoming trace packet. @@ -57,8 +79,8 @@ class TraceHelper: packet: The trace packet to process """ try: - # Only process direct route trace packets - if packet.get_route_type() != ROUTE_TYPE_DIRECT or packet.path_len >= MAX_PATH_SIZE: + # Only process direct route trace packets (SNR path uses len(packet.path)) + if packet.get_route_type() != ROUTE_TYPE_DIRECT or len(packet.path) >= MAX_PATH_SIZE: return # Parse the trace payload @@ -68,8 +90,12 @@ class TraceHelper: logger.warning(f"Invalid trace packet: {parsed_data.get('error', 'Unknown error')}") return - trace_path = parsed_data["trace_path"] - trace_path_len = len(trace_path) + trace_bytes: bytes = parsed_data.get("trace_path_bytes") or b"" + flags = parsed_data.get("flags", 0) + hash_width = PathUtils.trace_payload_hash_width(flags) + trace_hops: List[bytes] = parsed_data.get("trace_hops") or [] + num_hops = len(trace_hops) + legacy_trace_path = parsed_data.get("trace_path") or [] # Check if this is a response to one of our pings trace_tag = parsed_data.get("tag") @@ -82,9 +108,11 @@ class TraceHelper: ) return # wait for a valid response or let timeout handle it ping_info = self.pending_pings[trace_tag] - # Store response data + # Store response data (legacy path list + structured hops) ping_info["result"] = { - "path": trace_path, + "path": legacy_trace_path, + "trace_hops": trace_hops, + "trace_path_bytes": trace_bytes, "snr": packet.get_snr(), "rssi": rssi_val, "received_at": time.time(), @@ -95,11 +123,11 @@ class TraceHelper: # Record the trace packet for dashboard/statistics if self.repeater_handler: - packet_record = self._create_trace_record(packet, trace_path, parsed_data) + packet_record = self._create_trace_record(packet, parsed_data) self.repeater_handler.log_trace_record(packet_record) # Extract and log path SNRs and hashes - path_snrs, path_hashes = self._extract_path_info(packet, trace_path) + path_snrs, path_hashes = self._extract_path_info(packet, parsed_data) # Add packet metadata for logging parsed_data["snr"] = packet.get_snr() @@ -109,16 +137,18 @@ class TraceHelper: logger.info(f"{formatted_response}") logger.info(f"Path SNRs: [{', '.join(path_snrs)}], Hashes: [{', '.join(path_hashes)}]") - # Check if we should forward this trace packet - should_forward = self._should_forward_trace(packet, trace_path, trace_path_len) + should_forward = self._should_forward_trace(packet, trace_bytes, flags, hash_width) if should_forward: - await self._forward_trace_packet(packet, trace_path_len) + await self._forward_trace_packet(packet, num_hops) else: - # This is the final destination or can't forward - just log and record - self._log_no_forward_reason(packet, trace_path, trace_path_len) - # When trace completed (reached end of path), push PUSH_CODE_TRACE_DATA (0x89) to companions (firmware onTraceRecv) - if packet.path_len >= trace_path_len and self.on_trace_complete: + self._log_no_forward_reason(packet, trace_bytes, hash_width) + if ( + self.on_trace_complete + and self._is_trace_complete(packet, trace_bytes, hash_width) + and self.repeater_handler + and not self.repeater_handler.is_duplicate(packet) + ): try: await self.on_trace_complete(packet, parsed_data) except Exception as e: @@ -127,38 +157,56 @@ class TraceHelper: except Exception as e: logger.error(f"Error processing trace packet: {e}") - def _create_trace_record(self, packet, trace_path: list, parsed_data: dict) -> Dict[str, Any]: + def _is_trace_complete(self, packet, trace_bytes: bytes, hash_width: int) -> bool: + """Mirror Mesh.cpp: offset = path_len<= len(trace hash bytes).""" + if not trace_bytes or hash_width <= 0: + return False + snr_count = len(packet.path) + return snr_count * hash_width >= len(trace_bytes) + + def _create_trace_record(self, packet, parsed_data: dict) -> Dict[str, Any]: """ Create a packet record for trace packets to log to statistics. Args: packet: The trace packet - trace_path: The parsed trace path from the payload - parsed_data: The parsed trace data + parsed_data: Full parse result from TraceHandler Returns: A dictionary containing the packet record """ - # Format trace path for display - trace_path_bytes = [f"{h:02X}" for h in trace_path[:8]] - if len(trace_path) > 8: + trace_hops: List[bytes] = parsed_data.get("trace_hops") or [] + legacy = parsed_data.get("trace_path") or [] + + trace_path_bytes = [h.hex().upper() for h in trace_hops[:8]] + if len(trace_hops) > 8: trace_path_bytes.append("...") path_hash = "[" + ", ".join(trace_path_bytes) + "]" - # Extract SNR information from the path + # Extract SNR information from the path (one SNR byte per hop along trace) path_snrs = [] path_snr_details = [] - for i in range(packet.path_len): - if i < len(packet.path): - snr_val = packet.path[i] - snr_db = snr_register_to_db(snr_val) - path_snrs.append(f"{snr_val}({snr_db:.1f}dB)") + for i in range(len(packet.path)): + snr_val = packet.path[i] + snr_db = snr_register_to_db(snr_val) + path_snrs.append(f"{snr_val}({snr_db:.1f}dB)") - # Add detailed SNR info if we have the corresponding hash - if i < len(trace_path): - path_snr_details.append( - {"hash": f"{trace_path[i]:02X}", "snr_raw": snr_val, "snr_db": snr_db} - ) + if i < len(trace_hops): + path_snr_details.append( + { + "hash": trace_hops[i].hex().upper(), + "snr_raw": snr_val, + "snr_db": snr_db, + } + ) + elif i < len(legacy): + path_snr_details.append( + { + "hash": f"{legacy[i]:02X}", + "snr_raw": snr_val, + "snr_db": snr_db, + } + ) return { "timestamp": time.time(), @@ -195,69 +243,74 @@ class TraceHelper: "path_hash": path_hash, "src_hash": None, "dst_hash": None, - "original_path": [f"{h:02X}" for h in trace_path], + "original_path": [h.hex() for h in trace_hops], "forwarded_path": None, # Add trace-specific SNR path information "path_snrs": path_snrs, # ["58(14.5dB)", "19(4.8dB)"] - "path_snr_details": path_snr_details, # [{"hash": "29", "snr_raw": 58, "snr_db": 14.5}] + "path_snr_details": path_snr_details, "is_trace": True, "raw_packet": packet.write_to().hex() if hasattr(packet, "write_to") else None, } - def _extract_path_info(self, packet, trace_path: list) -> tuple: + def _extract_path_info(self, packet, parsed_data: dict) -> tuple: """ Extract SNR and hash information from the packet path. - Args: - packet: The trace packet - trace_path: The parsed trace path from the payload - Returns: - A tuple of (path_snrs, path_hashes) lists + A tuple of (path_snrs, path_hashes) display lists """ + trace_hops: List[bytes] = parsed_data.get("trace_hops") or [] path_snrs = [] path_hashes = [] - for i in range(packet.path_len): + for i in range(len(packet.path)): if i < len(packet.path): snr_val = packet.path[i] snr_db = snr_register_to_db(snr_val) path_snrs.append(f"{snr_val}({snr_db:.1f}dB)") - if i < len(trace_path): - path_hashes.append(f"0x{trace_path[i]:02x}") + if i < len(trace_hops): + path_hashes.append(f"0x{trace_hops[i].hex()}") return path_snrs, path_hashes - def _should_forward_trace(self, packet, trace_path: list, trace_path_len: int) -> bool: + def _should_forward_trace( + self, packet, trace_bytes: bytes, flags: int, hash_width: int + ) -> bool: """ - Determine if this node should forward the trace packet. - Uses the same logic as the original working implementation. - - Args: - packet: The trace packet - trace_path: The parsed trace path from the payload - trace_path_len: The length of the trace path - - Returns: - True if the packet should be forwarded, False otherwise + Mesh.cpp TRACE branch: forward if offset < len and next hash matches identity. + offset = pkt->path_len< packet.path_len - and trace_path[packet.path_len] == self.local_hash - and self.repeater_handler - and not self.repeater_handler.is_duplicate(packet) - ) + if not trace_bytes or hash_width <= 0: + return False + snr_count = len(packet.path) + byte_off = snr_count * hash_width + if byte_off >= len(trace_bytes): + return False - async def _forward_trace_packet(self, packet, trace_path_len: int) -> None: + next_hop = trace_bytes[byte_off : byte_off + hash_width] + if len(next_hop) != hash_width: + return False + + pubkey_pfx = self._pubkey_prefix(hash_width) + if len(pubkey_pfx) >= hash_width: + match = next_hop == pubkey_pfx[:hash_width] + else: + match = hash_width == 1 and next_hop[0] == (self.local_hash & 0xFF) + + if not match: + return False + if not self.repeater_handler: + return False + return not self.repeater_handler.is_duplicate(packet) + + async def _forward_trace_packet(self, packet, num_hops: int) -> None: """ Forward a trace packet by appending SNR and sending via injection. Args: packet: The trace packet to forward - trace_path_len: The length of the trace path + num_hops: Total hops in trace path (for logging) """ # Update the packet record to show it will be transmitted if self.repeater_handler and hasattr(self.repeater_handler, "recent_packets"): @@ -290,7 +343,8 @@ class TraceHelper: packet.path_len += 1 logger.info( - f"Forwarding trace, stored SNR {current_snr:.1f}dB at position {packet.path_len - 1}" + f"Forwarding trace ({num_hops} hop path), stored SNR {current_snr:.1f}dB " + f"at SNR index {packet.path_len - 1}" ) # Inject packet into router for proper routing and transmission @@ -299,26 +353,34 @@ class TraceHelper: else: logger.warning("No packet injector available - trace packet not forwarded") - def _log_no_forward_reason(self, packet, trace_path: list, trace_path_len: int) -> None: - """ - Log the reason why a trace packet was not forwarded. - - Args: - packet: The trace packet - trace_path: The parsed trace path from the payload - trace_path_len: The length of the trace path - """ - if packet.path_len >= trace_path_len: - logger.info("Trace completed (reached end of path)") - elif len(trace_path) <= packet.path_len: - logger.info("Path index out of bounds") - elif trace_path[packet.path_len] != self.local_hash: - expected_hash = ( - trace_path[packet.path_len] if packet.path_len < len(trace_path) else None - ) - logger.info(f"Not our turn (next hop: 0x{expected_hash:02x})") - elif self.repeater_handler and self.repeater_handler.is_duplicate(packet): + def _log_no_forward_reason(self, packet, trace_bytes: bytes, hash_width: int) -> None: + """Log the reason why this node did not forward the trace.""" + if self.repeater_handler and self.repeater_handler.is_duplicate(packet): logger.info("Duplicate packet, ignoring") + return + + snr_count = len(packet.path) + if not trace_bytes or hash_width <= 0: + logger.info("Trace: empty path or invalid hash width") + return + + if snr_count * hash_width >= len(trace_bytes): + logger.info("Trace completed (reached end of path)") + return + + byte_off = snr_count * hash_width + next_hop = trace_bytes[byte_off : byte_off + hash_width] + pubkey_pfx = self._pubkey_prefix(hash_width) + if len(next_hop) == hash_width and len(pubkey_pfx) >= hash_width: + if next_hop != pubkey_pfx[:hash_width]: + logger.info(f"Not our turn (next hop: 0x{next_hop.hex()})") + return + elif hash_width == 1 and next_hop: + if (next_hop[0] & 0xFF) != (self.local_hash & 0xFF): + logger.info(f"Not our turn (next hop: 0x{next_hop.hex()})") + return + + logger.info("Trace: not forwarded (internal)") def register_ping(self, tag: int, target_hash: int) -> asyncio.Event: """Register a ping request and return an event to wait on. diff --git a/repeater/main.py b/repeater/main.py index a1a3650..de97609 100644 --- a/repeater/main.py +++ b/repeater/main.py @@ -170,6 +170,7 @@ class RepeaterDaemon: repeater_handler=self.repeater_handler, packet_injector=self.router.inject_packet, log_fn=logger.info, + local_identity=self.local_identity, ) logger.info("Trace processing helper initialized") @@ -745,21 +746,28 @@ class RepeaterDaemon: async def _on_trace_complete_for_companions(self, packet, parsed_data) -> None: """Trace completed at this node: push PUSH_CODE_TRACE_DATA (0x89) to companion clients (firmware onTraceRecv).""" - path_len = len(parsed_data.get("trace_path", [])) - if path_len == 0: + path_hashes = parsed_data.get("trace_path_bytes") or b"" + if not path_hashes: return - path_hashes = bytes(parsed_data["trace_path"]) flags = parsed_data.get("flags", 0) + path_sz = flags & 0x03 + hash_len = len(path_hashes) + expected_snr_len = hash_len >> path_sz + if expected_snr_len <= 0: + return tag = parsed_data.get("tag", 0) auth_code = parsed_data.get("auth_code", 0) - # path_snrs: exactly path_len bytes = (path_len-1) from forwarding hops + 1 (our receive SNR) snr_scaled = max(-128, min(127, int(round(packet.get_snr() * 4)))) snr_byte = snr_scaled if snr_scaled >= 0 else (256 + snr_scaled) - path_snrs = bytes(packet.path)[: path_len - 1] + bytes([snr_byte]) + # Firmware: memcpy path_snrs from pkt->path (length hash_len >> path_sz), then final SNR byte + raw = bytes(packet.path)[:expected_snr_len] + if len(raw) < expected_snr_len: + raw = raw + b"\x00" * (expected_snr_len - len(raw)) + path_snrs = raw for fs in getattr(self, "companion_frame_servers", []): try: - fs.push_trace_data( - path_len, flags, tag, auth_code, path_hashes, path_snrs, snr_byte + await fs.push_trace_data_async( + hash_len, flags, tag, auth_code, path_hashes, path_snrs, snr_byte ) except Exception as e: logger.debug("Push trace data to companion: %s", e) diff --git a/repeater/web/api_endpoints.py b/repeater/web/api_endpoints.py index df0ef92..e9ec034 100644 --- a/repeater/web/api_endpoints.py +++ b/repeater/web/api_endpoints.py @@ -2197,17 +2197,20 @@ class APIEndpoints: # Calculate round-trip time rtt_ms = (result["received_at"] - ping_info["sent_at"]) * 1000 - # result["path"] is a flat byte list from _parse_trace_payload. - # For multi-byte hash mode, group into byte_count-sized chunks - # before formatting (e.g. [0xb5, 0xd8] → ["0xb5d8"] for 2-byte mode). - raw_path = result["path"] - if byte_count > 1: + # Prefer structured hops from TraceHelper; else legacy flat list. + if result.get("trace_hops"): grouped_path = [ - int.from_bytes(bytes(raw_path[i:i + byte_count]), "big") - for i in range(0, len(raw_path), byte_count) + int.from_bytes(bytes(h), "big") for h in result["trace_hops"] ] else: - grouped_path = raw_path + raw_path = result["path"] + if byte_count > 1: + grouped_path = [ + int.from_bytes(bytes(raw_path[i : i + byte_count]), "big") + for i in range(0, len(raw_path), byte_count) + ] + else: + grouped_path = raw_path return self._success( { diff --git a/tests/test_path_hash_protocol.py b/tests/test_path_hash_protocol.py index 5c7f3dd..d650ba0 100644 --- a/tests/test_path_hash_protocol.py +++ b/tests/test_path_hash_protocol.py @@ -684,27 +684,25 @@ class TestTracePayloadParsing: assert result["trace_path"] == [0xAA, 0xBB, 0xCC] assert result["path_length"] == 3 - def test_parse_trace_with_multibyte_path_is_flat(self): - """ - Trace path is raw bytes in payload — _parse_trace_payload returns it flat. - Multi-byte grouping is NOT done at the trace parser level. - """ + def test_parse_trace_with_multibyte_path_grouped_by_flags(self): + """flags=1 → 2-byte hashes per hop (Mesh.cpp 1<