mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-03-28 17:43:05 +01:00
Add sender_key to outgoing and make unread counts respect block list
This commit is contained in:
@@ -384,11 +384,17 @@ class MessageRepository:
|
||||
return MessageRepository._row_to_message(row)
|
||||
|
||||
@staticmethod
|
||||
async def get_unread_counts(name: str | None = None) -> dict:
|
||||
async def get_unread_counts(
|
||||
name: str | None = None,
|
||||
blocked_keys: list[str] | None = None,
|
||||
blocked_names: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Get unread message counts, mention flags, and last message times for all conversations.
|
||||
|
||||
Args:
|
||||
name: User's display name for @[name] mention detection. If None, mentions are skipped.
|
||||
blocked_keys: Public keys whose messages should be excluded from counts.
|
||||
blocked_names: Display names whose messages should be excluded from counts.
|
||||
|
||||
Returns:
|
||||
Dict with 'counts', 'mentions', and 'last_message_times' keys.
|
||||
@@ -399,9 +405,25 @@ class MessageRepository:
|
||||
|
||||
mention_token = f"@[{name}]" if name else None
|
||||
|
||||
# Build optional block-list WHERE fragments for channel messages
|
||||
chan_block_sql = ""
|
||||
chan_block_params: list[Any] = []
|
||||
if blocked_keys:
|
||||
placeholders = ",".join("?" for _ in blocked_keys)
|
||||
chan_block_sql += (
|
||||
f" AND NOT (m.sender_key IS NOT NULL AND LOWER(m.sender_key) IN ({placeholders}))"
|
||||
)
|
||||
chan_block_params.extend(blocked_keys)
|
||||
if blocked_names:
|
||||
placeholders = ",".join("?" for _ in blocked_names)
|
||||
chan_block_sql += (
|
||||
f" AND NOT (m.sender_name IS NOT NULL AND m.sender_name IN ({placeholders}))"
|
||||
)
|
||||
chan_block_params.extend(blocked_names)
|
||||
|
||||
# Channel unreads
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
f"""
|
||||
SELECT m.conversation_key,
|
||||
COUNT(*) as unread_count,
|
||||
SUM(CASE
|
||||
@@ -412,9 +434,10 @@ class MessageRepository:
|
||||
JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.outgoing = 0
|
||||
AND m.received_at > COALESCE(c.last_read_at, 0)
|
||||
{chan_block_sql}
|
||||
GROUP BY m.conversation_key
|
||||
""",
|
||||
(mention_token or "", mention_token or ""),
|
||||
(mention_token or "", mention_token or "", *chan_block_params),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
@@ -423,9 +446,17 @@ class MessageRepository:
|
||||
if mention_token and row["has_mention"]:
|
||||
mention_flags[state_key] = True
|
||||
|
||||
# Build block-list exclusion for contact (DM) unreads
|
||||
contact_block_sql = ""
|
||||
contact_block_params: list[Any] = []
|
||||
if blocked_keys:
|
||||
placeholders = ",".join("?" for _ in blocked_keys)
|
||||
contact_block_sql += f" AND LOWER(m.conversation_key) NOT IN ({placeholders})"
|
||||
contact_block_params.extend(blocked_keys)
|
||||
|
||||
# Contact unreads
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
f"""
|
||||
SELECT m.conversation_key,
|
||||
COUNT(*) as unread_count,
|
||||
SUM(CASE
|
||||
@@ -436,9 +467,10 @@ class MessageRepository:
|
||||
JOIN contacts ct ON m.conversation_key = ct.public_key
|
||||
WHERE m.type = 'PRIV' AND m.outgoing = 0
|
||||
AND m.received_at > COALESCE(ct.last_read_at, 0)
|
||||
{contact_block_sql}
|
||||
GROUP BY m.conversation_key
|
||||
""",
|
||||
(mention_token or "", mention_token or ""),
|
||||
(mention_token or "", mention_token or "", *contact_block_params),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
|
||||
@@ -239,8 +239,11 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
|
||||
radio_name: str = ""
|
||||
text_with_sender: str = request.text
|
||||
|
||||
our_public_key: str | None = None
|
||||
|
||||
async with radio_manager.radio_operation("send_channel_message") as mc:
|
||||
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
|
||||
our_public_key = (mc.self_info.get("public_key") or None) if mc.self_info else None
|
||||
text_with_sender = f"{radio_name}: {request.text}" if radio_name else request.text
|
||||
# Load the channel to a temporary radio slot before sending
|
||||
set_result = await mc.commands.set_channel(
|
||||
@@ -286,6 +289,7 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
|
||||
received_at=now,
|
||||
outgoing=True,
|
||||
sender_name=radio_name or None,
|
||||
sender_key=our_public_key,
|
||||
)
|
||||
if message_id is None:
|
||||
raise HTTPException(
|
||||
@@ -307,6 +311,8 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
|
||||
received_at=now,
|
||||
outgoing=True,
|
||||
acked=0,
|
||||
sender_name=radio_name or None,
|
||||
sender_key=our_public_key,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
@@ -325,6 +331,8 @@ async def send_channel_message(request: SendChannelMessageRequest) -> Message:
|
||||
outgoing=True,
|
||||
acked=acked_count,
|
||||
paths=paths,
|
||||
sender_name=radio_name or None,
|
||||
sender_key=our_public_key,
|
||||
)
|
||||
|
||||
# Trigger bots for outgoing channel messages (runs in background, doesn't block response)
|
||||
@@ -404,9 +412,12 @@ async def resend_channel_message(
|
||||
status_code=400, detail=f"Invalid channel key format: {msg.conversation_key}"
|
||||
) from None
|
||||
|
||||
resend_public_key: str | None = None
|
||||
|
||||
async with radio_manager.radio_operation("resend_channel_message") as mc:
|
||||
# Strip sender prefix: DB stores "RadioName: message" but radio needs "message"
|
||||
radio_name = mc.self_info.get("name", "") if mc.self_info else ""
|
||||
resend_public_key = (mc.self_info.get("public_key") or None) if mc.self_info else None
|
||||
text_to_send = msg.text
|
||||
if radio_name and text_to_send.startswith(f"{radio_name}: "):
|
||||
text_to_send = text_to_send[len(f"{radio_name}: ") :]
|
||||
@@ -442,6 +453,7 @@ async def resend_channel_message(
|
||||
received_at=now,
|
||||
outgoing=True,
|
||||
sender_name=radio_name or None,
|
||||
sender_key=resend_public_key,
|
||||
)
|
||||
if new_msg_id is None:
|
||||
# Timestamp-second collision (same text+channel within the same second).
|
||||
@@ -464,6 +476,8 @@ async def resend_channel_message(
|
||||
received_at=now,
|
||||
outgoing=True,
|
||||
acked=0,
|
||||
sender_name=radio_name or None,
|
||||
sender_key=resend_public_key,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,12 @@ from fastapi import APIRouter
|
||||
|
||||
from app.models import UnreadCounts
|
||||
from app.radio import radio_manager
|
||||
from app.repository import ChannelRepository, ContactRepository, MessageRepository
|
||||
from app.repository import (
|
||||
AppSettingsRepository,
|
||||
ChannelRepository,
|
||||
ContactRepository,
|
||||
MessageRepository,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/read-state", tags=["read-state"])
|
||||
@@ -26,7 +31,12 @@ async def get_unreads() -> UnreadCounts:
|
||||
mc = radio_manager.meshcore
|
||||
if mc and mc.self_info:
|
||||
name = mc.self_info.get("name") or None
|
||||
data = await MessageRepository.get_unread_counts(name)
|
||||
settings = await AppSettingsRepository.get()
|
||||
blocked_keys = settings.blocked_keys or None
|
||||
blocked_names = settings.blocked_names or None
|
||||
data = await MessageRepository.get_unread_counts(
|
||||
name, blocked_keys=blocked_keys, blocked_names=blocked_names
|
||||
)
|
||||
return UnreadCounts(**data)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,12 @@ import time
|
||||
|
||||
import pytest
|
||||
|
||||
from app.repository import AppSettingsRepository, MessageRepository
|
||||
from app.repository import (
|
||||
AppSettingsRepository,
|
||||
ChannelRepository,
|
||||
ContactRepository,
|
||||
MessageRepository,
|
||||
)
|
||||
from app.routers.settings import (
|
||||
BlockKeyRequest,
|
||||
BlockNameRequest,
|
||||
@@ -168,3 +173,140 @@ class TestMessageBlockFiltering:
|
||||
assert "blocked dm" not in texts
|
||||
assert "normal dm" in texts
|
||||
assert "outgoing to blocked" in texts
|
||||
|
||||
|
||||
class TestUnreadCountsBlockFiltering:
|
||||
"""Unread counts should exclude messages from blocked keys/names."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unread_counts_exclude_blocked_key_dms(self, test_db):
|
||||
"""Blocked key DMs should not contribute to unread counts."""
|
||||
blocked_key = "aa" * 32
|
||||
normal_key = "bb" * 32
|
||||
now = int(time.time())
|
||||
|
||||
# Set up contacts with last_read_at in the past
|
||||
await ContactRepository.upsert({"public_key": blocked_key, "name": "Blocked"})
|
||||
await ContactRepository.upsert({"public_key": normal_key, "name": "Normal"})
|
||||
|
||||
# Incoming DMs
|
||||
await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text="blocked msg",
|
||||
received_at=now,
|
||||
conversation_key=blocked_key,
|
||||
sender_timestamp=now,
|
||||
)
|
||||
await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text="normal msg",
|
||||
received_at=now + 1,
|
||||
conversation_key=normal_key,
|
||||
sender_timestamp=now + 1,
|
||||
)
|
||||
|
||||
result = await MessageRepository.get_unread_counts(
|
||||
blocked_keys=[blocked_key],
|
||||
)
|
||||
assert f"contact-{blocked_key}" not in result["counts"]
|
||||
assert result["counts"][f"contact-{normal_key}"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unread_counts_exclude_blocked_key_channel_msgs(self, test_db):
|
||||
"""Blocked key channel messages should not contribute to unread counts."""
|
||||
blocked_key = "aa" * 32
|
||||
normal_key = "bb" * 32
|
||||
chan_key = "CC" * 16
|
||||
now = int(time.time())
|
||||
|
||||
await ChannelRepository.upsert(key=chan_key, name="#test")
|
||||
await ChannelRepository.update_last_read_at(chan_key, 0)
|
||||
|
||||
await MessageRepository.create(
|
||||
msg_type="CHAN",
|
||||
text="Blocked: spam",
|
||||
received_at=now,
|
||||
conversation_key=chan_key,
|
||||
sender_timestamp=now,
|
||||
sender_name="Blocked",
|
||||
sender_key=blocked_key,
|
||||
)
|
||||
await MessageRepository.create(
|
||||
msg_type="CHAN",
|
||||
text="Normal: hi",
|
||||
received_at=now + 1,
|
||||
conversation_key=chan_key,
|
||||
sender_timestamp=now + 1,
|
||||
sender_name="Normal",
|
||||
sender_key=normal_key,
|
||||
)
|
||||
|
||||
result = await MessageRepository.get_unread_counts(
|
||||
blocked_keys=[blocked_key],
|
||||
)
|
||||
assert result["counts"][f"channel-{chan_key}"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unread_counts_exclude_blocked_name_channel_msgs(self, test_db):
|
||||
"""Blocked name channel messages should not contribute to unread counts."""
|
||||
chan_key = "DD" * 16
|
||||
now = int(time.time())
|
||||
|
||||
await ChannelRepository.upsert(key=chan_key, name="#test2")
|
||||
await ChannelRepository.update_last_read_at(chan_key, 0)
|
||||
|
||||
await MessageRepository.create(
|
||||
msg_type="CHAN",
|
||||
text="Spammer: buy stuff",
|
||||
received_at=now,
|
||||
conversation_key=chan_key,
|
||||
sender_timestamp=now,
|
||||
sender_name="Spammer",
|
||||
sender_key="ee" * 32,
|
||||
)
|
||||
await MessageRepository.create(
|
||||
msg_type="CHAN",
|
||||
text="Friend: hello",
|
||||
received_at=now + 1,
|
||||
conversation_key=chan_key,
|
||||
sender_timestamp=now + 1,
|
||||
sender_name="Friend",
|
||||
sender_key="ff" * 32,
|
||||
)
|
||||
|
||||
result = await MessageRepository.get_unread_counts(
|
||||
blocked_names=["Spammer"],
|
||||
)
|
||||
assert result["counts"][f"channel-{chan_key}"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unread_counts_no_block_lists_returns_all(self, test_db):
|
||||
"""Without block lists, all messages count toward unreads."""
|
||||
blocked_key = "aa" * 32
|
||||
chan_key = "CC" * 16
|
||||
now = int(time.time())
|
||||
|
||||
await ContactRepository.upsert({"public_key": blocked_key, "name": "Someone"})
|
||||
await ChannelRepository.upsert(key=chan_key, name="#all")
|
||||
await ChannelRepository.update_last_read_at(chan_key, 0)
|
||||
|
||||
await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text="dm",
|
||||
received_at=now,
|
||||
conversation_key=blocked_key,
|
||||
sender_timestamp=now,
|
||||
)
|
||||
await MessageRepository.create(
|
||||
msg_type="CHAN",
|
||||
text="Someone: hi",
|
||||
received_at=now + 1,
|
||||
conversation_key=chan_key,
|
||||
sender_timestamp=now + 1,
|
||||
sender_name="Someone",
|
||||
sender_key=blocked_key,
|
||||
)
|
||||
|
||||
result = await MessageRepository.get_unread_counts()
|
||||
assert result["counts"][f"contact-{blocked_key}"] == 1
|
||||
assert result["counts"][f"channel-{chan_key}"] == 1
|
||||
|
||||
@@ -259,6 +259,44 @@ class TestOutgoingChannelBotTrigger:
|
||||
assert message.id is not None
|
||||
assert message.acked == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_channel_msg_includes_sender_key(self, test_db):
|
||||
"""Outgoing channel message includes our public key as sender_key."""
|
||||
our_pubkey = "ab" * 32
|
||||
mc = _make_mc(name="MyNode")
|
||||
mc.self_info["public_key"] = our_pubkey
|
||||
chan_key = "ee" * 16
|
||||
await ChannelRepository.upsert(key=chan_key, name="#test")
|
||||
|
||||
broadcasts = []
|
||||
|
||||
def capture_broadcast(event_type, data):
|
||||
broadcasts.append({"type": event_type, "data": data})
|
||||
|
||||
with (
|
||||
patch("app.routers.messages.require_connected", return_value=mc),
|
||||
patch.object(radio_manager, "_meshcore", mc),
|
||||
patch("app.decoder.calculate_channel_hash", return_value="abcd"),
|
||||
patch("app.bot.run_bot_for_message", new=AsyncMock()),
|
||||
patch("app.routers.messages.broadcast_event", side_effect=capture_broadcast),
|
||||
):
|
||||
request = SendChannelMessageRequest(channel_key=chan_key, text="hello")
|
||||
message = await send_channel_message(request)
|
||||
|
||||
# Response message includes sender_key
|
||||
assert message.sender_key == our_pubkey
|
||||
assert message.sender_name == "MyNode"
|
||||
|
||||
# Broadcast also includes sender_key
|
||||
msg_broadcasts = [b for b in broadcasts if b["type"] == "message"]
|
||||
assert len(msg_broadcasts) == 1
|
||||
assert msg_broadcasts[0]["data"]["sender_key"] == our_pubkey
|
||||
|
||||
# DB row also has sender_key
|
||||
db_msg = await MessageRepository.get_by_id(message.id)
|
||||
assert db_msg is not None
|
||||
assert db_msg.sender_key == our_pubkey
|
||||
|
||||
|
||||
class TestResendChannelMessage:
|
||||
"""Test the user-triggered resend endpoint."""
|
||||
|
||||
Reference in New Issue
Block a user