diff --git a/pyproject.toml b/pyproject.toml index c4522c7..dffa43f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ keywords = ["mesh", "networking", "lora", "repeater", "daemon", "iot"] dependencies = [ - "pymc_core[hardware] @ git+https://github.com/rightup/pyMC_core.git@dev", + "pymc_core[hardware] @ git+https://github.com/rightup/pyMC_core.git@feat/valid-packets-checks", "pyyaml>=6.0.0", "cherrypy>=18.0.0", "paho-mqtt>=1.6.0", diff --git a/repeater/engine.py b/repeater/engine.py index cec592e..ba17935 100644 --- a/repeater/engine.py +++ b/repeater/engine.py @@ -99,7 +99,6 @@ class RepeaterHandler(BaseHandler): self._transport_keys_cache_ttl = 60 # Cache for 60 seconds self._last_drop_reason = None - self._known_neighbors = set() self._start_background_tasks() @@ -191,10 +190,6 @@ class RepeaterHandler(BaseHandler): if is_dupe and drop_reason is None: drop_reason = "Duplicate" - # Process adverts for neighbor tracking - if payload_type == PAYLOAD_TYPE_ADVERT: - self._process_advert(packet, rssi, snr) - path_hash = None display_path = ( original_path if original_path else (list(packet.path) if packet.path else []) @@ -342,85 +337,6 @@ class RepeaterHandler(BaseHandler): # Default reason return "Unknown" - def _process_advert(self, packet: Packet, rssi: int, snr: float): - - try: - from pymc_core.protocol.constants import ADVERT_FLAG_IS_REPEATER - from pymc_core.protocol.utils import ( - decode_appdata, - get_contact_type_name, - parse_advert_payload, - determine_contact_type_from_flags, - ) - - # Parse advert payload - if not packet.payload or len(packet.payload) < 40: - return - - advert_data = parse_advert_payload(packet.payload) - pubkey = advert_data.get("pubkey", "") - - # Skip our own adverts - if self.dispatcher and hasattr(self.dispatcher, "local_identity"): - local_pubkey = self.dispatcher.local_identity.get_public_key().hex() - if pubkey == local_pubkey: - logger.debug("Ignoring own advert in neighbor tracking") - return - - appdata = advert_data.get("appdata", b"") - if not appdata: - return - - appdata_decoded = decode_appdata(appdata) - flags = appdata_decoded.get("flags", 0) - is_repeater = bool(flags & ADVERT_FLAG_IS_REPEATER) - route_type = packet.header & PH_ROUTE_MASK - contact_type_id = determine_contact_type_from_flags(flags) - contact_type = get_contact_type_name(contact_type_id) - - # Extract neighbor info - node_name = appdata_decoded.get("node_name", "Unknown") - latitude = appdata_decoded.get("latitude") - longitude = appdata_decoded.get("longitude") - - current_time = time.time() - - if pubkey not in self._known_neighbors: - # Only check database if not in cache - current_neighbors = self.storage.get_neighbors() if self.storage else {} - is_new_neighbor = pubkey not in current_neighbors - - if is_new_neighbor: - self._known_neighbors.add(pubkey) - else: - is_new_neighbor = False - - advert_record = { - "timestamp": current_time, - "pubkey": pubkey, - "node_name": node_name, - "is_repeater": is_repeater, - "route_type": route_type, - "contact_type": contact_type, - "latitude": latitude, - "longitude": longitude, - "rssi": rssi, - "snr": snr, - "is_new_neighbor": is_new_neighbor, - } - - # Store to database - if self.storage: - try: - self.storage.record_advert(advert_record) - if is_new_neighbor: - logger.info(f"Discovered new neighbor: {node_name} ({pubkey[:16]}...)") - except Exception as e: - logger.error(f"Failed to store advert record: {e}") - - except Exception as e: - logger.debug(f"Error processing advert for neighbor tracking: {e}") - def is_duplicate(self, packet: Packet) -> bool: pkt_hash = packet.calculate_packet_hash().hex().upper() diff --git a/repeater/handler_helpers/__init__.py b/repeater/handler_helpers/__init__.py new file mode 100644 index 0000000..08f5104 --- /dev/null +++ b/repeater/handler_helpers/__init__.py @@ -0,0 +1,7 @@ +"""Handler helper modules for pyMC Repeater.""" + +from .trace import TraceHelper +from .discovery import DiscoveryHelper +from .advert import AdvertHelper + +__all__ = ["TraceHelper", "DiscoveryHelper", "AdvertHelper"] diff --git a/repeater/handler_helpers/advert.py b/repeater/handler_helpers/advert.py new file mode 100644 index 0000000..377f902 --- /dev/null +++ b/repeater/handler_helpers/advert.py @@ -0,0 +1,112 @@ +""" +Advertisement packet handling helper for pyMC Repeater. + +This module processes advertisement packets for neighbor tracking and discovery. +""" + +import logging +import time + +from pymc_core.node.handlers.advert import AdvertHandler + +logger = logging.getLogger("AdvertHelper") + + +class AdvertHelper: + """Helper class for processing advertisement packets in the repeater.""" + + def __init__(self, local_identity, storage, log_fn=None): + """ + Initialize the advert helper. + + Args: + local_identity: The LocalIdentity instance for this repeater + storage: StorageCollector instance for persisting advert data + log_fn: Optional logging function for AdvertHandler + """ + self.local_identity = local_identity + self.storage = storage + + # Create AdvertHandler internally as a parsing utility + self.advert_handler = AdvertHandler(log_fn=log_fn or logger.info) + + # Cache for tracking known neighbors (avoid repeated database queries) + self._known_neighbors = set() + + async def process_advert_packet(self, packet, rssi: int, snr: float) -> None: + """ + Process an incoming advertisement packet. + + This method uses AdvertHandler to parse the packet, then stores + the neighbor information for tracking and discovery. + + Args: + packet: The advertisement packet to process + rssi: Received signal strength indicator + snr: Signal-to-noise ratio + """ + try: + # Set signal metrics on packet for handler to use + packet._snr = snr + packet._rssi = rssi + + # Use AdvertHandler to parse the packet - it now returns parsed data + advert_data = await self.advert_handler(packet) + + if not advert_data or not advert_data.get("valid"): + logger.debug("Invalid advert packet") + return + + # Extract data from parsed advert + pubkey = advert_data["public_key"] + node_name = advert_data["name"] + contact_type = advert_data["contact_type"] + + # Skip our own adverts + if self.local_identity: + local_pubkey = self.local_identity.get_public_key().hex() + if pubkey == local_pubkey: + logger.debug("Ignoring own advert in neighbor tracking") + return + + # Get route type from packet header + from pymc_core.protocol.constants import PH_ROUTE_MASK + route_type = packet.header & PH_ROUTE_MASK + + # Check if this is a new neighbor + current_time = time.time() + if pubkey not in self._known_neighbors: + # Only check database if not in cache + current_neighbors = self.storage.get_neighbors() if self.storage else {} + is_new_neighbor = pubkey not in current_neighbors + + if is_new_neighbor: + self._known_neighbors.add(pubkey) + logger.info(f"Discovered new neighbor: {node_name} ({pubkey[:16]}...)") + else: + is_new_neighbor = False + + # Build advert record + advert_record = { + "timestamp": current_time, + "pubkey": pubkey, + "node_name": node_name, + "is_repeater": "REPEATER" in contact_type.upper(), + "route_type": route_type, + "contact_type": contact_type, + "latitude": advert_data["latitude"], + "longitude": advert_data["longitude"], + "rssi": rssi, + "snr": snr, + "is_new_neighbor": is_new_neighbor, + } + + # Store to database + if self.storage: + try: + self.storage.record_advert(advert_record) + except Exception as e: + logger.error(f"Failed to store advert record: {e}") + + except Exception as e: + logger.error(f"Error processing advert packet: {e}", exc_info=True) diff --git a/repeater/handler_helpers/discovery.py b/repeater/handler_helpers/discovery.py new file mode 100644 index 0000000..5269e10 --- /dev/null +++ b/repeater/handler_helpers/discovery.py @@ -0,0 +1,132 @@ +""" +Discovery request/response handling helper for pyMC Repeater. + +This module handles the processing and response to discovery requests, +allowing other nodes to discover repeaters on the mesh network. +""" + +import asyncio +import logging +from typing import Optional, Callable, Any + +from pymc_core.node.handlers.control import ControlHandler + +logger = logging.getLogger("DiscoveryHelper") + + +class DiscoveryHelper: + """Helper class for processing discovery requests in the repeater.""" + + def __init__( + self, + local_identity, + dispatcher, + node_type: int = 2, + log_fn=None, + ): + """ + Initialize the discovery helper. + + Args: + local_identity: The LocalIdentity instance for this repeater + dispatcher: The Dispatcher instance for sending packets + node_type: Node type identifier (2 = Repeater) + log_fn: Optional logging function for ControlHandler + """ + self.local_identity = local_identity + self.dispatcher = dispatcher + self.node_type = node_type + + # Create ControlHandler internally as a parsing utility + self.control_handler = ControlHandler(log_fn=log_fn or logger.info) + + # Set up the request callback + self.control_handler.set_request_callback(self._on_discovery_request) + logger.debug("Discovery handler initialized") + + def _on_discovery_request(self, request_data: dict) -> None: + """ + Handle incoming discovery request. + + Args: + request_data: Dictionary containing the parsed discovery request + """ + try: + tag = request_data.get("tag", 0) + filter_byte = request_data.get("filter", 0) + prefix_only = request_data.get("prefix_only", False) + snr = request_data.get("snr", 0.0) + rssi = request_data.get("rssi", 0) + + logger.info( + f"Request: tag=0x{tag:08X}, filter=0x{filter_byte:02X}, " + f"SNR={snr:+.1f}dB, RSSI={rssi}dBm" + ) + + # Check if filter matches our node type (repeater = 2, filter_mask = 0x04) + filter_mask = 1 << self.node_type # 1 << 2 = 0x04 + if (filter_byte & filter_mask) == 0: + logger.debug("Filter doesn't match, ignoring") + return + + logger.info("Sending response...") + + if self.local_identity: + self._send_discovery_response(tag, self.node_type, snr, prefix_only) + else: + logger.warning("No local identity available for response") + + except Exception as e: + logger.error(f"Error handling request: {e}") + + def _send_discovery_response( + self, + tag: int, + node_type: int, + inbound_snr: float, + prefix_only: bool, + ) -> None: + """ + Create and send a discovery response packet. + + Args: + tag: The tag from the discovery request + node_type: Node type identifier + inbound_snr: SNR of the received request + prefix_only: Whether to use prefix-only mode + """ + try: + our_pub_key = self.local_identity.get_public_key() + + from pymc_core.protocol.packet_builder import PacketBuilder + + response_packet = PacketBuilder.create_discovery_response( + tag=tag, + node_type=node_type, + inbound_snr=inbound_snr, + pub_key=our_pub_key, + prefix_only=prefix_only, + ) + + # Send response asynchronously + asyncio.create_task(self._send_packet_async(response_packet, tag)) + + except Exception as e: + logger.error(f"Error creating discovery response: {e}") + + async def _send_packet_async(self, packet, tag: int) -> None: + """ + Send a discovery response packet asynchronously. + + Args: + packet: The packet to send + tag: The tag for logging purposes + """ + try: + success = await self.dispatcher.send_packet(packet, wait_for_ack=False) + if success: + logger.info(f"Response sent for tag 0x{tag:08X}") + else: + logger.warning(f"Failed to send response for tag 0x{tag:08X}") + except Exception as e: + logger.error(f"Error sending response: {e}") diff --git a/repeater/handler_helpers/trace.py b/repeater/handler_helpers/trace.py new file mode 100644 index 0000000..1763033 --- /dev/null +++ b/repeater/handler_helpers/trace.py @@ -0,0 +1,285 @@ +""" +Trace packet handling helper for pyMC Repeater. + +This module handles the processing and forwarding of trace packets, +which are used for network diagnostics to track the path and SNR +of packets through the mesh network. +""" + +import logging +import time +from typing import Optional, Dict, Any + +from pymc_core.node.handlers.trace import TraceHandler +from pymc_core.protocol.constants import MAX_PATH_SIZE, ROUTE_TYPE_DIRECT + +logger = logging.getLogger("TraceHelper") + + +class TraceHelper: + """Helper class for processing trace packets in the repeater.""" + + def __init__(self, local_hash: int, repeater_handler, dispatcher, log_fn=None): + """ + Initialize the trace helper. + + Args: + local_hash: The local node's hash identifier + repeater_handler: The RepeaterHandler instance + dispatcher: The Dispatcher instance for sending packets + log_fn: Optional logging function for TraceHandler + """ + self.local_hash = local_hash + self.repeater_handler = repeater_handler + self.dispatcher = dispatcher + + # Create TraceHandler internally as a parsing utility + self.trace_handler = TraceHandler(log_fn=log_fn or logger.info) + + async def process_trace_packet(self, packet) -> None: + """ + Process an incoming trace packet. + + This method handles trace packet validation, logging, recording, + and forwarding if this node is the next hop in the trace path. + + Args: + packet: The trace packet to process + """ + try: + # Only process direct route trace packets + if packet.get_route_type() != ROUTE_TYPE_DIRECT or packet.path_len >= MAX_PATH_SIZE: + return + + # Parse the trace payload + parsed_data = self.trace_handler._parse_trace_payload(packet.payload) + + if not parsed_data.get("valid", False): + logger.warning( + f"Invalid trace packet: {parsed_data.get('error', 'Unknown error')}" + ) + return + + trace_path = parsed_data["trace_path"] + trace_path_len = len(trace_path) + + # Record the trace packet for dashboard/statistics + if self.repeater_handler: + packet_record = self._create_trace_record(packet, trace_path, parsed_data) + self.repeater_handler.log_trace_record(packet_record) + + # Extract and log path SNRs and hashes + path_snrs, path_hashes = self._extract_path_info(packet, trace_path) + + # Add packet metadata for logging + parsed_data["snr"] = packet.get_snr() + parsed_data["rssi"] = getattr(packet, "rssi", 0) + formatted_response = self.trace_handler._format_trace_response(parsed_data) + + logger.info(f"{formatted_response}") + logger.info(f"Path SNRs: [{', '.join(path_snrs)}], Hashes: [{', '.join(path_hashes)}]") + + # Check if we should forward this trace packet + should_forward = self._should_forward_trace(packet, trace_path, trace_path_len) + + if should_forward: + await self._forward_trace_packet(packet, trace_path_len) + else: + self._log_no_forward_reason(packet, trace_path, trace_path_len) + + except Exception as e: + logger.error(f"Error processing trace packet: {e}") + + def _create_trace_record(self, packet, trace_path: list, parsed_data: dict) -> Dict[str, Any]: + """ + Create a packet record for trace packets to log to statistics. + + Args: + packet: The trace packet + trace_path: The parsed trace path from the payload + parsed_data: The parsed trace data + + Returns: + A dictionary containing the packet record + """ + # Format trace path for display + trace_path_bytes = [f"{h:02X}" for h in trace_path[:8]] + if len(trace_path) > 8: + trace_path_bytes.append("...") + path_hash = "[" + ", ".join(trace_path_bytes) + "]" + + # Extract SNR information from the path + path_snrs = [] + path_snr_details = [] + for i in range(packet.path_len): + if i < len(packet.path): + snr_val = packet.path[i] + # Convert unsigned byte to signed SNR + snr_signed = snr_val if snr_val < 128 else snr_val - 256 + snr_db = snr_signed / 4.0 + path_snrs.append(f"{snr_val}({snr_db:.1f}dB)") + + # Add detailed SNR info if we have the corresponding hash + if i < len(trace_path): + path_snr_details.append({ + "hash": f"{trace_path[i]:02X}", + "snr_raw": snr_val, + "snr_db": snr_db + }) + + return { + "timestamp": time.time(), + "header": f"0x{packet.header:02X}" if hasattr(packet, "header") and packet.header is not None else None, + "payload": packet.payload.hex() if hasattr(packet, "payload") and packet.payload else None, + "payload_length": len(packet.payload) if hasattr(packet, "payload") and packet.payload else 0, + "type": packet.get_payload_type(), # 0x09 for trace + "route": packet.get_route_type(), # Should be direct (1) + "length": len(packet.payload or b""), + "rssi": getattr(packet, "rssi", 0), + "snr": getattr(packet, "snr", 0.0), + "score": self.repeater_handler.calculate_packet_score( + getattr(packet, "snr", 0.0), + len(packet.payload or b""), + self.repeater_handler.radio_config.get("spreading_factor", 8) + ) if self.repeater_handler else 0.0, + "tx_delay_ms": 0, + "transmitted": False, + "is_duplicate": False, + "packet_hash": packet.calculate_packet_hash().hex().upper()[:16], + "drop_reason": "trace_received", + "path_hash": path_hash, + "src_hash": None, + "dst_hash": None, + "original_path": [f"{h:02X}" for h in trace_path], + "forwarded_path": None, + # Add trace-specific SNR path information + "path_snrs": path_snrs, # ["58(14.5dB)", "19(4.8dB)"] + "path_snr_details": path_snr_details, # [{"hash": "29", "snr_raw": 58, "snr_db": 14.5}] + "is_trace": True, + "raw_packet": packet.write_to().hex() if hasattr(packet, "write_to") else None, + } + + def _extract_path_info(self, packet, trace_path: list) -> tuple: + """ + Extract SNR and hash information from the packet path. + + Args: + packet: The trace packet + trace_path: The parsed trace path from the payload + + Returns: + A tuple of (path_snrs, path_hashes) lists + """ + path_snrs = [] + path_hashes = [] + + for i in range(packet.path_len): + if i < len(packet.path): + snr_val = packet.path[i] + snr_signed = snr_val if snr_val < 128 else snr_val - 256 + snr_db = snr_signed / 4.0 + path_snrs.append(f"{snr_val}({snr_db:.1f}dB)") + + if i < len(trace_path): + path_hashes.append(f"0x{trace_path[i]:02x}") + + return path_snrs, path_hashes + + def _should_forward_trace(self, packet, trace_path: list, trace_path_len: int) -> bool: + """ + Determine if this node should forward the trace packet. + + Args: + packet: The trace packet + trace_path: The parsed trace path from the payload + trace_path_len: The length of the trace path + + Returns: + True if the packet should be forwarded, False otherwise + """ + # Check if we've reached the end of the trace path + if packet.path_len >= trace_path_len: + return False + + # Check if path index is valid + if len(trace_path) <= packet.path_len: + return False + + # Check if this node is the next hop + if trace_path[packet.path_len] != self.local_hash: + return False + + # Check for duplicates + if self.repeater_handler and self.repeater_handler.is_duplicate(packet): + return False + + return True + + async def _forward_trace_packet(self, packet, trace_path_len: int) -> None: + """ + Forward a trace packet by appending SNR and sending it. + + Args: + packet: The trace packet to forward + trace_path_len: The length of the trace path + """ + # Update the packet record to show it was transmitted + if self.repeater_handler and hasattr(self.repeater_handler, 'recent_packets'): + packet_hash = packet.calculate_packet_hash().hex().upper()[:16] + for record in reversed(self.repeater_handler.recent_packets): + if record.get("packet_hash") == packet_hash: + record["transmitted"] = True + record["drop_reason"] = "trace_forwarded" + break + + # Get current SNR and scale it for storage (SNR * 4) + current_snr = packet.get_snr() + snr_scaled = int(current_snr * 4) + + # Clamp to signed byte range [-128, 127] + if snr_scaled > 127: + snr_scaled = 127 + elif snr_scaled < -128: + snr_scaled = -128 + + # Convert to unsigned byte representation + snr_byte = snr_scaled if snr_scaled >= 0 else (256 + snr_scaled) + + # Ensure path array is long enough + while len(packet.path) <= packet.path_len: + packet.path.append(0) + + # Store SNR at current position and increment path length + packet.path[packet.path_len] = snr_byte + packet.path_len += 1 + + logger.info( + f"Forwarding trace, stored SNR {current_snr:.1f}dB at position {packet.path_len - 1}" + ) + + # Mark as seen - packet will flow to repeater handler via pipeline + # which will apply all forwarding rules and validation + if self.repeater_handler: + self.repeater_handler.mark_seen(packet) + + # Don't send directly - let packet flow to repeater handler through pipeline + # The pipeline will pass this modified packet to the repeater for validation and forwarding + + def _log_no_forward_reason(self, packet, trace_path: list, trace_path_len: int) -> None: + """ + Log the reason why a trace packet was not forwarded. + + Args: + packet: The trace packet + trace_path: The parsed trace path from the payload + trace_path_len: The length of the trace path + """ + if packet.path_len >= trace_path_len: + logger.info("Trace completed (reached end of path)") + elif len(trace_path) <= packet.path_len: + logger.info("Path index out of bounds") + elif trace_path[packet.path_len] != self.local_hash: + expected_hash = trace_path[packet.path_len] if packet.path_len < len(trace_path) else None + logger.info(f"Not our turn (next hop: 0x{expected_hash:02x})") + elif self.repeater_handler and self.repeater_handler.is_duplicate(packet): + logger.info("Duplicate packet, ignoring") diff --git a/repeater/main.py b/repeater/main.py index 1a5232e..7be8558 100644 --- a/repeater/main.py +++ b/repeater/main.py @@ -6,8 +6,8 @@ import sys from repeater.config import get_radio_for_board, load_config from repeater.engine import RepeaterHandler from repeater.web.http_server import HTTPStatsServer, _log_buffer -from pymc_core.node.handlers.trace import TraceHandler -from pymc_core.protocol.constants import MAX_PATH_SIZE, ROUTE_TYPE_DIRECT +from repeater.handler_helpers import TraceHelper, DiscoveryHelper, AdvertHelper +from repeater.packet_pipeline import PacketPipeline logger = logging.getLogger("RepeaterDaemon") @@ -23,7 +23,10 @@ class RepeaterDaemon: self.local_hash = None self.local_identity = None self.http_server = None - self.trace_handler = None + self.trace_helper = None + self.advert_helper = None + self.discovery_helper = None + self.pipeline = None log_level = config.get("logging", {}).get("level", "INFO") @@ -45,7 +48,6 @@ class RepeaterDaemon: try: self.radio = get_radio_for_board(self.config) - if hasattr(self.radio, 'set_custom_cad_thresholds'): # Load CAD settings from config, with defaults cad_config = self.config.get("radio", {}).get("cad", {}) @@ -91,7 +93,6 @@ class RepeaterDaemon: self.local_identity = local_identity self.dispatcher.local_identity = local_identity - pubkey = local_identity.get_public_key() self.local_hash = pubkey[0] logger.info(f"Local identity set: {local_identity.get_address_bytes().hex()}") @@ -105,266 +106,60 @@ class RepeaterDaemon: self.config, self.dispatcher, self.local_hash, send_advert_func=self.send_advert ) - self.dispatcher.register_fallback_handler(self._repeater_callback) - logger.info("Repeater handler registered (forwarder mode)") - - self.trace_handler = TraceHandler(log_fn=logger.info) + # Create pipeline + self.pipeline = PacketPipeline(self) + await self.pipeline.start() - self.dispatcher.register_handler( - TraceHandler.payload_type(), - self._trace_callback, - ) - logger.info("Trace handler registered for network diagnostics") + # Register pipeline as entry point for ALL packets via fallback handler + # All received packets flow through pipeline → helpers → repeater validation + self.dispatcher.register_fallback_handler(self._pipeline_callback) + logger.info("Pipeline registered as fallback (catches all packets)") + # Create processing helpers (handlers created internally) + self.trace_helper = TraceHelper( + local_hash=self.local_hash, + repeater_handler=self.repeater_handler, + dispatcher=self.dispatcher, + log_fn=logger.info, + ) + logger.info("Trace processing helper initialized") + + # Create advert helper for neighbor tracking + self.advert_helper = AdvertHelper( + local_identity=self.local_identity, + storage=self.repeater_handler.storage if self.repeater_handler else None, + log_fn=logger.info, + ) + logger.info("Advert processing helper initialized") + + # Set up discovery handler if enabled allow_discovery = self.config.get("repeater", {}).get("allow_discovery", True) if allow_discovery: - self._setup_discovery_handler() - logger.info("Discovery response handler enabled") + self.discovery_helper = DiscoveryHelper( + local_identity=self.local_identity, + dispatcher=self.dispatcher, + node_type=2, + log_fn=logger.info, + ) + logger.info("Discovery processing helper initialized") else: logger.info("Discovery response handler disabled") - - except Exception as e: logger.error(f"Failed to initialize dispatcher: {e}") raise - async def _repeater_callback(self, packet): - - if self.repeater_handler: - - metadata = { - "rssi": getattr(packet, "rssi", 0), - "snr": getattr(packet, "snr", 0.0), - "timestamp": getattr(packet, "timestamp", 0), - } - await self.repeater_handler(packet, metadata) - - async def _trace_callback(self, packet): - - try: - # Only process direct route trace packets - if packet.get_route_type() != ROUTE_TYPE_DIRECT or packet.path_len >= MAX_PATH_SIZE: - return - - - parsed_data = self.trace_handler._parse_trace_payload(packet.payload) - - if not parsed_data.get("valid", False): - logger.warning(f"[TraceHandler] Invalid trace packet: {parsed_data.get('error', 'Unknown error')}") - return - - trace_path = parsed_data["trace_path"] - trace_path_len = len(trace_path) - - - if self.repeater_handler: - import time - - trace_path_bytes = [f"{h:02X}" for h in trace_path[:8]] - if len(trace_path) > 8: - trace_path_bytes.append("...") - path_hash = "[" + ", ".join(trace_path_bytes) + "]" - - path_snrs = [] - path_snr_details = [] - for i in range(packet.path_len): - if i < len(packet.path): - snr_val = packet.path[i] - - snr_signed = snr_val if snr_val < 128 else snr_val - 256 - snr_db = snr_signed / 4.0 - path_snrs.append(f"{snr_val}({snr_db:.1f}dB)") - - if i < len(trace_path): - path_snr_details.append({ - "hash": f"{trace_path[i]:02X}", - "snr_raw": snr_val, - "snr_db": snr_db - }) - - packet_record = { - "timestamp": time.time(), - "header": f"0x{packet.header:02X}" if hasattr(packet, "header") and packet.header is not None else None, - "payload": packet.payload.hex() if hasattr(packet, "payload") and packet.payload else None, - "payload_length": len(packet.payload) if hasattr(packet, "payload") and packet.payload else 0, - "type": packet.get_payload_type(), # 0x09 for trace - "route": packet.get_route_type(), # Should be direct (1) - "length": len(packet.payload or b""), - "rssi": getattr(packet, "rssi", 0), - "snr": getattr(packet, "snr", 0.0), - "score": self.repeater_handler.calculate_packet_score( - getattr(packet, "snr", 0.0), - len(packet.payload or b""), - self.repeater_handler.radio_config.get("spreading_factor", 8) - ), - "tx_delay_ms": 0, - "transmitted": False, - "is_duplicate": False, - "packet_hash": packet.calculate_packet_hash().hex().upper()[:16], - "drop_reason": "trace_received", - "path_hash": path_hash, - "src_hash": None, - "dst_hash": None, - "original_path": [f"{h:02X}" for h in trace_path], - "forwarded_path": None, - # Add trace-specific SNR path information - "path_snrs": path_snrs, # ["58(14.5dB)", "19(4.8dB)"] - "path_snr_details": path_snr_details, # [{"hash": "29", "snr_raw": 58, "snr_db": 14.5}] - "is_trace": True, - "raw_packet": packet.write_to().hex() if hasattr(packet, "write_to") else None, - } - self.repeater_handler.log_trace_record(packet_record) - - path_snrs = [] - path_hashes = [] - for i in range(packet.path_len): - if i < len(packet.path): - snr_val = packet.path[i] - snr_signed = snr_val if snr_val < 128 else snr_val - 256 - snr_db = snr_signed / 4.0 - path_snrs.append(f"{snr_val}({snr_db:.1f}dB)") - if i < len(trace_path): - path_hashes.append(f"0x{trace_path[i]:02x}") - - - parsed_data["snr"] = packet.get_snr() - parsed_data["rssi"] = getattr(packet, "rssi", 0) - formatted_response = self.trace_handler._format_trace_response(parsed_data) - - logger.info(f"[TraceHandler] {formatted_response}") - logger.info(f"[TraceHandler] Path SNRs: [{', '.join(path_snrs)}], Hashes: [{', '.join(path_hashes)}]") - - - if (packet.path_len < trace_path_len and - len(trace_path) > packet.path_len and - trace_path[packet.path_len] == self.local_hash and - self.repeater_handler and not self.repeater_handler.is_duplicate(packet)): - - if self.repeater_handler and hasattr(self.repeater_handler, 'recent_packets'): - packet_hash = packet.calculate_packet_hash().hex().upper()[:16] - for record in reversed(self.repeater_handler.recent_packets): - if record.get("packet_hash") == packet_hash: - record["transmitted"] = True - record["drop_reason"] = "trace_forwarded" - break - - current_snr = packet.get_snr() - - - snr_scaled = int(current_snr * 4) - - if snr_scaled > 127: - snr_scaled = 127 - elif snr_scaled < -128: - snr_scaled = -128 - - snr_byte = snr_scaled if snr_scaled >= 0 else (256 + snr_scaled) - - while len(packet.path) <= packet.path_len: - packet.path.append(0) - - packet.path[packet.path_len] = snr_byte - packet.path_len += 1 - - logger.info(f"[TraceHandler] Forwarding trace, stored SNR {current_snr:.1f}dB at position {packet.path_len-1}") - - # Mark as seen and forward directly (bypass normal routing, no ACK required) - self.repeater_handler.mark_seen(packet) - if self.dispatcher: - await self.dispatcher.send_packet(packet, wait_for_ack=False) - else: - # Show why we didn't forward - if packet.path_len >= trace_path_len: - logger.info(f"[TraceHandler] Trace completed (reached end of path)") - elif len(trace_path) <= packet.path_len: - logger.info(f"[TraceHandler] Path index out of bounds") - elif trace_path[packet.path_len] != self.local_hash: - expected_hash = trace_path[packet.path_len] if packet.path_len < len(trace_path) else None - logger.info(f"[TraceHandler] Not our turn (next hop: 0x{expected_hash:02x})") - elif self.repeater_handler and self.repeater_handler.is_duplicate(packet): - logger.info(f"[TraceHandler] Duplicate packet, ignoring") - - except Exception as e: - logger.error(f"[TraceHandler] Error processing trace packet: {e}") - - def _setup_discovery_handler(self): - """Set up discovery request/response handling.""" - try: - from pymc_core.node.handlers.control import ControlHandler - - self.control_handler = ControlHandler(log_fn=logger.info) - self.dispatcher.register_handler( - ControlHandler.payload_type(), - self._control_callback, - ) - - # Node type 2 = Repeater - node_type = 2 - - def on_discovery_request(request_data: dict): - """Handle incoming discovery request.""" - try: - tag = request_data.get("tag", 0) - filter_byte = request_data.get("filter", 0) - prefix_only = request_data.get("prefix_only", False) - snr = request_data.get("snr", 0.0) - rssi = request_data.get("rssi", 0) - - logger.info(f"[Discovery] Request: tag=0x{tag:08X}, filter=0x{filter_byte:02X}, SNR={snr:+.1f}dB, RSSI={rssi}dBm") - - # Check if filter matches our node type (repeater = 2, filter_mask = 0x04) - filter_mask = 1 << node_type # 1 << 2 = 0x04 - if (filter_byte & filter_mask) == 0: - logger.debug("[Discovery] Filter doesn't match, ignoring") - return - - logger.info("[Discovery] Sending response...") - - if self.local_identity: - our_pub_key = self.local_identity.get_public_key() - - from pymc_core.protocol.packet_builder import PacketBuilder - response_packet = PacketBuilder.create_discovery_response( - tag=tag, - node_type=node_type, - inbound_snr=snr, - pub_key=our_pub_key, - prefix_only=prefix_only, - ) - - # Send response asynchronously - asyncio.create_task(self._send_discovery_response(response_packet, tag)) - else: - logger.warning("[Discovery] No local identity available for response") - - except Exception as e: - logger.error(f"[Discovery] Error handling request: {e}") - - self.control_handler.set_request_callback(on_discovery_request) - logger.debug("[Discovery] Handler registered") - - except Exception as e: - logger.error(f"Failed to setup discovery handler: {e}") - - async def _control_callback(self, packet): - if self.control_handler: - await self.control_handler(packet) - - async def _send_discovery_response(self, packet, tag): - try: - success = await self.dispatcher.send_packet(packet, wait_for_ack=False) - if success: - logger.info(f"[Discovery] Response sent for tag 0x{tag:08X}") - else: - logger.warning(f"[Discovery] Failed to send response for tag 0x{tag:08X}") - except Exception as e: - logger.error(f"[Discovery] Error sending response: {e}") - - + async def _pipeline_callback(self, packet): + """ + Single entry point for ALL packets. + Enqueues packets for pipeline processing. + """ + if self.pipeline: + await self.pipeline.enqueue(packet) def get_stats(self) -> dict: - + stats = {} + if self.repeater_handler: stats = self.repeater_handler.get_stats() # Add public key if available @@ -374,8 +169,12 @@ class RepeaterDaemon: stats["public_key"] = pubkey.hex() except Exception: stats["public_key"] = None - return stats - return {} + + # Add pipeline statistics + if self.pipeline: + stats["pipeline"] = self.pipeline.get_stats() + + return stats async def send_advert(self) -> bool: @@ -468,6 +267,8 @@ class RepeaterDaemon: await self.dispatcher.run_forever() except KeyboardInterrupt: logger.info("Shutting down...") + if self.pipeline: + await self.pipeline.stop() if self.http_server: self.http_server.stop() diff --git a/repeater/packet_pipeline.py b/repeater/packet_pipeline.py new file mode 100644 index 0000000..dce175b --- /dev/null +++ b/repeater/packet_pipeline.py @@ -0,0 +1,211 @@ +""" +Packet processing pipeline for pyMC Repeater. + +This module provides a queue-based pipeline that processes packets through handlers +sequentially, tracks statistics, and ensures all packets flow through repeater logic. +""" + +import asyncio +import logging +import time +from collections import deque + +from pymc_core.node.handlers.trace import TraceHandler +from pymc_core.node.handlers.control import ControlHandler +from pymc_core.node.handlers.advert import AdvertHandler +from pymc_core.protocol.utils import get_packet_type_name + +logger = logging.getLogger("PacketPipeline") + + +class PacketPipeline: + """ + Pipeline that processes packets through handlers sequentially. + Tracks queue statistics and ensures all packets flow through repeater logic. + """ + + def __init__(self, daemon_instance): + self.daemon = daemon_instance + self.queue = asyncio.Queue() + self.running = False + self.pipeline_task = None + + # Statistics tracking + self.stats = { + "total_enqueued": 0, + "total_processed": 0, + "total_errors": 0, + "current_queue_size": 0, + "max_queue_size": 0, + "processing_times": deque(maxlen=100), # Last 100 processing times + "packets_by_type": {}, + "packets_marked_no_retransmit": 0, + "packets_forwarded": 0, + } + self.last_stats_log = time.time() + + async def start(self): + """Start the pipeline processing task.""" + self.running = True + self.pipeline_task = asyncio.create_task(self._process_pipeline()) + logger.info("Packet pipeline started") + + async def stop(self): + """Stop the pipeline processing task.""" + self.running = False + if self.pipeline_task: + self.pipeline_task.cancel() + try: + await self.pipeline_task + except asyncio.CancelledError: + pass + logger.info("Packet pipeline stopped") + self._log_final_stats() + + async def enqueue(self, packet): + """Add packet to pipeline queue and track statistics.""" + await self.queue.put(packet) + self.stats["total_enqueued"] += 1 + self.stats["current_queue_size"] = self.queue.qsize() + + # Track max queue size + if self.stats["current_queue_size"] > self.stats["max_queue_size"]: + self.stats["max_queue_size"] = self.stats["current_queue_size"] + + # Log stats periodically (every 30 seconds) + now = time.time() + if now - self.last_stats_log > 30: + self._log_stats() + self.last_stats_log = now + + async def _process_pipeline(self): + """Process packets through the pipeline.""" + while self.running: + try: + packet = await asyncio.wait_for(self.queue.get(), timeout=0.1) + + start_time = time.time() + await self._process_packet(packet) + processing_time = (time.time() - start_time) * 1000 # ms + + self.stats["total_processed"] += 1 + self.stats["current_queue_size"] = self.queue.qsize() + self.stats["processing_times"].append(processing_time) + + except asyncio.TimeoutError: + continue + except Exception as e: + self.stats["total_errors"] += 1 + logger.error(f"Pipeline error: {e}", exc_info=True) + + async def _process_packet(self, packet): + """ + Process a single packet through the handler pipeline. + + Flow: + 1. Route to specific handler based on payload type + 2. Handler processes and may mark do_not_retransmit + 3. If not marked, pass to repeater for forwarding + """ + payload_type = packet.get_payload_type() + + # Track packet type + type_name = get_packet_type_name(payload_type) + self.stats["packets_by_type"][type_name] = self.stats["packets_by_type"].get(type_name, 0) + 1 + + # Stage 1: Route to specific handlers + if payload_type == TraceHandler.payload_type(): + # Process trace packet + if self.daemon.trace_helper: + await self.daemon.trace_helper.process_trace_packet(packet) + + elif payload_type == ControlHandler.payload_type(): + # Process control/discovery packet + if self.daemon.discovery_helper: + await self.daemon.discovery_helper.control_handler(packet) + packet.mark_do_not_retransmit() + + elif payload_type == AdvertHandler.payload_type(): + # Process advertisement packet for neighbor tracking + if self.daemon.advert_helper: + # Extract metadata for advert processing + rssi = getattr(packet, "rssi", 0) + snr = getattr(packet, "snr", 0.0) + await self.daemon.advert_helper.process_advert_packet(packet, rssi, snr) + + if self.daemon.repeater_handler: + metadata = { + "rssi": getattr(packet, "rssi", 0), + "snr": getattr(packet, "snr", 0.0), + "timestamp": getattr(packet, "timestamp", 0), + } + + # Call process_packet to get validation result and delay + snr = metadata.get("snr", 0.0) + result = self.daemon.repeater_handler.process_packet(packet, snr) + + if result: + fwd_pkt, delay = result + + # Calculate airtime for duty cycle tracking + from pymc_core.protocol.packet_utils import PacketTimingUtils + packet_bytes = fwd_pkt.write_to() if hasattr(fwd_pkt, "write_to") else fwd_pkt.payload or b"" + airtime_ms = PacketTimingUtils.estimate_airtime_ms( + len(packet_bytes), + self.daemon.repeater_handler.radio_config + ) + + # Check duty cycle + can_tx, wait_time = self.daemon.repeater_handler.airtime_mgr.can_transmit(airtime_ms) + + if can_tx: + # Schedule transmission with calculated delay + await self.daemon.repeater_handler.schedule_retransmit(fwd_pkt, delay, airtime_ms) + self.stats["packets_forwarded"] += 1 + logger.debug(f"Packet scheduled for forwarding with {delay:.3f}s delay") + else: + logger.warning( + f"Duty cycle limit exceeded. Airtime={airtime_ms:.1f}ms, " + f"wait={wait_time:.1f}s before retry" + ) + else: + logger.debug(f"Packet rejected by repeater handler: {self.daemon.repeater_handler._last_drop_reason}") + + + def _log_stats(self): + """Log pipeline statistics.""" + avg_processing_time = 0 + if self.stats["processing_times"]: + avg_processing_time = sum(self.stats["processing_times"]) / len(self.stats["processing_times"]) + + logger.info( + f"[Pipeline Stats] Enqueued: {self.stats['total_enqueued']}, " + f"Processed: {self.stats['total_processed']}, " + f"Errors: {self.stats['total_errors']}, " + f"Queue: {self.stats['current_queue_size']}/{self.stats['max_queue_size']} (current/max), " + f"Avg Time: {avg_processing_time:.2f}ms, " + f"Forwarded: {self.stats['packets_forwarded']}, " + f"Marked NoRetx: {self.stats['packets_marked_no_retransmit']}" + ) + + # Log packet type breakdown + if self.stats["packets_by_type"]: + type_breakdown = ", ".join([f"{k}: {v}" for k, v in sorted(self.stats["packets_by_type"].items())]) + logger.debug(f"[Pipeline Types] {type_breakdown}") + + def _log_final_stats(self): + """Log final statistics on shutdown.""" + logger.info("=== Final Pipeline Statistics ===") + self._log_stats() + logger.info("================================") + + def get_stats(self): + """Return current pipeline statistics.""" + stats_copy = self.stats.copy() + if self.stats["processing_times"]: + stats_copy["avg_processing_time_ms"] = sum(self.stats["processing_times"]) / len(self.stats["processing_times"]) + else: + stats_copy["avg_processing_time_ms"] = 0 + # Don't include the deque in the return value + del stats_copy["processing_times"] + return stats_copy