diff --git a/app/repository/messages.py b/app/repository/messages.py index 6949dc1..ab6c160 100644 --- a/app/repository/messages.py +++ b/app/repository/messages.py @@ -293,6 +293,40 @@ class MessageRepository: clause += ")" return clause, params + @staticmethod + def _build_blocked_incoming_clause( + message_alias: str = "", + blocked_keys: list[str] | None = None, + blocked_names: list[str] | None = None, + ) -> tuple[str, list[Any]]: + prefix = f"{message_alias}." if message_alias else "" + blocked_matchers: list[str] = [] + params: list[Any] = [] + + if blocked_keys: + placeholders = ",".join("?" for _ in blocked_keys) + blocked_matchers.append( + f"({prefix}type = 'PRIV' AND LOWER({prefix}conversation_key) IN ({placeholders}))" + ) + params.extend(blocked_keys) + blocked_matchers.append( + f"({prefix}type = 'CHAN' AND {prefix}sender_key IS NOT NULL" + f" AND LOWER({prefix}sender_key) IN ({placeholders}))" + ) + params.extend(blocked_keys) + + if blocked_names: + placeholders = ",".join("?" for _ in blocked_names) + blocked_matchers.append( + f"({prefix}sender_name IS NOT NULL AND {prefix}sender_name IN ({placeholders}))" + ) + params.extend(blocked_names) + + if not blocked_matchers: + return "", [] + + return f"NOT ({prefix}outgoing = 0 AND ({' OR '.join(blocked_matchers)}))", params + @staticmethod def _row_to_message(row: Any) -> Message: """Convert a database row to a Message model.""" @@ -337,25 +371,12 @@ class MessageRepository: ) params: list[Any] = [] - if blocked_keys: - placeholders = ",".join("?" for _ in blocked_keys) - query += ( - f" AND NOT (messages.outgoing=0 AND (" - f"(messages.type='PRIV' AND LOWER(messages.conversation_key) IN ({placeholders}))" - f" OR (messages.type='CHAN' AND messages.sender_key IS NOT NULL" - f" AND LOWER(messages.sender_key) IN ({placeholders}))" - f"))" - ) - params.extend(blocked_keys) - params.extend(blocked_keys) - - if blocked_names: - placeholders = ",".join("?" for _ in blocked_names) - query += ( - f" AND NOT (messages.outgoing=0 AND messages.sender_name IS NOT NULL" - f" AND messages.sender_name IN ({placeholders}))" - ) - params.extend(blocked_names) + blocked_clause, blocked_params = MessageRepository._build_blocked_incoming_clause( + "messages", blocked_keys, blocked_names + ) + if blocked_clause: + query += f" AND {blocked_clause}" + params.extend(blocked_params) if msg_type: query += " AND messages.type = ?" @@ -437,23 +458,12 @@ class MessageRepository: where_parts.append(clause.removeprefix("AND ")) base_params.append(norm_key) - if blocked_keys: - placeholders = ",".join("?" for _ in blocked_keys) - where_parts.append( - f"NOT (outgoing=0 AND (" - f"(type='PRIV' AND LOWER(conversation_key) IN ({placeholders}))" - f" OR (type='CHAN' AND sender_key IS NOT NULL AND LOWER(sender_key) IN ({placeholders}))" - f"))" - ) - base_params.extend(blocked_keys) - base_params.extend(blocked_keys) - - if blocked_names: - placeholders = ",".join("?" for _ in blocked_names) - where_parts.append( - f"NOT (outgoing=0 AND sender_name IS NOT NULL AND sender_name IN ({placeholders}))" - ) - base_params.extend(blocked_names) + blocked_clause, blocked_params = MessageRepository._build_blocked_incoming_clause( + blocked_keys=blocked_keys, blocked_names=blocked_names + ) + if blocked_clause: + where_parts.append(blocked_clause) + base_params.extend(blocked_params) where_sql = " AND ".join(["1=1", *where_parts]) @@ -588,21 +598,10 @@ class MessageRepository: mention_token = f"@[{name}]" if name else None - # Build optional block-list WHERE fragments for channel messages - chan_block_sql = "" - chan_block_params: list[Any] = [] - if blocked_keys: - placeholders = ",".join("?" for _ in blocked_keys) - chan_block_sql += ( - f" AND NOT (m.sender_key IS NOT NULL AND LOWER(m.sender_key) IN ({placeholders}))" - ) - chan_block_params.extend(blocked_keys) - if blocked_names: - placeholders = ",".join("?" for _ in blocked_names) - chan_block_sql += ( - f" AND NOT (m.sender_name IS NOT NULL AND m.sender_name IN ({placeholders}))" - ) - chan_block_params.extend(blocked_names) + blocked_clause, blocked_params = MessageRepository._build_blocked_incoming_clause( + "m", blocked_keys, blocked_names + ) + blocked_sql = f" AND {blocked_clause}" if blocked_clause else "" # Channel unreads cursor = await db.conn.execute( @@ -617,10 +616,10 @@ class MessageRepository: JOIN channels c ON m.conversation_key = c.key WHERE m.type = 'CHAN' AND m.outgoing = 0 AND m.received_at > COALESCE(c.last_read_at, 0) - {chan_block_sql} + {blocked_sql} GROUP BY m.conversation_key """, - (mention_token or "", mention_token or "", *chan_block_params), + (mention_token or "", mention_token or "", *blocked_params), ) rows = await cursor.fetchall() for row in rows: @@ -629,14 +628,6 @@ class MessageRepository: if mention_token and row["has_mention"]: mention_flags[state_key] = True - # Build block-list exclusion for contact (DM) unreads - contact_block_sql = "" - contact_block_params: list[Any] = [] - if blocked_keys: - placeholders = ",".join("?" for _ in blocked_keys) - contact_block_sql += f" AND LOWER(m.conversation_key) NOT IN ({placeholders})" - contact_block_params.extend(blocked_keys) - # Contact unreads cursor = await db.conn.execute( f""" @@ -650,10 +641,10 @@ class MessageRepository: JOIN contacts ct ON m.conversation_key = ct.public_key WHERE m.type = 'PRIV' AND m.outgoing = 0 AND m.received_at > COALESCE(ct.last_read_at, 0) - {contact_block_sql} + {blocked_sql} GROUP BY m.conversation_key """, - (mention_token or "", mention_token or "", *contact_block_params), + (mention_token or "", mention_token or "", *blocked_params), ) rows = await cursor.fetchall() for row in rows: @@ -684,50 +675,10 @@ class MessageRepository: # Last message times for all conversations (including read ones), # excluding blocked incoming traffic so refresh matches live WS behavior. - last_time_filters: list[str] = [] - last_time_params: list[Any] = [] - - if blocked_keys: - placeholders = ",".join("?" for _ in blocked_keys) - last_time_filters.append( - f""" - NOT ( - type = 'PRIV' - AND outgoing = 0 - AND LOWER(conversation_key) IN ({placeholders}) - ) - """ - ) - last_time_params.extend(blocked_keys) - last_time_filters.append( - f""" - NOT ( - type = 'CHAN' - AND outgoing = 0 - AND sender_key IS NOT NULL - AND LOWER(sender_key) IN ({placeholders}) - ) - """ - ) - last_time_params.extend(blocked_keys) - - if blocked_names: - placeholders = ",".join("?" for _ in blocked_names) - last_time_filters.append( - f""" - NOT ( - type = 'CHAN' - AND outgoing = 0 - AND sender_name IS NOT NULL - AND sender_name IN ({placeholders}) - ) - """ - ) - last_time_params.extend(blocked_names) - - last_time_where_sql = ( - f"WHERE {' AND '.join(last_time_filters)}" if last_time_filters else "" + last_time_clause, last_time_params = MessageRepository._build_blocked_incoming_clause( + blocked_keys=blocked_keys, blocked_names=blocked_names ) + last_time_where_sql = f"WHERE {last_time_clause}" if last_time_clause else "" cursor = await db.conn.execute( f""" diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index b91715d..a41d93f 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,6 @@ import { useEffect, useCallback, useRef, useState } from 'react'; import { api } from './api'; +import * as messageCache from './messageCache'; import { takePrefetchOrFetch } from './prefetch'; import { useWebSocket } from './useWebSocket'; import { @@ -72,6 +73,7 @@ export function App() { const messageInputRef = useRef(null); const [rawPackets, setRawPackets] = useState([]); const [channelUnreadMarker, setChannelUnreadMarker] = useState(null); + const [visibilityVersion, setVisibilityVersion] = useState(0); const lastUnreadBackfillAttemptRef = useRef(null); const { notificationsSupported, @@ -231,6 +233,7 @@ export function App() { fetchOlderMessages, fetchNewerMessages, jumpToBottom, + reloadCurrentConversation, addMessageIfNew, updateMessageAck, triggerReconcile, @@ -325,6 +328,28 @@ export function App() { updateMessageAck, notifyIncomingMessage, }); + const handleVisibilityPolicyChanged = useCallback(() => { + messageCache.clear(); + reloadCurrentConversation(); + void refreshUnreads(); + setVisibilityVersion((current) => current + 1); + }, [refreshUnreads, reloadCurrentConversation]); + + const handleBlockKey = useCallback( + async (key: string) => { + await handleToggleBlockedKey(key); + handleVisibilityPolicyChanged(); + }, + [handleToggleBlockedKey, handleVisibilityPolicyChanged] + ); + + const handleBlockName = useCallback( + async (name: string) => { + await handleToggleBlockedName(name); + handleVisibilityPolicyChanged(); + }, + [handleToggleBlockedName, handleVisibilityPolicyChanged] + ); const { handleSendMessage, handleResendChannelMessage, @@ -332,17 +357,12 @@ export function App() { handleSenderClick, handleTrace, handlePathDiscovery, - handleBlockKey, - handleBlockName, } = useConversationActions({ activeConversation, activeConversationRef, setContacts, setChannels, addMessageIfNew, - jumpToBottom, - handleToggleBlockedKey, - handleToggleBlockedName, messageInputRef, }); const handleCreateCrackedChannel = useCallback( @@ -443,6 +463,7 @@ export function App() { const searchProps = { contacts, channels, + visibilityVersion, onNavigateToMessage: handleNavigateToMessage, prefillRequest: searchPrefillRequest, }; diff --git a/frontend/src/components/SearchView.tsx b/frontend/src/components/SearchView.tsx index 8b0e50d..e297a51 100644 --- a/frontend/src/components/SearchView.tsx +++ b/frontend/src/components/SearchView.tsx @@ -31,6 +31,7 @@ export interface SearchNavigateTarget { export interface SearchViewProps { contacts: Contact[]; channels: Channel[]; + visibilityVersion?: number; onNavigateToMessage: (target: SearchNavigateTarget) => void; prefillRequest?: { query: string; @@ -84,6 +85,7 @@ function getHighlightQuery(query: string): string { export function SearchView({ contacts, channels, + visibilityVersion = 0, onNavigateToMessage, prefillRequest = null, }: SearchViewProps) { @@ -110,7 +112,7 @@ export function SearchView({ setResults([]); setOffset(0); setHasMore(false); - }, [debouncedQuery]); + }, [debouncedQuery, visibilityVersion]); useEffect(() => { if (!prefillRequest) { @@ -159,7 +161,7 @@ export function SearchView({ }); return () => controller.abort(); - }, [debouncedQuery]); + }, [debouncedQuery, visibilityVersion]); const loadMore = useCallback(() => { if (!debouncedQuery || loading) return; diff --git a/frontend/src/hooks/useConversationActions.ts b/frontend/src/hooks/useConversationActions.ts index fc74105..17cff93 100644 --- a/frontend/src/hooks/useConversationActions.ts +++ b/frontend/src/hooks/useConversationActions.ts @@ -1,6 +1,5 @@ import { useCallback, type MutableRefObject, type RefObject } from 'react'; import { api } from '../api'; -import * as messageCache from '../messageCache'; import { toast } from '../components/ui/sonner'; import type { MessageInputHandle } from '../components/MessageInput'; import type { Channel, Contact, Conversation, Message, PathDiscoveryResponse } from '../types'; @@ -12,9 +11,6 @@ interface UseConversationActionsArgs { setContacts: React.Dispatch>; setChannels: React.Dispatch>; addMessageIfNew: (msg: Message) => boolean; - jumpToBottom: () => void; - handleToggleBlockedKey: (key: string) => Promise; - handleToggleBlockedName: (name: string) => Promise; messageInputRef: RefObject; } @@ -28,8 +24,6 @@ interface UseConversationActionsResult { handleSenderClick: (sender: string) => void; handleTrace: () => Promise; handlePathDiscovery: (publicKey: string) => Promise; - handleBlockKey: (key: string) => Promise; - handleBlockName: (name: string) => Promise; } export function useConversationActions({ @@ -38,9 +32,6 @@ export function useConversationActions({ setContacts, setChannels, addMessageIfNew, - jumpToBottom, - handleToggleBlockedKey, - handleToggleBlockedName, messageInputRef, }: UseConversationActionsArgs): UseConversationActionsResult { const mergeChannelIntoList = useCallback( @@ -139,24 +130,6 @@ export function useConversationActions({ [setContacts] ); - const handleBlockKey = useCallback( - async (key: string) => { - await handleToggleBlockedKey(key); - messageCache.clear(); - jumpToBottom(); - }, - [handleToggleBlockedKey, jumpToBottom] - ); - - const handleBlockName = useCallback( - async (name: string) => { - await handleToggleBlockedName(name); - messageCache.clear(); - jumpToBottom(); - }, - [handleToggleBlockedName, jumpToBottom] - ); - return { handleSendMessage, handleResendChannelMessage, @@ -164,7 +137,5 @@ export function useConversationActions({ handleSenderClick, handleTrace, handlePathDiscovery, - handleBlockKey, - handleBlockName, }; } diff --git a/frontend/src/hooks/useConversationMessages.ts b/frontend/src/hooks/useConversationMessages.ts index 85e6e98..d097ea8 100644 --- a/frontend/src/hooks/useConversationMessages.ts +++ b/frontend/src/hooks/useConversationMessages.ts @@ -77,6 +77,7 @@ interface UseConversationMessagesResult { fetchOlderMessages: () => Promise; fetchNewerMessages: () => Promise; jumpToBottom: () => void; + reloadCurrentConversation: () => void; addMessageIfNew: (msg: Message) => boolean; updateMessageAck: (messageId: number, ackCount: number, paths?: MessagePath[]) => void; triggerReconcile: () => void; @@ -167,6 +168,8 @@ export function useConversationMessages( const hasOlderMessagesRef = useRef(false); const hasNewerMessagesRef = useRef(false); const prevConversationIdRef = useRef(null); + const prevReloadVersionRef = useRef(0); + const [reloadVersion, setReloadVersion] = useState(0); useEffect(() => { messagesRef.current = messages; @@ -398,6 +401,13 @@ export function useConversationMessages( void fetchLatestMessages(true); }, [activeConversation, fetchLatestMessages]); + const reloadCurrentConversation = useCallback(() => { + if (!isMessageConversation(activeConversation)) return; + setHasNewerMessages(false); + messageCache.remove(activeConversation.id); + setReloadVersion((current) => current + 1); + }, [activeConversation]); + const triggerReconcile = useCallback(() => { if (!isMessageConversation(activeConversation)) return; const controller = new AbortController(); @@ -414,12 +424,14 @@ export function useConversationMessages( const prevId = prevConversationIdRef.current; const newId = activeConversation?.id ?? null; const conversationChanged = prevId !== newId; + const reloadRequested = prevReloadVersionRef.current !== reloadVersion; fetchingConversationIdRef.current = newId; prevConversationIdRef.current = newId; + prevReloadVersionRef.current = reloadVersion; latestReconcileRequestIdRef.current = 0; // Preserve around-loaded context on the same conversation when search clears targetMessageId. - if (!conversationChanged && !targetMessageId) { + if (!conversationChanged && !targetMessageId && !reloadRequested) { return; } @@ -498,7 +510,7 @@ export function useConversationMessages( controller.abort(); }; // eslint-disable-next-line react-hooks/exhaustive-deps - }, [activeConversation?.id, activeConversation?.type, targetMessageId]); + }, [activeConversation?.id, activeConversation?.type, targetMessageId, reloadVersion]); // Add a message if it's new (deduplication) // Returns true if the message was added, false if it was a duplicate @@ -584,6 +596,7 @@ export function useConversationMessages( fetchOlderMessages, fetchNewerMessages, jumpToBottom, + reloadCurrentConversation, addMessageIfNew, updateMessageAck, triggerReconcile, diff --git a/frontend/src/test/appFavorites.test.tsx b/frontend/src/test/appFavorites.test.tsx index 804f0cf..420b1c8 100644 --- a/frontend/src/test/appFavorites.test.tsx +++ b/frontend/src/test/appFavorites.test.tsx @@ -69,6 +69,7 @@ vi.mock('../hooks', async (importOriginal) => { fetchOlderMessages: mocks.hookFns.fetchOlderMessages, fetchNewerMessages: vi.fn(async () => {}), jumpToBottom: vi.fn(), + reloadCurrentConversation: vi.fn(), addMessageIfNew: mocks.hookFns.addMessageIfNew, updateMessageAck: mocks.hookFns.updateMessageAck, triggerReconcile: mocks.hookFns.triggerReconcile, diff --git a/frontend/src/test/appSearchJump.test.tsx b/frontend/src/test/appSearchJump.test.tsx index 7d0a42a..6caf69a 100644 --- a/frontend/src/test/appSearchJump.test.tsx +++ b/frontend/src/test/appSearchJump.test.tsx @@ -42,6 +42,7 @@ vi.mock('../hooks', async (importOriginal) => { fetchOlderMessages: vi.fn(async () => {}), fetchNewerMessages: vi.fn(async () => {}), jumpToBottom: vi.fn(), + reloadCurrentConversation: vi.fn(), addMessageIfNew: vi.fn(), updateMessageAck: vi.fn(), triggerReconcile: vi.fn(), diff --git a/frontend/src/test/appStartupHash.test.tsx b/frontend/src/test/appStartupHash.test.tsx index 9e9a932..797a2fd 100644 --- a/frontend/src/test/appStartupHash.test.tsx +++ b/frontend/src/test/appStartupHash.test.tsx @@ -30,11 +30,18 @@ vi.mock('../hooks', async (importOriginal) => { messagesLoading: false, loadingOlder: false, hasOlderMessages: false, + hasNewerMessages: false, + loadingNewer: false, + hasNewerMessagesRef: { current: false }, setMessages: vi.fn(), fetchMessages: vi.fn(async () => {}), fetchOlderMessages: vi.fn(async () => {}), + fetchNewerMessages: vi.fn(async () => {}), + jumpToBottom: vi.fn(), + reloadCurrentConversation: vi.fn(), addMessageIfNew: vi.fn(), updateMessageAck: vi.fn(), + triggerReconcile: vi.fn(), }), useUnreadCounts: () => ({ unreadCounts: {}, @@ -45,6 +52,7 @@ vi.mock('../hooks', async (importOriginal) => { renameConversationState: vi.fn(), markAllRead: vi.fn(), trackNewMessage: vi.fn(), + refreshUnreads: vi.fn(async () => {}), }), getMessageContentKey: () => 'content-key', }; diff --git a/frontend/src/test/searchView.test.tsx b/frontend/src/test/searchView.test.tsx index 7ffb239..540dede 100644 --- a/frontend/src/test/searchView.test.tsx +++ b/frontend/src/test/searchView.test.tsx @@ -60,6 +60,7 @@ async function typeAndWaitForResults(query: string) { describe('SearchView', () => { beforeEach(() => { vi.clearAllMocks(); + mockGetMessages.mockReset(); }); afterEach(() => { @@ -284,6 +285,33 @@ describe('SearchView', () => { ); }); + it('refetches current results when visibility policy changes', async () => { + mockGetMessages + .mockResolvedValueOnce([createSearchResult({ id: 1, text: 'visible result' })]) + .mockResolvedValueOnce([]); + + const { rerender } = render(); + + await typeAndWaitForResults('visible'); + expect(mockGetMessages).toHaveBeenCalledTimes(1); + expect( + screen.getAllByRole('button').some((button) => button.textContent?.includes('visible result')) + ).toBe(true); + + rerender(); + + await act(async () => { + await new Promise((resolve) => setTimeout(resolve, 0)); + }); + + expect(mockGetMessages).toHaveBeenCalledTimes(2); + expect(mockGetMessages).toHaveBeenLastCalledWith( + expect.objectContaining({ q: 'visible' }), + expect.any(AbortSignal) + ); + expect(screen.getByText(/No messages found/)).toBeInTheDocument(); + }); + it('aborts the load-more request on unmount', async () => { const pageResults = Array.from({ length: 50 }, (_, i) => createSearchResult({ id: i + 1, text: `result ${i}` }) diff --git a/frontend/src/test/useConversationActions.test.ts b/frontend/src/test/useConversationActions.test.ts index 98e927e..5201dbe 100644 --- a/frontend/src/test/useConversationActions.test.ts +++ b/frontend/src/test/useConversationActions.test.ts @@ -13,9 +13,6 @@ const mocks = vi.hoisted(() => ({ sendDirectMessage: vi.fn(), setChannelFloodScopeOverride: vi.fn(), }, - messageCache: { - clear: vi.fn(), - }, toast: { success: vi.fn(), error: vi.fn(), @@ -26,8 +23,6 @@ vi.mock('../api', () => ({ api: mocks.api, })); -vi.mock('../messageCache', () => mocks.messageCache); - vi.mock('../components/ui/sonner', () => ({ toast: mocks.toast, })); @@ -69,9 +64,6 @@ function createArgs(overrides: Partial setContacts: vi.fn(), setChannels: vi.fn(), addMessageIfNew: vi.fn(() => true), - jumpToBottom: vi.fn(), - handleToggleBlockedKey: vi.fn(async () => {}), - handleToggleBlockedName: vi.fn(async () => {}), messageInputRef: { current: { appendText: vi.fn() } }, ...overrides, }; @@ -122,19 +114,6 @@ describe('useConversationActions', () => { expect(args.addMessageIfNew).not.toHaveBeenCalled(); }); - it('clears cached messages and jumps to the latest page after blocking a key', async () => { - const args = createArgs(); - const { result } = renderHook(() => useConversationActions(args)); - - await act(async () => { - await result.current.handleBlockKey('cc'.repeat(32)); - }); - - expect(args.handleToggleBlockedKey).toHaveBeenCalledWith('cc'.repeat(32)); - expect(mocks.messageCache.clear).toHaveBeenCalledTimes(1); - expect(args.jumpToBottom).toHaveBeenCalledTimes(1); - }); - it('appends sender mentions into the message input', () => { const args = createArgs(); const { result } = renderHook(() => useConversationActions(args)); diff --git a/frontend/src/test/useConversationMessages.race.test.ts b/frontend/src/test/useConversationMessages.race.test.ts index ef0accd..dbda153 100644 --- a/frontend/src/test/useConversationMessages.race.test.ts +++ b/frontend/src/test/useConversationMessages.race.test.ts @@ -225,6 +225,36 @@ describe('useConversationMessages conversation switch', () => { expect(result.current.messages[0].conversation_key).toBe('conv_b'); }); + it('reloads the active conversation from source when requested', async () => { + const conv = createConversation(); + mockGetMessages + .mockResolvedValueOnce([ + createMessage({ id: 1, text: 'keep me', sender_timestamp: 1700000000, received_at: 1 }), + createMessage({ + id: 2, + text: 'blocked later', + sender_timestamp: 1700000001, + received_at: 2, + }), + ]) + .mockResolvedValueOnce([ + createMessage({ id: 1, text: 'keep me', sender_timestamp: 1700000000, received_at: 1 }), + ]); + + const { result } = renderHook(() => useConversationMessages(conv)); + + await waitFor(() => expect(result.current.messagesLoading).toBe(false)); + expect(result.current.messages.map((msg) => msg.text)).toEqual(['keep me', 'blocked later']); + + act(() => { + result.current.reloadCurrentConversation(); + }); + + await waitFor(() => expect(mockGetMessages).toHaveBeenCalledTimes(2)); + await waitFor(() => expect(result.current.messagesLoading).toBe(false)); + expect(result.current.messages.map((msg) => msg.text)).toEqual(['keep me']); + }); + it('aborts in-flight fetch when switching conversations', async () => { const convA: Conversation = { type: 'contact', id: 'conv_a', name: 'Contact A' }; const convB: Conversation = { type: 'contact', id: 'conv_b', name: 'Contact B' }; diff --git a/tests/test_block_lists.py b/tests/test_block_lists.py index ee6f12e..371aed2 100644 --- a/tests/test_block_lists.py +++ b/tests/test_block_lists.py @@ -279,6 +279,38 @@ class TestUnreadCountsBlockFiltering: ) assert result["counts"][f"channel-{chan_key}"] == 1 + @pytest.mark.asyncio + async def test_unread_counts_exclude_blocked_name_dms(self, test_db): + """Blocked-name DMs should not contribute to unread counts.""" + blocked_key = "aa" * 32 + normal_key = "bb" * 32 + now = int(time.time()) + + await ContactRepository.upsert({"public_key": blocked_key, "name": "Spammer"}) + await ContactRepository.upsert({"public_key": normal_key, "name": "Friend"}) + + await MessageRepository.create( + msg_type="PRIV", + text="blocked dm", + received_at=now, + conversation_key=blocked_key, + sender_timestamp=now, + sender_name="Spammer", + ) + await MessageRepository.create( + msg_type="PRIV", + text="allowed dm", + received_at=now + 1, + conversation_key=normal_key, + sender_timestamp=now + 1, + sender_name="Friend", + ) + + result = await MessageRepository.get_unread_counts(blocked_names=["Spammer"]) + + assert f"contact-{blocked_key}" not in result["counts"] + assert result["counts"][f"contact-{normal_key}"] == 1 + @pytest.mark.asyncio async def test_unread_counts_no_block_lists_returns_all(self, test_db): """Without block lists, all messages count toward unreads.""" @@ -389,3 +421,34 @@ class TestUnreadCountsBlockFiltering: result = await MessageRepository.get_unread_counts(blocked_names=["Spammer"]) assert result["last_message_times"][f"channel-{chan_key}"] == 1999 + + @pytest.mark.asyncio + async def test_last_message_times_exclude_blocked_name_dms(self, test_db): + """Blocked incoming DM names should not reseed recent-sort timestamps.""" + blocked_key = "aa" * 32 + normal_key = "bb" * 32 + + await ContactRepository.upsert({"public_key": blocked_key, "name": "Spammer"}) + await ContactRepository.upsert({"public_key": normal_key, "name": "Friend"}) + + await MessageRepository.create( + msg_type="PRIV", + text="blocked dm", + received_at=3000, + conversation_key=blocked_key, + sender_timestamp=3000, + sender_name="Spammer", + ) + await MessageRepository.create( + msg_type="PRIV", + text="allowed dm", + received_at=2999, + conversation_key=normal_key, + sender_timestamp=2999, + sender_name="Friend", + ) + + result = await MessageRepository.get_unread_counts(blocked_names=["Spammer"]) + + assert f"contact-{blocked_key}" not in result["last_message_times"] + assert result["last_message_times"][f"contact-{normal_key}"] == 2999