diff --git a/tests/test_api.py b/tests/test_api.py index 4ed58e7..9211d81 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -670,6 +670,163 @@ class TestReadStateEndpoints: db._connection = original_conn await conn.close() + @pytest.mark.asyncio + async def test_unreads_reset_after_mark_read(self): + """Marking a conversation as read zeroes its unread count; new messages after count again.""" + import aiosqlite + + from app.database import db + from app.repository import MessageRepository + + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + + await conn.execute(""" + CREATE TABLE channels ( + key TEXT PRIMARY KEY, name TEXT NOT NULL, + is_hashtag INTEGER DEFAULT 0, on_radio INTEGER DEFAULT 0, last_read_at INTEGER + ) + """) + await conn.execute(""" + CREATE TABLE contacts ( + public_key TEXT PRIMARY KEY, name TEXT, + type INTEGER DEFAULT 0, flags INTEGER DEFAULT 0, + last_path TEXT, last_path_len INTEGER DEFAULT -1, + last_advert INTEGER, lat REAL, lon REAL, last_seen INTEGER, + on_radio INTEGER DEFAULT 0, last_contacted INTEGER, last_read_at INTEGER + ) + """) + await conn.execute(""" + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, type TEXT NOT NULL, + conversation_key TEXT NOT NULL, text TEXT NOT NULL, + sender_timestamp INTEGER, received_at INTEGER NOT NULL, + paths TEXT, txt_type INTEGER DEFAULT 0, signature TEXT, + outgoing INTEGER DEFAULT 0, acked INTEGER DEFAULT 0, + UNIQUE(type, conversation_key, text, sender_timestamp) + ) + """) + + chan_key = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1" + await conn.execute( + "INSERT INTO channels (key, name, last_read_at) VALUES (?, ?, ?)", + (chan_key, "Public", 1000), + ) + # 2 unread messages (received_at > last_read_at=1000) + await conn.execute( + "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", + ("CHAN", chan_key, "msg1", 1001, 0), + ) + await conn.execute( + "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", + ("CHAN", chan_key, "msg2", 1002, 0), + ) + await conn.commit() + + original_conn = db._connection + db._connection = conn + + try: + # Verify 2 unread + result = await MessageRepository.get_unread_counts(None) + assert result["counts"][f"channel-{chan_key}"] == 2 + + # Simulate mark-read by updating last_read_at to after all messages + await conn.execute( + "UPDATE channels SET last_read_at = ? WHERE key = ?", (1002, chan_key) + ) + await conn.commit() + + # Verify 0 unread + result = await MessageRepository.get_unread_counts(None) + assert result["counts"].get(f"channel-{chan_key}", 0) == 0 + + # New message arrives after the read point + await conn.execute( + "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", + ("CHAN", chan_key, "msg3", 1003, 0), + ) + await conn.commit() + + # Verify exactly 1 unread + result = await MessageRepository.get_unread_counts(None) + assert result["counts"][f"channel-{chan_key}"] == 1 + finally: + db._connection = original_conn + await conn.close() + + @pytest.mark.asyncio + async def test_unreads_exclude_outgoing_messages(self): + """Outgoing messages should never count as unread, even when received_at > last_read_at. + + This is critical: without the outgoing filter, every message we send would + show as an unread badge in the sidebar. + """ + import aiosqlite + + from app.database import db + from app.repository import MessageRepository + + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + + await conn.execute(""" + CREATE TABLE channels ( + key TEXT PRIMARY KEY, name TEXT NOT NULL, + is_hashtag INTEGER DEFAULT 0, on_radio INTEGER DEFAULT 0, last_read_at INTEGER + ) + """) + await conn.execute(""" + CREATE TABLE contacts ( + public_key TEXT PRIMARY KEY, name TEXT, + type INTEGER DEFAULT 0, flags INTEGER DEFAULT 0, + last_path TEXT, last_path_len INTEGER DEFAULT -1, + last_advert INTEGER, lat REAL, lon REAL, last_seen INTEGER, + on_radio INTEGER DEFAULT 0, last_contacted INTEGER, last_read_at INTEGER + ) + """) + await conn.execute(""" + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, type TEXT NOT NULL, + conversation_key TEXT NOT NULL, text TEXT NOT NULL, + sender_timestamp INTEGER, received_at INTEGER NOT NULL, + paths TEXT, txt_type INTEGER DEFAULT 0, signature TEXT, + outgoing INTEGER DEFAULT 0, acked INTEGER DEFAULT 0, + UNIQUE(type, conversation_key, text, sender_timestamp) + ) + """) + + contact_key = "abcd" * 16 + await conn.execute( + "INSERT INTO contacts (public_key, name, last_read_at) VALUES (?, ?, ?)", + (contact_key, "Bob", 1000), + ) + # 1 incoming (should count) + 2 outgoing (should NOT count) + await conn.execute( + "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", + ("PRIV", contact_key, "incoming msg", 1001, 0), + ) + await conn.execute( + "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", + ("PRIV", contact_key, "my reply", 1002, 1), + ) + await conn.execute( + "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", + ("PRIV", contact_key, "another reply", 1003, 1), + ) + await conn.commit() + + original_conn = db._connection + db._connection = conn + + try: + result = await MessageRepository.get_unread_counts(None) + # Only the 1 incoming message should count as unread + assert result["counts"][f"contact-{contact_key}"] == 1 + finally: + db._connection = original_conn + await conn.close() + @pytest.mark.asyncio async def test_mark_all_read_updates_all_conversations(self): """Bulk mark-all-read updates all contacts and channels.""" diff --git a/tests/test_packet_pipeline.py b/tests/test_packet_pipeline.py index 7b72af1..498f45f 100644 --- a/tests/test_packet_pipeline.py +++ b/tests/test_packet_pipeline.py @@ -386,6 +386,58 @@ class TestAdvertisementPipeline: contact = await ContactRepository.get_by_key(test_pubkey) assert contact.last_path_len == 1 # Still the shorter path + @pytest.mark.asyncio + async def test_advertisement_replaces_stale_path_outside_window( + self, test_db, captured_broadcasts + ): + """When existing path is stale (>60s), a new longer path should replace it. + + In a mesh network, a stale short path may no longer be valid (node moved, repeater + went offline). Accepting the new longer path ensures we have a working route. + """ + from app.packet_processor import _process_advertisement + + test_pubkey = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + await ContactRepository.upsert( + { + "public_key": test_pubkey, + "name": "TestNode", + "type": 1, + "last_seen": 1000, + "last_path_len": 1, # Short path + "last_path": "aa", + } + ) + + from unittest.mock import MagicMock + + from app.decoder import ParsedAdvertisement + + broadcasts, mock_broadcast = captured_broadcasts + + # New longer path arriving AFTER 60s window (timestamp 1000 + 61 = 1061) + long_packet_info = MagicMock() + long_packet_info.path_length = 4 + long_packet_info.path = bytes.fromhex("aabbccdd") + long_packet_info.payload = b"" + + with patch("app.packet_processor.broadcast_event", mock_broadcast): + with patch("app.packet_processor.parse_advertisement") as mock_parse: + mock_parse.return_value = ParsedAdvertisement( + public_key=test_pubkey, + name="TestNode", + timestamp=1061, + lat=None, + lon=None, + device_role=1, + ) + await _process_advertisement(b"", timestamp=1061, packet_info=long_packet_info) + + # Verify the longer path replaced the stale shorter one + contact = await ContactRepository.get_by_key(test_pubkey) + assert contact.last_path_len == 4 + assert contact.last_path == "aabbccdd" + class TestAckPipeline: """Test ACK flow: outgoing message → ACK received → broadcast update."""