Add sender_key to outgoing and make unread counts respect block list

This commit is contained in:
Jack Kingsman
2026-03-05 10:43:16 -08:00
parent 01a5dc8d93
commit 7715732e69
5 changed files with 244 additions and 8 deletions

View File

@@ -384,11 +384,17 @@ class MessageRepository:
return MessageRepository._row_to_message(row)
@staticmethod
async def get_unread_counts(name: str | None = None) -> dict:
async def get_unread_counts(
name: str | None = None,
blocked_keys: list[str] | None = None,
blocked_names: list[str] | None = None,
) -> dict:
"""Get unread message counts, mention flags, and last message times for all conversations.
Args:
name: User's display name for @[name] mention detection. If None, mentions are skipped.
blocked_keys: Public keys whose messages should be excluded from counts.
blocked_names: Display names whose messages should be excluded from counts.
Returns:
Dict with 'counts', 'mentions', and 'last_message_times' keys.
@@ -399,9 +405,25 @@ class MessageRepository:
mention_token = f"@[{name}]" if name else None
# Build optional block-list WHERE fragments for channel messages
chan_block_sql = ""
chan_block_params: list[Any] = []
if blocked_keys:
placeholders = ",".join("?" for _ in blocked_keys)
chan_block_sql += (
f" AND NOT (m.sender_key IS NOT NULL AND LOWER(m.sender_key) IN ({placeholders}))"
)
chan_block_params.extend(blocked_keys)
if blocked_names:
placeholders = ",".join("?" for _ in blocked_names)
chan_block_sql += (
f" AND NOT (m.sender_name IS NOT NULL AND m.sender_name IN ({placeholders}))"
)
chan_block_params.extend(blocked_names)
# Channel unreads
cursor = await db.conn.execute(
"""
f"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
@@ -412,9 +434,10 @@ class MessageRepository:
JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.outgoing = 0
AND m.received_at > COALESCE(c.last_read_at, 0)
{chan_block_sql}
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or ""),
(mention_token or "", mention_token or "", *chan_block_params),
)
rows = await cursor.fetchall()
for row in rows:
@@ -423,9 +446,17 @@ class MessageRepository:
if mention_token and row["has_mention"]:
mention_flags[state_key] = True
# Build block-list exclusion for contact (DM) unreads
contact_block_sql = ""
contact_block_params: list[Any] = []
if blocked_keys:
placeholders = ",".join("?" for _ in blocked_keys)
contact_block_sql += f" AND LOWER(m.conversation_key) NOT IN ({placeholders})"
contact_block_params.extend(blocked_keys)
# Contact unreads
cursor = await db.conn.execute(
"""
f"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
@@ -436,9 +467,10 @@ class MessageRepository:
JOIN contacts ct ON m.conversation_key = ct.public_key
WHERE m.type = 'PRIV' AND m.outgoing = 0
AND m.received_at > COALESCE(ct.last_read_at, 0)
{contact_block_sql}
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or ""),
(mention_token or "", mention_token or "", *contact_block_params),
)
rows = await cursor.fetchall()
for row in rows:

View File

@@ -239,8 +239,11 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
radio_name: str = ""
text_with_sender: str = request.text
our_public_key: str | None = None
async with radio_manager.radio_operation("send_channel_message") as mc:
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
our_public_key = (mc.self_info.get("public_key") or None) if mc.self_info else None
text_with_sender = f"{radio_name}: {request.text}" if radio_name else request.text
# Load the channel to a temporary radio slot before sending
set_result = await mc.commands.set_channel(
@@ -286,6 +289,7 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
received_at=now,
outgoing=True,
sender_name=radio_name or None,
sender_key=our_public_key,
)
if message_id is None:
raise HTTPException(
@@ -307,6 +311,8 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
received_at=now,
outgoing=True,
acked=0,
sender_name=radio_name or None,
sender_key=our_public_key,
).model_dump(),
)
@@ -325,6 +331,8 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
outgoing=True,
acked=acked_count,
paths=paths,
sender_name=radio_name or None,
sender_key=our_public_key,
)
# Trigger bots for outgoing channel messages (runs in background, doesn't block response)
@@ -404,9 +412,12 @@ async def resend_channel_message(
status_code=400, detail=f"Invalid channel key format: {msg.conversation_key}"
) from None
resend_public_key: str | None = None
async with radio_manager.radio_operation("resend_channel_message") as mc:
# Strip sender prefix: DB stores "RadioName: message" but radio needs "message"
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
resend_public_key = (mc.self_info.get("public_key") or None) if mc.self_info else None
text_to_send = msg.text
if radio_name and text_to_send.startswith(f"{radio_name}: "):
text_to_send = text_to_send[len(f"{radio_name}: ") :]
@@ -442,6 +453,7 @@ async def resend_channel_message(
received_at=now,
outgoing=True,
sender_name=radio_name or None,
sender_key=resend_public_key,
)
if new_msg_id is None:
# Timestamp-second collision (same text+channel within the same second).
@@ -464,6 +476,8 @@ async def resend_channel_message(
received_at=now,
outgoing=True,
acked=0,
sender_name=radio_name or None,
sender_key=resend_public_key,
).model_dump(),
)

View File

@@ -7,7 +7,12 @@ from fastapi import APIRouter
from app.models import UnreadCounts
from app.radio import radio_manager
from app.repository import ChannelRepository, ContactRepository, MessageRepository
from app.repository import (
AppSettingsRepository,
ChannelRepository,
ContactRepository,
MessageRepository,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/read-state", tags=["read-state"])
@@ -26,7 +31,12 @@ async def get_unreads() -> UnreadCounts:
mc = radio_manager.meshcore
if mc and mc.self_info:
name = mc.self_info.get("name") or None
data = await MessageRepository.get_unread_counts(name)
settings = await AppSettingsRepository.get()
blocked_keys = settings.blocked_keys or None
blocked_names = settings.blocked_names or None
data = await MessageRepository.get_unread_counts(
name, blocked_keys=blocked_keys, blocked_names=blocked_names
)
return UnreadCounts(**data)

View File

@@ -4,7 +4,12 @@ import time
import pytest
from app.repository import AppSettingsRepository, MessageRepository
from app.repository import (
AppSettingsRepository,
ChannelRepository,
ContactRepository,
MessageRepository,
)
from app.routers.settings import (
BlockKeyRequest,
BlockNameRequest,
@@ -168,3 +173,140 @@ class TestMessageBlockFiltering:
assert "blocked dm" not in texts
assert "normal dm" in texts
assert "outgoing to blocked" in texts
class TestUnreadCountsBlockFiltering:
"""Unread counts should exclude messages from blocked keys/names."""
@pytest.mark.asyncio
async def test_unread_counts_exclude_blocked_key_dms(self, test_db):
"""Blocked key DMs should not contribute to unread counts."""
blocked_key = "aa" * 32
normal_key = "bb" * 32
now = int(time.time())
# Set up contacts with last_read_at in the past
await ContactRepository.upsert({"public_key": blocked_key, "name": "Blocked"})
await ContactRepository.upsert({"public_key": normal_key, "name": "Normal"})
# Incoming DMs
await MessageRepository.create(
msg_type="PRIV",
text="blocked msg",
received_at=now,
conversation_key=blocked_key,
sender_timestamp=now,
)
await MessageRepository.create(
msg_type="PRIV",
text="normal msg",
received_at=now + 1,
conversation_key=normal_key,
sender_timestamp=now + 1,
)
result = await MessageRepository.get_unread_counts(
blocked_keys=[blocked_key],
)
assert f"contact-{blocked_key}" not in result["counts"]
assert result["counts"][f"contact-{normal_key}"] == 1
@pytest.mark.asyncio
async def test_unread_counts_exclude_blocked_key_channel_msgs(self, test_db):
"""Blocked key channel messages should not contribute to unread counts."""
blocked_key = "aa" * 32
normal_key = "bb" * 32
chan_key = "CC" * 16
now = int(time.time())
await ChannelRepository.upsert(key=chan_key, name="#test")
await ChannelRepository.update_last_read_at(chan_key, 0)
await MessageRepository.create(
msg_type="CHAN",
text="Blocked: spam",
received_at=now,
conversation_key=chan_key,
sender_timestamp=now,
sender_name="Blocked",
sender_key=blocked_key,
)
await MessageRepository.create(
msg_type="CHAN",
text="Normal: hi",
received_at=now + 1,
conversation_key=chan_key,
sender_timestamp=now + 1,
sender_name="Normal",
sender_key=normal_key,
)
result = await MessageRepository.get_unread_counts(
blocked_keys=[blocked_key],
)
assert result["counts"][f"channel-{chan_key}"] == 1
@pytest.mark.asyncio
async def test_unread_counts_exclude_blocked_name_channel_msgs(self, test_db):
"""Blocked name channel messages should not contribute to unread counts."""
chan_key = "DD" * 16
now = int(time.time())
await ChannelRepository.upsert(key=chan_key, name="#test2")
await ChannelRepository.update_last_read_at(chan_key, 0)
await MessageRepository.create(
msg_type="CHAN",
text="Spammer: buy stuff",
received_at=now,
conversation_key=chan_key,
sender_timestamp=now,
sender_name="Spammer",
sender_key="ee" * 32,
)
await MessageRepository.create(
msg_type="CHAN",
text="Friend: hello",
received_at=now + 1,
conversation_key=chan_key,
sender_timestamp=now + 1,
sender_name="Friend",
sender_key="ff" * 32,
)
result = await MessageRepository.get_unread_counts(
blocked_names=["Spammer"],
)
assert result["counts"][f"channel-{chan_key}"] == 1
@pytest.mark.asyncio
async def test_unread_counts_no_block_lists_returns_all(self, test_db):
"""Without block lists, all messages count toward unreads."""
blocked_key = "aa" * 32
chan_key = "CC" * 16
now = int(time.time())
await ContactRepository.upsert({"public_key": blocked_key, "name": "Someone"})
await ChannelRepository.upsert(key=chan_key, name="#all")
await ChannelRepository.update_last_read_at(chan_key, 0)
await MessageRepository.create(
msg_type="PRIV",
text="dm",
received_at=now,
conversation_key=blocked_key,
sender_timestamp=now,
)
await MessageRepository.create(
msg_type="CHAN",
text="Someone: hi",
received_at=now + 1,
conversation_key=chan_key,
sender_timestamp=now + 1,
sender_name="Someone",
sender_key=blocked_key,
)
result = await MessageRepository.get_unread_counts()
assert result["counts"][f"contact-{blocked_key}"] == 1
assert result["counts"][f"channel-{chan_key}"] == 1

View File

@@ -259,6 +259,44 @@ class TestOutgoingChannelBotTrigger:
assert message.id is not None
assert message.acked == 0
@pytest.mark.asyncio
async def test_send_channel_msg_includes_sender_key(self, test_db):
"""Outgoing channel message includes our public key as sender_key."""
our_pubkey = "ab" * 32
mc = _make_mc(name="MyNode")
mc.self_info["public_key"] = our_pubkey
chan_key = "ee" * 16
await ChannelRepository.upsert(key=chan_key, name="#test")
broadcasts = []
def capture_broadcast(event_type, data):
broadcasts.append({"type": event_type, "data": data})
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.decoder.calculate_channel_hash", return_value="abcd"),
patch("app.bot.run_bot_for_message", new=AsyncMock()),
patch("app.routers.messages.broadcast_event", side_effect=capture_broadcast),
):
request = SendChannelMessageRequest(channel_key=chan_key, text="hello")
message = await send_channel_message(request)
# Response message includes sender_key
assert message.sender_key == our_pubkey
assert message.sender_name == "MyNode"
# Broadcast also includes sender_key
msg_broadcasts = [b for b in broadcasts if b["type"] == "message"]
assert len(msg_broadcasts) == 1
assert msg_broadcasts[0]["data"]["sender_key"] == our_pubkey
# DB row also has sender_key
db_msg = await MessageRepository.get_by_id(message.id)
assert db_msg is not None
assert db_msg.sender_key == our_pubkey
class TestResendChannelMessage:
"""Test the user-triggered resend endpoint."""