diff --git a/app/event_handlers.py b/app/event_handlers.py index e0cd6e6..5376b2c 100644 --- a/app/event_handlers.py +++ b/app/event_handlers.py @@ -106,13 +106,16 @@ async def on_contact_message(event: "Event") -> None: ts = payload.get("sender_timestamp") sender_timestamp = ts if ts is not None else received_at sender_name = contact.name if contact else None + path = payload.get("path") + path_len = payload.get("path_len") msg_id = await MessageRepository.create( msg_type="PRIV", text=payload.get("text", ""), conversation_key=sender_pubkey, sender_timestamp=sender_timestamp, received_at=received_at, - path=payload.get("path"), + path=path, + path_len=path_len, txt_type=txt_type, signature=payload.get("signature"), sender_key=sender_pubkey, @@ -129,8 +132,7 @@ async def on_contact_message(event: "Event") -> None: logger.debug("DM from %s handled by event handler (fallback path)", sender_pubkey[:12]) # Build paths array for broadcast - path = payload.get("path") - paths = [MessagePath(path=path or "", received_at=received_at)] if path is not None else None + paths = [MessagePath(path=path or "", received_at=received_at, path_len=path_len)] if path is not None else None # Broadcast the new message broadcast_event( diff --git a/app/models.py b/app/models.py index 2fb62a8..b0f85ad 100644 --- a/app/models.py +++ b/app/models.py @@ -176,8 +176,12 @@ class ChannelDetail(BaseModel): class MessagePath(BaseModel): """A single path that a message took to reach us.""" - path: str = Field(description="Hex-encoded routing path (2 chars per hop)") + path: str = Field(description="Hex-encoded routing path") received_at: int = Field(description="Unix timestamp when this path was received") + path_len: int | None = Field( + default=None, + description="Hop count. None = legacy (infer as len(path)//2, i.e. 1-byte hops)", + ) class Message(BaseModel): diff --git a/app/packet_processor.py b/app/packet_processor.py index 649bdbe..0ea7feb 100644 --- a/app/packet_processor.py +++ b/app/packet_processor.py @@ -58,6 +58,7 @@ async def _handle_duplicate_message( sender_timestamp: int, path: str | None, received: int, + path_len: int | None = None, ) -> None: """Handle a duplicate message by updating paths/acks on the existing record. @@ -90,7 +91,7 @@ async def _handle_duplicate_message( # Add path if provided if path is not None: - paths = await MessageRepository.add_path(existing_msg.id, path, received) + paths = await MessageRepository.add_path(existing_msg.id, path, received, path_len) else: # Get current paths for broadcast paths = existing_msg.paths or [] @@ -128,6 +129,7 @@ async def create_message_from_decrypted( timestamp: int, received_at: int | None = None, path: str | None = None, + path_len: int | None = None, channel_name: str | None = None, realtime: bool = True, ) -> int | None: @@ -172,6 +174,7 @@ async def create_message_from_decrypted( sender_timestamp=timestamp, received_at=received, path=path, + path_len=path_len, sender_name=sender, sender_key=resolved_sender_key, ) @@ -182,7 +185,7 @@ async def create_message_from_decrypted( # 2. Same message arrives via multiple paths before first is committed # In either case, add the path to the existing message. await _handle_duplicate_message( - packet_id, "CHAN", channel_key_normalized, text, timestamp, path, received + packet_id, "CHAN", channel_key_normalized, text, timestamp, path, received, path_len ) return None @@ -193,7 +196,7 @@ async def create_message_from_decrypted( # Build paths array for broadcast # Use "is not None" to include empty string (direct/0-hop messages) - paths = [MessagePath(path=path or "", received_at=received)] if path is not None else None + paths = [MessagePath(path=path or "", received_at=received, path_len=path_len)] if path is not None else None # Broadcast new message to connected clients (and fanout modules when realtime) broadcast_event( @@ -223,6 +226,7 @@ async def create_dm_message_from_decrypted( our_public_key: str | None, received_at: int | None = None, path: str | None = None, + path_len: int | None = None, outgoing: bool = False, realtime: bool = True, ) -> int | None: @@ -270,6 +274,7 @@ async def create_dm_message_from_decrypted( sender_timestamp=decrypted.timestamp, received_at=received, path=path, + path_len=path_len, outgoing=outgoing, sender_key=conversation_key if not outgoing else None, sender_name=sender_name, @@ -285,6 +290,7 @@ async def create_dm_message_from_decrypted( decrypted.timestamp, path, received, + path_len, ) return None @@ -299,7 +305,7 @@ async def create_dm_message_from_decrypted( await RawPacketRepository.mark_decrypted(packet_id, msg_id) # Build paths array for broadcast - paths = [MessagePath(path=path or "", received_at=received)] if path is not None else None + paths = [MessagePath(path=path or "", received_at=received, path_len=path_len)] if path is not None else None # Broadcast new message to connected clients (and fanout modules when realtime) sender_name = contact.name if contact and not outgoing else None @@ -383,6 +389,7 @@ async def run_historical_dm_decryption( # Extract path from the raw packet for storage packet_info = parse_packet(packet_data) path_hex = packet_info.path.hex() if packet_info else None + path_len = packet_info.path_length if packet_info else None msg_id = await create_dm_message_from_decrypted( packet_id=packet_id, @@ -391,6 +398,7 @@ async def run_historical_dm_decryption( our_public_key=our_public_key_bytes.hex(), received_at=packet_timestamp, path=path_hex, + path_len=path_len, outgoing=outgoing, realtime=False, # Historical decryption should not trigger fanout ) @@ -606,6 +614,7 @@ async def _process_group_text( timestamp=decrypted.timestamp, received_at=timestamp, path=packet_info.path.hex() if packet_info else None, + path_len=packet_info.path_length if packet_info else None, ) return { @@ -872,6 +881,7 @@ async def _process_direct_message( our_public_key=our_public_key.hex(), received_at=timestamp, path=packet_info.path.hex() if packet_info else None, + path_len=packet_info.path_length if packet_info else None, outgoing=is_outgoing, ) diff --git a/app/repository/messages.py b/app/repository/messages.py index d3cb977..d0c88d4 100644 --- a/app/repository/messages.py +++ b/app/repository/messages.py @@ -26,6 +26,7 @@ class MessageRepository: conversation_key: str, sender_timestamp: int | None = None, path: str | None = None, + path_len: int | None = None, txt_type: int = 0, signature: str | None = None, outgoing: bool = False, @@ -43,7 +44,10 @@ class MessageRepository: # Convert single path to paths array format paths_json = None if path is not None: - paths_json = json.dumps([{"path": path, "received_at": received_at}]) + entry: dict = {"path": path, "received_at": received_at} + if path_len is not None: + entry["path_len"] = path_len + paths_json = json.dumps([entry]) cursor = await db.conn.execute( """ @@ -74,7 +78,10 @@ class MessageRepository: @staticmethod async def add_path( - message_id: int, path: str, received_at: int | None = None + message_id: int, + path: str, + received_at: int | None = None, + path_len: int | None = None, ) -> list[MessagePath]: """Add a new path to an existing message. @@ -85,7 +92,10 @@ class MessageRepository: # Atomic append: use json_insert to avoid read-modify-write race when # multiple duplicate packets arrive concurrently for the same message. - new_entry = json.dumps({"path": path, "received_at": ts}) + entry: dict = {"path": path, "received_at": ts} + if path_len is not None: + entry["path_len"] = path_len + new_entry = json.dumps(entry) await db.conn.execute( """UPDATE messages SET paths = json_insert( COALESCE(paths, '[]'), '$[#]', json(?) diff --git a/tests/test_echo_dedup.py b/tests/test_echo_dedup.py index 7c94fd3..d4ff0f6 100644 --- a/tests/test_echo_dedup.py +++ b/tests/test_echo_dedup.py @@ -586,6 +586,7 @@ class TestDirectMessageDirectionDetection: packet_info = MagicMock() packet_info.payload = bytes([0xFA, 0xA1, 0x00, 0x00]) + b"\x00" * 20 packet_info.path = b"" + packet_info.path_length = 0 # Create the contact so decryption can find a candidate await ContactRepository.upsert( @@ -637,6 +638,7 @@ class TestDirectMessageDirectionDetection: packet_info = MagicMock() packet_info.payload = bytes([0xFA, 0xA1, 0x00, 0x00]) + b"\x00" * 20 packet_info.path = b"" + packet_info.path_length = 0 await ContactRepository.upsert( { @@ -682,7 +684,7 @@ class TestDirectMessageDirectionDetection: message_broadcasts = [b for b in broadcasts if b["type"] == "message"] assert len(message_broadcasts) == 1 assert message_broadcasts[0]["data"]["paths"] == [ - {"path": "", "received_at": SENDER_TIMESTAMP} + {"path": "", "received_at": SENDER_TIMESTAMP, "path_len": 0} ] @pytest.mark.asyncio @@ -694,6 +696,7 @@ class TestDirectMessageDirectionDetection: # dest_hash=a1 (contact), src_hash=fa (us) packet_info.payload = bytes([0xA1, 0xFA, 0x00, 0x00]) + b"\x00" * 20 packet_info.path = b"" + packet_info.path_length = 0 await ContactRepository.upsert( { @@ -743,6 +746,7 @@ class TestDirectMessageDirectionDetection: # Both dest_hash and src_hash are 0xFA (our first byte) packet_info.payload = bytes([0xFA, 0xFA, 0x00, 0x00]) + b"\x00" * 20 packet_info.path = b"" + packet_info.path_length = 0 # Contact whose first byte also starts with "fa" await ContactRepository.upsert( @@ -793,6 +797,7 @@ class TestDirectMessageDirectionDetection: # Neither byte matches our first byte (0xFA) packet_info.payload = bytes([0x11, 0x22, 0x00, 0x00]) + b"\x00" * 20 packet_info.path = b"" + packet_info.path_length = 0 pkt_id, _ = await RawPacketRepository.create(b"dir_test_none", SENDER_TIMESTAMP)