mirror of
https://github.com/rightup/pyMC_Repeater.git
synced 2026-03-28 17:43:06 +01:00
feat: enhance trace processing and path handling in RepeaterDaemon and TraceHelper
- Added local_identity parameter to RepeaterDaemon for improved trace path matching. - Refactored trace path handling in TraceHelper to support multi-byte hashes and structured hops. - Updated methods to ensure compatibility with new trace data formats and improved logging. - Enhanced tests to validate new trace processing logic and path handling.
This commit is contained in:
@@ -9,11 +9,12 @@ of packets through the mesh network.
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pymc_core.hardware.signal_utils import snr_register_to_db
|
||||
from pymc_core.node.handlers.trace import TraceHandler
|
||||
from pymc_core.protocol.constants import MAX_PATH_SIZE, ROUTE_TYPE_DIRECT
|
||||
from pymc_core.protocol.packet_utils import PathUtils
|
||||
|
||||
logger = logging.getLogger("TraceHelper")
|
||||
|
||||
@@ -21,17 +22,33 @@ logger = logging.getLogger("TraceHelper")
|
||||
class TraceHelper:
|
||||
"""Helper class for processing trace packets in the repeater."""
|
||||
|
||||
def __init__(self, local_hash: int, repeater_handler, packet_injector=None, log_fn=None):
|
||||
def __init__(
|
||||
self,
|
||||
local_hash: int,
|
||||
repeater_handler,
|
||||
packet_injector=None,
|
||||
log_fn=None,
|
||||
local_identity=None,
|
||||
):
|
||||
"""
|
||||
Initialize the trace helper.
|
||||
|
||||
Args:
|
||||
local_hash: The local node's hash identifier
|
||||
local_hash: The local node's 1-byte hash (first byte of pubkey); legacy
|
||||
repeater_handler: The RepeaterHandler instance
|
||||
packet_injector: Callable to inject new packets into the router for sending
|
||||
log_fn: Optional logging function for TraceHandler
|
||||
local_identity: LocalIdentity (or any object with get_public_key()) for
|
||||
multibyte TRACE path matching (Mesh.cpp isHashMatch with 1<<path_sz bytes)
|
||||
"""
|
||||
self.local_hash = local_hash
|
||||
self.local_identity = local_identity
|
||||
self._pubkey_bytes: bytes = b""
|
||||
if local_identity is not None and hasattr(local_identity, "get_public_key"):
|
||||
try:
|
||||
self._pubkey_bytes = bytes(local_identity.get_public_key())
|
||||
except Exception:
|
||||
self._pubkey_bytes = b""
|
||||
self.repeater_handler = repeater_handler
|
||||
self.packet_injector = packet_injector # Function to inject packets into router
|
||||
|
||||
@@ -46,6 +63,11 @@ class TraceHelper:
|
||||
# Create TraceHandler internally as a parsing utility
|
||||
self.trace_handler = TraceHandler(log_fn=log_fn or logger.info)
|
||||
|
||||
def _pubkey_prefix(self, width: int) -> bytes:
|
||||
if width <= 0 or not self._pubkey_bytes:
|
||||
return b""
|
||||
return self._pubkey_bytes[:width]
|
||||
|
||||
async def process_trace_packet(self, packet) -> None:
|
||||
"""
|
||||
Process an incoming trace packet.
|
||||
@@ -57,8 +79,8 @@ class TraceHelper:
|
||||
packet: The trace packet to process
|
||||
"""
|
||||
try:
|
||||
# Only process direct route trace packets
|
||||
if packet.get_route_type() != ROUTE_TYPE_DIRECT or packet.path_len >= MAX_PATH_SIZE:
|
||||
# Only process direct route trace packets (SNR path uses len(packet.path))
|
||||
if packet.get_route_type() != ROUTE_TYPE_DIRECT or len(packet.path) >= MAX_PATH_SIZE:
|
||||
return
|
||||
|
||||
# Parse the trace payload
|
||||
@@ -68,8 +90,12 @@ class TraceHelper:
|
||||
logger.warning(f"Invalid trace packet: {parsed_data.get('error', 'Unknown error')}")
|
||||
return
|
||||
|
||||
trace_path = parsed_data["trace_path"]
|
||||
trace_path_len = len(trace_path)
|
||||
trace_bytes: bytes = parsed_data.get("trace_path_bytes") or b""
|
||||
flags = parsed_data.get("flags", 0)
|
||||
hash_width = PathUtils.trace_payload_hash_width(flags)
|
||||
trace_hops: List[bytes] = parsed_data.get("trace_hops") or []
|
||||
num_hops = len(trace_hops)
|
||||
legacy_trace_path = parsed_data.get("trace_path") or []
|
||||
|
||||
# Check if this is a response to one of our pings
|
||||
trace_tag = parsed_data.get("tag")
|
||||
@@ -82,9 +108,11 @@ class TraceHelper:
|
||||
)
|
||||
return # wait for a valid response or let timeout handle it
|
||||
ping_info = self.pending_pings[trace_tag]
|
||||
# Store response data
|
||||
# Store response data (legacy path list + structured hops)
|
||||
ping_info["result"] = {
|
||||
"path": trace_path,
|
||||
"path": legacy_trace_path,
|
||||
"trace_hops": trace_hops,
|
||||
"trace_path_bytes": trace_bytes,
|
||||
"snr": packet.get_snr(),
|
||||
"rssi": rssi_val,
|
||||
"received_at": time.time(),
|
||||
@@ -95,11 +123,11 @@ class TraceHelper:
|
||||
|
||||
# Record the trace packet for dashboard/statistics
|
||||
if self.repeater_handler:
|
||||
packet_record = self._create_trace_record(packet, trace_path, parsed_data)
|
||||
packet_record = self._create_trace_record(packet, parsed_data)
|
||||
self.repeater_handler.log_trace_record(packet_record)
|
||||
|
||||
# Extract and log path SNRs and hashes
|
||||
path_snrs, path_hashes = self._extract_path_info(packet, trace_path)
|
||||
path_snrs, path_hashes = self._extract_path_info(packet, parsed_data)
|
||||
|
||||
# Add packet metadata for logging
|
||||
parsed_data["snr"] = packet.get_snr()
|
||||
@@ -109,16 +137,18 @@ class TraceHelper:
|
||||
logger.info(f"{formatted_response}")
|
||||
logger.info(f"Path SNRs: [{', '.join(path_snrs)}], Hashes: [{', '.join(path_hashes)}]")
|
||||
|
||||
# Check if we should forward this trace packet
|
||||
should_forward = self._should_forward_trace(packet, trace_path, trace_path_len)
|
||||
should_forward = self._should_forward_trace(packet, trace_bytes, flags, hash_width)
|
||||
|
||||
if should_forward:
|
||||
await self._forward_trace_packet(packet, trace_path_len)
|
||||
await self._forward_trace_packet(packet, num_hops)
|
||||
else:
|
||||
# This is the final destination or can't forward - just log and record
|
||||
self._log_no_forward_reason(packet, trace_path, trace_path_len)
|
||||
# When trace completed (reached end of path), push PUSH_CODE_TRACE_DATA (0x89) to companions (firmware onTraceRecv)
|
||||
if packet.path_len >= trace_path_len and self.on_trace_complete:
|
||||
self._log_no_forward_reason(packet, trace_bytes, hash_width)
|
||||
if (
|
||||
self.on_trace_complete
|
||||
and self._is_trace_complete(packet, trace_bytes, hash_width)
|
||||
and self.repeater_handler
|
||||
and not self.repeater_handler.is_duplicate(packet)
|
||||
):
|
||||
try:
|
||||
await self.on_trace_complete(packet, parsed_data)
|
||||
except Exception as e:
|
||||
@@ -127,38 +157,56 @@ class TraceHelper:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing trace packet: {e}")
|
||||
|
||||
def _create_trace_record(self, packet, trace_path: list, parsed_data: dict) -> Dict[str, Any]:
|
||||
def _is_trace_complete(self, packet, trace_bytes: bytes, hash_width: int) -> bool:
|
||||
"""Mirror Mesh.cpp: offset = path_len<<path_sz >= len(trace hash bytes)."""
|
||||
if not trace_bytes or hash_width <= 0:
|
||||
return False
|
||||
snr_count = len(packet.path)
|
||||
return snr_count * hash_width >= len(trace_bytes)
|
||||
|
||||
def _create_trace_record(self, packet, parsed_data: dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a packet record for trace packets to log to statistics.
|
||||
|
||||
Args:
|
||||
packet: The trace packet
|
||||
trace_path: The parsed trace path from the payload
|
||||
parsed_data: The parsed trace data
|
||||
parsed_data: Full parse result from TraceHandler
|
||||
|
||||
Returns:
|
||||
A dictionary containing the packet record
|
||||
"""
|
||||
# Format trace path for display
|
||||
trace_path_bytes = [f"{h:02X}" for h in trace_path[:8]]
|
||||
if len(trace_path) > 8:
|
||||
trace_hops: List[bytes] = parsed_data.get("trace_hops") or []
|
||||
legacy = parsed_data.get("trace_path") or []
|
||||
|
||||
trace_path_bytes = [h.hex().upper() for h in trace_hops[:8]]
|
||||
if len(trace_hops) > 8:
|
||||
trace_path_bytes.append("...")
|
||||
path_hash = "[" + ", ".join(trace_path_bytes) + "]"
|
||||
|
||||
# Extract SNR information from the path
|
||||
# Extract SNR information from the path (one SNR byte per hop along trace)
|
||||
path_snrs = []
|
||||
path_snr_details = []
|
||||
for i in range(packet.path_len):
|
||||
if i < len(packet.path):
|
||||
snr_val = packet.path[i]
|
||||
snr_db = snr_register_to_db(snr_val)
|
||||
path_snrs.append(f"{snr_val}({snr_db:.1f}dB)")
|
||||
for i in range(len(packet.path)):
|
||||
snr_val = packet.path[i]
|
||||
snr_db = snr_register_to_db(snr_val)
|
||||
path_snrs.append(f"{snr_val}({snr_db:.1f}dB)")
|
||||
|
||||
# Add detailed SNR info if we have the corresponding hash
|
||||
if i < len(trace_path):
|
||||
path_snr_details.append(
|
||||
{"hash": f"{trace_path[i]:02X}", "snr_raw": snr_val, "snr_db": snr_db}
|
||||
)
|
||||
if i < len(trace_hops):
|
||||
path_snr_details.append(
|
||||
{
|
||||
"hash": trace_hops[i].hex().upper(),
|
||||
"snr_raw": snr_val,
|
||||
"snr_db": snr_db,
|
||||
}
|
||||
)
|
||||
elif i < len(legacy):
|
||||
path_snr_details.append(
|
||||
{
|
||||
"hash": f"{legacy[i]:02X}",
|
||||
"snr_raw": snr_val,
|
||||
"snr_db": snr_db,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"timestamp": time.time(),
|
||||
@@ -195,69 +243,74 @@ class TraceHelper:
|
||||
"path_hash": path_hash,
|
||||
"src_hash": None,
|
||||
"dst_hash": None,
|
||||
"original_path": [f"{h:02X}" for h in trace_path],
|
||||
"original_path": [h.hex() for h in trace_hops],
|
||||
"forwarded_path": None,
|
||||
# Add trace-specific SNR path information
|
||||
"path_snrs": path_snrs, # ["58(14.5dB)", "19(4.8dB)"]
|
||||
"path_snr_details": path_snr_details, # [{"hash": "29", "snr_raw": 58, "snr_db": 14.5}]
|
||||
"path_snr_details": path_snr_details,
|
||||
"is_trace": True,
|
||||
"raw_packet": packet.write_to().hex() if hasattr(packet, "write_to") else None,
|
||||
}
|
||||
|
||||
def _extract_path_info(self, packet, trace_path: list) -> tuple:
|
||||
def _extract_path_info(self, packet, parsed_data: dict) -> tuple:
|
||||
"""
|
||||
Extract SNR and hash information from the packet path.
|
||||
|
||||
Args:
|
||||
packet: The trace packet
|
||||
trace_path: The parsed trace path from the payload
|
||||
|
||||
Returns:
|
||||
A tuple of (path_snrs, path_hashes) lists
|
||||
A tuple of (path_snrs, path_hashes) display lists
|
||||
"""
|
||||
trace_hops: List[bytes] = parsed_data.get("trace_hops") or []
|
||||
path_snrs = []
|
||||
path_hashes = []
|
||||
|
||||
for i in range(packet.path_len):
|
||||
for i in range(len(packet.path)):
|
||||
if i < len(packet.path):
|
||||
snr_val = packet.path[i]
|
||||
snr_db = snr_register_to_db(snr_val)
|
||||
path_snrs.append(f"{snr_val}({snr_db:.1f}dB)")
|
||||
|
||||
if i < len(trace_path):
|
||||
path_hashes.append(f"0x{trace_path[i]:02x}")
|
||||
if i < len(trace_hops):
|
||||
path_hashes.append(f"0x{trace_hops[i].hex()}")
|
||||
|
||||
return path_snrs, path_hashes
|
||||
|
||||
def _should_forward_trace(self, packet, trace_path: list, trace_path_len: int) -> bool:
|
||||
def _should_forward_trace(
|
||||
self, packet, trace_bytes: bytes, flags: int, hash_width: int
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if this node should forward the trace packet.
|
||||
Uses the same logic as the original working implementation.
|
||||
|
||||
Args:
|
||||
packet: The trace packet
|
||||
trace_path: The parsed trace path from the payload
|
||||
trace_path_len: The length of the trace path
|
||||
|
||||
Returns:
|
||||
True if the packet should be forwarded, False otherwise
|
||||
Mesh.cpp TRACE branch: forward if offset < len and next hash matches identity.
|
||||
offset = pkt->path_len<<path_sz uses SNR count in packet.path (len(packet.path)).
|
||||
"""
|
||||
# Use the exact logic from the original working code
|
||||
return (
|
||||
packet.path_len < trace_path_len
|
||||
and len(trace_path) > packet.path_len
|
||||
and trace_path[packet.path_len] == self.local_hash
|
||||
and self.repeater_handler
|
||||
and not self.repeater_handler.is_duplicate(packet)
|
||||
)
|
||||
if not trace_bytes or hash_width <= 0:
|
||||
return False
|
||||
snr_count = len(packet.path)
|
||||
byte_off = snr_count * hash_width
|
||||
if byte_off >= len(trace_bytes):
|
||||
return False
|
||||
|
||||
async def _forward_trace_packet(self, packet, trace_path_len: int) -> None:
|
||||
next_hop = trace_bytes[byte_off : byte_off + hash_width]
|
||||
if len(next_hop) != hash_width:
|
||||
return False
|
||||
|
||||
pubkey_pfx = self._pubkey_prefix(hash_width)
|
||||
if len(pubkey_pfx) >= hash_width:
|
||||
match = next_hop == pubkey_pfx[:hash_width]
|
||||
else:
|
||||
match = hash_width == 1 and next_hop[0] == (self.local_hash & 0xFF)
|
||||
|
||||
if not match:
|
||||
return False
|
||||
if not self.repeater_handler:
|
||||
return False
|
||||
return not self.repeater_handler.is_duplicate(packet)
|
||||
|
||||
async def _forward_trace_packet(self, packet, num_hops: int) -> None:
|
||||
"""
|
||||
Forward a trace packet by appending SNR and sending via injection.
|
||||
|
||||
Args:
|
||||
packet: The trace packet to forward
|
||||
trace_path_len: The length of the trace path
|
||||
num_hops: Total hops in trace path (for logging)
|
||||
"""
|
||||
# Update the packet record to show it will be transmitted
|
||||
if self.repeater_handler and hasattr(self.repeater_handler, "recent_packets"):
|
||||
@@ -290,7 +343,8 @@ class TraceHelper:
|
||||
packet.path_len += 1
|
||||
|
||||
logger.info(
|
||||
f"Forwarding trace, stored SNR {current_snr:.1f}dB at position {packet.path_len - 1}"
|
||||
f"Forwarding trace ({num_hops} hop path), stored SNR {current_snr:.1f}dB "
|
||||
f"at SNR index {packet.path_len - 1}"
|
||||
)
|
||||
|
||||
# Inject packet into router for proper routing and transmission
|
||||
@@ -299,26 +353,34 @@ class TraceHelper:
|
||||
else:
|
||||
logger.warning("No packet injector available - trace packet not forwarded")
|
||||
|
||||
def _log_no_forward_reason(self, packet, trace_path: list, trace_path_len: int) -> None:
|
||||
"""
|
||||
Log the reason why a trace packet was not forwarded.
|
||||
|
||||
Args:
|
||||
packet: The trace packet
|
||||
trace_path: The parsed trace path from the payload
|
||||
trace_path_len: The length of the trace path
|
||||
"""
|
||||
if packet.path_len >= trace_path_len:
|
||||
logger.info("Trace completed (reached end of path)")
|
||||
elif len(trace_path) <= packet.path_len:
|
||||
logger.info("Path index out of bounds")
|
||||
elif trace_path[packet.path_len] != self.local_hash:
|
||||
expected_hash = (
|
||||
trace_path[packet.path_len] if packet.path_len < len(trace_path) else None
|
||||
)
|
||||
logger.info(f"Not our turn (next hop: 0x{expected_hash:02x})")
|
||||
elif self.repeater_handler and self.repeater_handler.is_duplicate(packet):
|
||||
def _log_no_forward_reason(self, packet, trace_bytes: bytes, hash_width: int) -> None:
|
||||
"""Log the reason why this node did not forward the trace."""
|
||||
if self.repeater_handler and self.repeater_handler.is_duplicate(packet):
|
||||
logger.info("Duplicate packet, ignoring")
|
||||
return
|
||||
|
||||
snr_count = len(packet.path)
|
||||
if not trace_bytes or hash_width <= 0:
|
||||
logger.info("Trace: empty path or invalid hash width")
|
||||
return
|
||||
|
||||
if snr_count * hash_width >= len(trace_bytes):
|
||||
logger.info("Trace completed (reached end of path)")
|
||||
return
|
||||
|
||||
byte_off = snr_count * hash_width
|
||||
next_hop = trace_bytes[byte_off : byte_off + hash_width]
|
||||
pubkey_pfx = self._pubkey_prefix(hash_width)
|
||||
if len(next_hop) == hash_width and len(pubkey_pfx) >= hash_width:
|
||||
if next_hop != pubkey_pfx[:hash_width]:
|
||||
logger.info(f"Not our turn (next hop: 0x{next_hop.hex()})")
|
||||
return
|
||||
elif hash_width == 1 and next_hop:
|
||||
if (next_hop[0] & 0xFF) != (self.local_hash & 0xFF):
|
||||
logger.info(f"Not our turn (next hop: 0x{next_hop.hex()})")
|
||||
return
|
||||
|
||||
logger.info("Trace: not forwarded (internal)")
|
||||
|
||||
def register_ping(self, tag: int, target_hash: int) -> asyncio.Event:
|
||||
"""Register a ping request and return an event to wait on.
|
||||
|
||||
@@ -170,6 +170,7 @@ class RepeaterDaemon:
|
||||
repeater_handler=self.repeater_handler,
|
||||
packet_injector=self.router.inject_packet,
|
||||
log_fn=logger.info,
|
||||
local_identity=self.local_identity,
|
||||
)
|
||||
logger.info("Trace processing helper initialized")
|
||||
|
||||
@@ -745,21 +746,28 @@ class RepeaterDaemon:
|
||||
|
||||
async def _on_trace_complete_for_companions(self, packet, parsed_data) -> None:
|
||||
"""Trace completed at this node: push PUSH_CODE_TRACE_DATA (0x89) to companion clients (firmware onTraceRecv)."""
|
||||
path_len = len(parsed_data.get("trace_path", []))
|
||||
if path_len == 0:
|
||||
path_hashes = parsed_data.get("trace_path_bytes") or b""
|
||||
if not path_hashes:
|
||||
return
|
||||
path_hashes = bytes(parsed_data["trace_path"])
|
||||
flags = parsed_data.get("flags", 0)
|
||||
path_sz = flags & 0x03
|
||||
hash_len = len(path_hashes)
|
||||
expected_snr_len = hash_len >> path_sz
|
||||
if expected_snr_len <= 0:
|
||||
return
|
||||
tag = parsed_data.get("tag", 0)
|
||||
auth_code = parsed_data.get("auth_code", 0)
|
||||
# path_snrs: exactly path_len bytes = (path_len-1) from forwarding hops + 1 (our receive SNR)
|
||||
snr_scaled = max(-128, min(127, int(round(packet.get_snr() * 4))))
|
||||
snr_byte = snr_scaled if snr_scaled >= 0 else (256 + snr_scaled)
|
||||
path_snrs = bytes(packet.path)[: path_len - 1] + bytes([snr_byte])
|
||||
# Firmware: memcpy path_snrs from pkt->path (length hash_len >> path_sz), then final SNR byte
|
||||
raw = bytes(packet.path)[:expected_snr_len]
|
||||
if len(raw) < expected_snr_len:
|
||||
raw = raw + b"\x00" * (expected_snr_len - len(raw))
|
||||
path_snrs = raw
|
||||
for fs in getattr(self, "companion_frame_servers", []):
|
||||
try:
|
||||
fs.push_trace_data(
|
||||
path_len, flags, tag, auth_code, path_hashes, path_snrs, snr_byte
|
||||
await fs.push_trace_data_async(
|
||||
hash_len, flags, tag, auth_code, path_hashes, path_snrs, snr_byte
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Push trace data to companion: %s", e)
|
||||
|
||||
@@ -2197,17 +2197,20 @@ class APIEndpoints:
|
||||
# Calculate round-trip time
|
||||
rtt_ms = (result["received_at"] - ping_info["sent_at"]) * 1000
|
||||
|
||||
# result["path"] is a flat byte list from _parse_trace_payload.
|
||||
# For multi-byte hash mode, group into byte_count-sized chunks
|
||||
# before formatting (e.g. [0xb5, 0xd8] → ["0xb5d8"] for 2-byte mode).
|
||||
raw_path = result["path"]
|
||||
if byte_count > 1:
|
||||
# Prefer structured hops from TraceHelper; else legacy flat list.
|
||||
if result.get("trace_hops"):
|
||||
grouped_path = [
|
||||
int.from_bytes(bytes(raw_path[i:i + byte_count]), "big")
|
||||
for i in range(0, len(raw_path), byte_count)
|
||||
int.from_bytes(bytes(h), "big") for h in result["trace_hops"]
|
||||
]
|
||||
else:
|
||||
grouped_path = raw_path
|
||||
raw_path = result["path"]
|
||||
if byte_count > 1:
|
||||
grouped_path = [
|
||||
int.from_bytes(bytes(raw_path[i : i + byte_count]), "big")
|
||||
for i in range(0, len(raw_path), byte_count)
|
||||
]
|
||||
else:
|
||||
grouped_path = raw_path
|
||||
|
||||
return self._success(
|
||||
{
|
||||
|
||||
@@ -684,27 +684,25 @@ class TestTracePayloadParsing:
|
||||
assert result["trace_path"] == [0xAA, 0xBB, 0xCC]
|
||||
assert result["path_length"] == 3
|
||||
|
||||
def test_parse_trace_with_multibyte_path_is_flat(self):
|
||||
"""
|
||||
Trace path is raw bytes in payload — _parse_trace_payload returns it flat.
|
||||
Multi-byte grouping is NOT done at the trace parser level.
|
||||
"""
|
||||
def test_parse_trace_with_multibyte_path_grouped_by_flags(self):
|
||||
"""flags=1 → 2-byte hashes per hop (Mesh.cpp 1<<path_sz)."""
|
||||
th = self._make_trace_handler()
|
||||
# 2 hops of 2-byte hashes → 4 flat bytes in the payload
|
||||
path = bytes([0xAA, 0xBB, 0xCC, 0xDD])
|
||||
payload = struct.pack("<IIB", 10, 20, 0) + path
|
||||
payload = struct.pack("<IIB", 10, 20, 1) + path
|
||||
result = th._parse_trace_payload(payload)
|
||||
assert result["valid"]
|
||||
# Returns flat list, not grouped
|
||||
assert result["trace_path"] == [0xAA, 0xBB, 0xCC, 0xDD]
|
||||
assert result["path_hash_width"] == 2
|
||||
assert result["trace_hops"] == [b"\xaa\xbb", b"\xcc\xdd"]
|
||||
assert result["trace_path"] == [0xAA, 0xCC] # legacy: first byte per hop
|
||||
assert result["path_length"] == 4
|
||||
assert result["path_hop_count"] == 2
|
||||
|
||||
def test_parse_from_real_packet(self):
|
||||
"""Create a trace with PacketBuilder, serialize, deserialize, then parse."""
|
||||
th = self._make_trace_handler()
|
||||
trace_path = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66]
|
||||
pkt = PacketBuilder.create_trace(
|
||||
tag=100, auth_code=200, flags=7, path=trace_path
|
||||
tag=100, auth_code=200, flags=0, path=trace_path
|
||||
)
|
||||
wire = pkt.write_to()
|
||||
pkt2 = Packet()
|
||||
@@ -713,8 +711,10 @@ class TestTracePayloadParsing:
|
||||
assert result["valid"]
|
||||
assert result["tag"] == 100
|
||||
assert result["auth_code"] == 200
|
||||
assert result["flags"] == 7
|
||||
assert result["flags"] == 0
|
||||
assert result["trace_path"] == trace_path
|
||||
assert result["path_hop_count"] == 6
|
||||
assert result["trace_path_bytes"] == bytes(trace_path)
|
||||
|
||||
def test_parse_too_short_payload(self):
|
||||
th = self._make_trace_handler()
|
||||
@@ -723,6 +723,54 @@ class TestTracePayloadParsing:
|
||||
assert "too short" in result["error"].lower()
|
||||
|
||||
|
||||
class TestTraceHelperMultibyte:
|
||||
"""TraceHelper._should_forward_trace with 2-byte TRACE payload hashes."""
|
||||
|
||||
def test_should_forward_when_next_hop_matches_pubkey_prefix(self):
|
||||
from repeater.handler_helpers.trace import TraceHelper
|
||||
|
||||
from pymc_core.protocol import LocalIdentity
|
||||
|
||||
identity = LocalIdentity()
|
||||
pub = bytes(identity.get_public_key())
|
||||
rh = MagicMock()
|
||||
rh.is_duplicate = MagicMock(return_value=False)
|
||||
th = TraceHelper(
|
||||
local_hash=pub[0],
|
||||
repeater_handler=rh,
|
||||
local_identity=identity,
|
||||
)
|
||||
trace_bytes = pub[:2] + b"\x11\x22"
|
||||
flags = 1
|
||||
hash_width = 2
|
||||
pkt = Packet()
|
||||
pkt.path = bytearray()
|
||||
pkt.path_len = 0
|
||||
assert th._should_forward_trace(pkt, trace_bytes, flags, hash_width)
|
||||
|
||||
def test_should_not_forward_when_next_hop_mismatch(self):
|
||||
from repeater.handler_helpers.trace import TraceHelper
|
||||
|
||||
from pymc_core.protocol import LocalIdentity
|
||||
|
||||
identity = LocalIdentity()
|
||||
pub = bytes(identity.get_public_key())
|
||||
rh = MagicMock()
|
||||
rh.is_duplicate = MagicMock(return_value=False)
|
||||
th = TraceHelper(
|
||||
local_hash=pub[0],
|
||||
repeater_handler=rh,
|
||||
local_identity=identity,
|
||||
)
|
||||
trace_bytes = bytes([pub[0] ^ 0xFF, pub[1] ^ 0xFF])
|
||||
flags = 1
|
||||
hash_width = 2
|
||||
pkt = Packet()
|
||||
pkt.path = bytearray()
|
||||
pkt.path_len = 0
|
||||
assert not th._should_forward_trace(pkt, trace_bytes, flags, hash_width)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# 8. Wire-level verification — manual byte inspection
|
||||
# ===================================================================
|
||||
|
||||
Reference in New Issue
Block a user