From 7715732e69a6a13878ae70c263d040e70c5f40ef Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Thu, 5 Mar 2026 10:43:16 -0800 Subject: [PATCH] Add sender_key to outgoing and make unread counts respect block list --- app/repository/messages.py | 42 +++++++++-- app/routers/messages.py | 14 ++++ app/routers/read_state.py | 14 +++- tests/test_block_lists.py | 144 +++++++++++++++++++++++++++++++++++- tests/test_send_messages.py | 38 ++++++++++ 5 files changed, 244 insertions(+), 8 deletions(-) diff --git a/app/repository/messages.py b/app/repository/messages.py index 31364ea..59deb47 100644 --- a/app/repository/messages.py +++ b/app/repository/messages.py @@ -384,11 +384,17 @@ class MessageRepository: return MessageRepository._row_to_message(row) @staticmethod - async def get_unread_counts(name: str | None = None) -> dict: + async def get_unread_counts( + name: str | None = None, + blocked_keys: list[str] | None = None, + blocked_names: list[str] | None = None, + ) -> dict: """Get unread message counts, mention flags, and last message times for all conversations. Args: name: User's display name for @[name] mention detection. If None, mentions are skipped. + blocked_keys: Public keys whose messages should be excluded from counts. + blocked_names: Display names whose messages should be excluded from counts. Returns: Dict with 'counts', 'mentions', and 'last_message_times' keys. @@ -399,9 +405,25 @@ class MessageRepository: mention_token = f"@[{name}]" if name else None + # Build optional block-list WHERE fragments for channel messages + chan_block_sql = "" + chan_block_params: list[Any] = [] + if blocked_keys: + placeholders = ",".join("?" for _ in blocked_keys) + chan_block_sql += ( + f" AND NOT (m.sender_key IS NOT NULL AND LOWER(m.sender_key) IN ({placeholders}))" + ) + chan_block_params.extend(blocked_keys) + if blocked_names: + placeholders = ",".join("?" for _ in blocked_names) + chan_block_sql += ( + f" AND NOT (m.sender_name IS NOT NULL AND m.sender_name IN ({placeholders}))" + ) + chan_block_params.extend(blocked_names) + # Channel unreads cursor = await db.conn.execute( - """ + f""" SELECT m.conversation_key, COUNT(*) as unread_count, SUM(CASE @@ -412,9 +434,10 @@ class MessageRepository: JOIN channels c ON m.conversation_key = c.key WHERE m.type = 'CHAN' AND m.outgoing = 0 AND m.received_at > COALESCE(c.last_read_at, 0) + {chan_block_sql} GROUP BY m.conversation_key """, - (mention_token or "", mention_token or ""), + (mention_token or "", mention_token or "", *chan_block_params), ) rows = await cursor.fetchall() for row in rows: @@ -423,9 +446,17 @@ class MessageRepository: if mention_token and row["has_mention"]: mention_flags[state_key] = True + # Build block-list exclusion for contact (DM) unreads + contact_block_sql = "" + contact_block_params: list[Any] = [] + if blocked_keys: + placeholders = ",".join("?" for _ in blocked_keys) + contact_block_sql += f" AND LOWER(m.conversation_key) NOT IN ({placeholders})" + contact_block_params.extend(blocked_keys) + # Contact unreads cursor = await db.conn.execute( - """ + f""" SELECT m.conversation_key, COUNT(*) as unread_count, SUM(CASE @@ -436,9 +467,10 @@ class MessageRepository: JOIN contacts ct ON m.conversation_key = ct.public_key WHERE m.type = 'PRIV' AND m.outgoing = 0 AND m.received_at > COALESCE(ct.last_read_at, 0) + {contact_block_sql} GROUP BY m.conversation_key """, - (mention_token or "", mention_token or ""), + (mention_token or "", mention_token or "", *contact_block_params), ) rows = await cursor.fetchall() for row in rows: diff --git a/app/routers/messages.py b/app/routers/messages.py index 776ad1c..757b0ba 100644 --- a/app/routers/messages.py +++ b/app/routers/messages.py @@ -239,8 +239,11 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message: radio_name: str = "" text_with_sender: str = request.text + our_public_key: str | None = None + async with radio_manager.radio_operation("send_channel_message") as mc: radio_name = mc.self_info.get("name", "") if mc.self_info else "" + our_public_key = (mc.self_info.get("public_key") or None) if mc.self_info else None text_with_sender = f"{radio_name}: {request.text}" if radio_name else request.text # Load the channel to a temporary radio slot before sending set_result = await mc.commands.set_channel( @@ -286,6 +289,7 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message: received_at=now, outgoing=True, sender_name=radio_name or None, + sender_key=our_public_key, ) if message_id is None: raise HTTPException( @@ -307,6 +311,8 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message: received_at=now, outgoing=True, acked=0, + sender_name=radio_name or None, + sender_key=our_public_key, ).model_dump(), ) @@ -325,6 +331,8 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message: outgoing=True, acked=acked_count, paths=paths, + sender_name=radio_name or None, + sender_key=our_public_key, ) # Trigger bots for outgoing channel messages (runs in background, doesn't block response) @@ -404,9 +412,12 @@ async def resend_channel_message( status_code=400, detail=f"Invalid channel key format: {msg.conversation_key}" ) from None + resend_public_key: str | None = None + async with radio_manager.radio_operation("resend_channel_message") as mc: # Strip sender prefix: DB stores "RadioName: message" but radio needs "message" radio_name = mc.self_info.get("name", "") if mc.self_info else "" + resend_public_key = (mc.self_info.get("public_key") or None) if mc.self_info else None text_to_send = msg.text if radio_name and text_to_send.startswith(f"{radio_name}: "): text_to_send = text_to_send[len(f"{radio_name}: ") :] @@ -442,6 +453,7 @@ async def resend_channel_message( received_at=now, outgoing=True, sender_name=radio_name or None, + sender_key=resend_public_key, ) if new_msg_id is None: # Timestamp-second collision (same text+channel within the same second). @@ -464,6 +476,8 @@ async def resend_channel_message( received_at=now, outgoing=True, acked=0, + sender_name=radio_name or None, + sender_key=resend_public_key, ).model_dump(), ) diff --git a/app/routers/read_state.py b/app/routers/read_state.py index e0b3066..28af10c 100644 --- a/app/routers/read_state.py +++ b/app/routers/read_state.py @@ -7,7 +7,12 @@ from fastapi import APIRouter from app.models import UnreadCounts from app.radio import radio_manager -from app.repository import ChannelRepository, ContactRepository, MessageRepository +from app.repository import ( + AppSettingsRepository, + ChannelRepository, + ContactRepository, + MessageRepository, +) logger = logging.getLogger(__name__) router = APIRouter(prefix="/read-state", tags=["read-state"]) @@ -26,7 +31,12 @@ async def get_unreads() -> UnreadCounts: mc = radio_manager.meshcore if mc and mc.self_info: name = mc.self_info.get("name") or None - data = await MessageRepository.get_unread_counts(name) + settings = await AppSettingsRepository.get() + blocked_keys = settings.blocked_keys or None + blocked_names = settings.blocked_names or None + data = await MessageRepository.get_unread_counts( + name, blocked_keys=blocked_keys, blocked_names=blocked_names + ) return UnreadCounts(**data) diff --git a/tests/test_block_lists.py b/tests/test_block_lists.py index 3931711..7968a02 100644 --- a/tests/test_block_lists.py +++ b/tests/test_block_lists.py @@ -4,7 +4,12 @@ import time import pytest -from app.repository import AppSettingsRepository, MessageRepository +from app.repository import ( + AppSettingsRepository, + ChannelRepository, + ContactRepository, + MessageRepository, +) from app.routers.settings import ( BlockKeyRequest, BlockNameRequest, @@ -168,3 +173,140 @@ class TestMessageBlockFiltering: assert "blocked dm" not in texts assert "normal dm" in texts assert "outgoing to blocked" in texts + + +class TestUnreadCountsBlockFiltering: + """Unread counts should exclude messages from blocked keys/names.""" + + @pytest.mark.asyncio + async def test_unread_counts_exclude_blocked_key_dms(self, test_db): + """Blocked key DMs should not contribute to unread counts.""" + blocked_key = "aa" * 32 + normal_key = "bb" * 32 + now = int(time.time()) + + # Set up contacts with last_read_at in the past + await ContactRepository.upsert({"public_key": blocked_key, "name": "Blocked"}) + await ContactRepository.upsert({"public_key": normal_key, "name": "Normal"}) + + # Incoming DMs + await MessageRepository.create( + msg_type="PRIV", + text="blocked msg", + received_at=now, + conversation_key=blocked_key, + sender_timestamp=now, + ) + await MessageRepository.create( + msg_type="PRIV", + text="normal msg", + received_at=now + 1, + conversation_key=normal_key, + sender_timestamp=now + 1, + ) + + result = await MessageRepository.get_unread_counts( + blocked_keys=[blocked_key], + ) + assert f"contact-{blocked_key}" not in result["counts"] + assert result["counts"][f"contact-{normal_key}"] == 1 + + @pytest.mark.asyncio + async def test_unread_counts_exclude_blocked_key_channel_msgs(self, test_db): + """Blocked key channel messages should not contribute to unread counts.""" + blocked_key = "aa" * 32 + normal_key = "bb" * 32 + chan_key = "CC" * 16 + now = int(time.time()) + + await ChannelRepository.upsert(key=chan_key, name="#test") + await ChannelRepository.update_last_read_at(chan_key, 0) + + await MessageRepository.create( + msg_type="CHAN", + text="Blocked: spam", + received_at=now, + conversation_key=chan_key, + sender_timestamp=now, + sender_name="Blocked", + sender_key=blocked_key, + ) + await MessageRepository.create( + msg_type="CHAN", + text="Normal: hi", + received_at=now + 1, + conversation_key=chan_key, + sender_timestamp=now + 1, + sender_name="Normal", + sender_key=normal_key, + ) + + result = await MessageRepository.get_unread_counts( + blocked_keys=[blocked_key], + ) + assert result["counts"][f"channel-{chan_key}"] == 1 + + @pytest.mark.asyncio + async def test_unread_counts_exclude_blocked_name_channel_msgs(self, test_db): + """Blocked name channel messages should not contribute to unread counts.""" + chan_key = "DD" * 16 + now = int(time.time()) + + await ChannelRepository.upsert(key=chan_key, name="#test2") + await ChannelRepository.update_last_read_at(chan_key, 0) + + await MessageRepository.create( + msg_type="CHAN", + text="Spammer: buy stuff", + received_at=now, + conversation_key=chan_key, + sender_timestamp=now, + sender_name="Spammer", + sender_key="ee" * 32, + ) + await MessageRepository.create( + msg_type="CHAN", + text="Friend: hello", + received_at=now + 1, + conversation_key=chan_key, + sender_timestamp=now + 1, + sender_name="Friend", + sender_key="ff" * 32, + ) + + result = await MessageRepository.get_unread_counts( + blocked_names=["Spammer"], + ) + assert result["counts"][f"channel-{chan_key}"] == 1 + + @pytest.mark.asyncio + async def test_unread_counts_no_block_lists_returns_all(self, test_db): + """Without block lists, all messages count toward unreads.""" + blocked_key = "aa" * 32 + chan_key = "CC" * 16 + now = int(time.time()) + + await ContactRepository.upsert({"public_key": blocked_key, "name": "Someone"}) + await ChannelRepository.upsert(key=chan_key, name="#all") + await ChannelRepository.update_last_read_at(chan_key, 0) + + await MessageRepository.create( + msg_type="PRIV", + text="dm", + received_at=now, + conversation_key=blocked_key, + sender_timestamp=now, + ) + await MessageRepository.create( + msg_type="CHAN", + text="Someone: hi", + received_at=now + 1, + conversation_key=chan_key, + sender_timestamp=now + 1, + sender_name="Someone", + sender_key=blocked_key, + ) + + result = await MessageRepository.get_unread_counts() + assert result["counts"][f"contact-{blocked_key}"] == 1 + assert result["counts"][f"channel-{chan_key}"] == 1 diff --git a/tests/test_send_messages.py b/tests/test_send_messages.py index 8591d8c..dabcf7c 100644 --- a/tests/test_send_messages.py +++ b/tests/test_send_messages.py @@ -259,6 +259,44 @@ class TestOutgoingChannelBotTrigger: assert message.id is not None assert message.acked == 0 + @pytest.mark.asyncio + async def test_send_channel_msg_includes_sender_key(self, test_db): + """Outgoing channel message includes our public key as sender_key.""" + our_pubkey = "ab" * 32 + mc = _make_mc(name="MyNode") + mc.self_info["public_key"] = our_pubkey + chan_key = "ee" * 16 + await ChannelRepository.upsert(key=chan_key, name="#test") + + broadcasts = [] + + def capture_broadcast(event_type, data): + broadcasts.append({"type": event_type, "data": data}) + + with ( + patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + patch("app.decoder.calculate_channel_hash", return_value="abcd"), + patch("app.bot.run_bot_for_message", new=AsyncMock()), + patch("app.routers.messages.broadcast_event", side_effect=capture_broadcast), + ): + request = SendChannelMessageRequest(channel_key=chan_key, text="hello") + message = await send_channel_message(request) + + # Response message includes sender_key + assert message.sender_key == our_pubkey + assert message.sender_name == "MyNode" + + # Broadcast also includes sender_key + msg_broadcasts = [b for b in broadcasts if b["type"] == "message"] + assert len(msg_broadcasts) == 1 + assert msg_broadcasts[0]["data"]["sender_key"] == our_pubkey + + # DB row also has sender_key + db_msg = await MessageRepository.get_by_id(message.id) + assert db_msg is not None + assert db_msg.sender_key == our_pubkey + class TestResendChannelMessage: """Test the user-triggered resend endpoint."""