From f335fc56cc16b3397725956f22e7210ff6e662a4 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 2 Mar 2026 18:02:53 -0800 Subject: [PATCH] Patch up some missing tests and fix+test channel add not clearing on channel submission without add-another checked --- frontend/package-lock.json | 15 + frontend/package.json | 3 +- frontend/src/components/NewMessageModal.tsx | 32 +- .../settings/SettingsRadioSection.tsx | 4 +- frontend/src/test/messageInput.test.tsx | 185 +++++++++ frontend/src/test/newMessageModal.test.tsx | 194 +++++++++ .../src/test/useContactsAndChannels.test.ts | 176 ++++++++ tests/test_event_handlers.py | 55 +++ tests/test_real_crypto.py | 387 ++++++++++++++++++ tests/test_send_messages.py | 84 ++++ tests/test_websocket.py | 47 ++- 11 files changed, 1174 insertions(+), 8 deletions(-) create mode 100644 frontend/src/test/messageInput.test.tsx create mode 100644 frontend/src/test/newMessageModal.test.tsx create mode 100644 frontend/src/test/useContactsAndChannels.test.ts create mode 100644 tests/test_real_crypto.py diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 56c3a22..55825e4 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -38,6 +38,7 @@ "@eslint/js": "^9.17.0", "@testing-library/jest-dom": "^6.6.0", "@testing-library/react": "^16.0.0", + "@testing-library/user-event": "^14.6.1", "@types/d3-force": "^3.0.10", "@types/leaflet": "^1.9.21", "@types/node": "^25.0.3", @@ -1593,6 +1594,20 @@ } } }, + "node_modules/@testing-library/user-event": { + "version": "14.6.1", + "resolved": "https://registry.npmjs.org/@testing-library/user-event/-/user-event-14.6.1.tgz", + "integrity": "sha512-vq7fv0rnt+QTXgPxr5Hjc210p6YKq2kmdziLgnsZGgLJ9e6VAShx1pACLuRjd/AS/sr7phAR58OIIpf0LlmQNw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12", + "npm": ">=6" + }, + "peerDependencies": { + "@testing-library/dom": ">=7.21.4" + } + }, "node_modules/@tweenjs/tween.js": { "version": "23.1.3", "resolved": "https://registry.npmjs.org/@tweenjs/tween.js/-/tween.js-23.1.3.tgz", diff --git a/frontend/package.json b/frontend/package.json index d748c04..d9ded28 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -45,12 +45,13 @@ "@eslint/js": "^9.17.0", "@testing-library/jest-dom": "^6.6.0", "@testing-library/react": "^16.0.0", + "@testing-library/user-event": "^14.6.1", "@types/d3-force": "^3.0.10", "@types/leaflet": "^1.9.21", - "@types/three": "^0.182.0", "@types/node": "^25.0.3", "@types/react": "^18.3.12", "@types/react-dom": "^18.3.1", + "@types/three": "^0.182.0", "@vitejs/plugin-react": "^4.3.3", "autoprefixer": "^10.4.23", "eslint": "^9.17.0", diff --git a/frontend/src/components/NewMessageModal.tsx b/frontend/src/components/NewMessageModal.tsx index 70a1eb6..9f6690a 100644 --- a/frontend/src/components/NewMessageModal.tsx +++ b/frontend/src/components/NewMessageModal.tsx @@ -48,6 +48,15 @@ export function NewMessageModal({ const [loading, setLoading] = useState(false); const hashtagInputRef = useRef(null); + const resetForm = () => { + setName(''); + setContactKey(''); + setRoomKey(''); + setTryHistorical(false); + setPermitCapitals(false); + setError(''); + }; + const handleCreate = async () => { setError(''); setLoading(true); @@ -77,6 +86,7 @@ export function NewMessageModal({ const normalizedName = permitCapitals ? channelName : channelName.toLowerCase(); await onCreateHashtagChannel(`#${normalizedName}`, tryHistorical); } + resetForm(); onClose(); } catch (err) { setError(err instanceof Error ? err.message : 'Failed to create'); @@ -121,7 +131,15 @@ export function NewMessageModal({ const showHistoricalOption = tab !== 'existing' && undecryptedCount > 0; return ( - !isOpen && onClose()}> + { + if (!isOpen) { + resetForm(); + onClose(); + } + }} + > New Conversation @@ -137,8 +155,7 @@ export function NewMessageModal({ value={tab} onValueChange={(v) => { setTab(v as Tab); - setName(''); - setError(''); + resetForm(); }} className="w-full" > @@ -164,6 +181,7 @@ export function NewMessageModal({ id: contact.public_key, name: getContactDisplayName(contact.name, contact.public_key), }); + resetForm(); onClose(); }} > @@ -294,7 +312,13 @@ export function NewMessageModal({ {error &&
{error}
} - {tab === 'hashtag' && ( diff --git a/frontend/src/components/settings/SettingsRadioSection.tsx b/frontend/src/components/settings/SettingsRadioSection.tsx index b04fc38..c0fa566 100644 --- a/frontend/src/components/settings/SettingsRadioSection.tsx +++ b/frontend/src/components/settings/SettingsRadioSection.tsx @@ -113,8 +113,8 @@ export function SettingsRadioSection({ const parsedCr = parseInt(cr, 10); if ( - [parsedLat, parsedLon, parsedTxPower, parsedFreq, parsedBw, parsedSf, parsedCr].some( - (v) => isNaN(v) + [parsedLat, parsedLon, parsedTxPower, parsedFreq, parsedBw, parsedSf, parsedCr].some((v) => + isNaN(v) ) ) { setError('All numeric fields must have valid values'); diff --git a/frontend/src/test/messageInput.test.tsx b/frontend/src/test/messageInput.test.tsx new file mode 100644 index 0000000..8b15648 --- /dev/null +++ b/frontend/src/test/messageInput.test.tsx @@ -0,0 +1,185 @@ +/** + * Tests for MessageInput component. + * + * Verifies character/byte limit calculation, warning states, and send button + * behavior for both DM and channel conversations. + */ + +import { render, screen, fireEvent } from '@testing-library/react'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +import { MessageInput } from '../components/MessageInput'; + +// Mock sonner (toast) +vi.mock('../components/ui/sonner', () => ({ + toast: { success: vi.fn(), error: vi.fn() }, +})); + +const textEncoder = new TextEncoder(); + +function byteLen(s: string): number { + return textEncoder.encode(s).length; +} + +describe('MessageInput', () => { + const onSend = vi.fn().mockResolvedValue(undefined); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + function renderInput(props: { + conversationType?: 'contact' | 'channel' | 'raw'; + senderName?: string; + disabled?: boolean; + }) { + return render( + + ); + } + + function getInput() { + return screen.getByPlaceholderText('Type a message...') as HTMLInputElement; + } + + function getSendButton() { + return screen.getByRole('button', { name: /send/i }) as HTMLButtonElement; + } + + describe('send button state', () => { + it('is disabled when text is empty', () => { + renderInput({ conversationType: 'contact' }); + expect(getSendButton()).toBeDisabled(); + }); + + it('is enabled when text is entered', () => { + renderInput({ conversationType: 'contact' }); + fireEvent.change(getInput(), { target: { value: 'Hello' } }); + expect(getSendButton()).toBeEnabled(); + }); + + it('is disabled when whitespace-only', () => { + renderInput({ conversationType: 'contact' }); + fireEvent.change(getInput(), { target: { value: ' ' } }); + expect(getSendButton()).toBeDisabled(); + }); + + it('is disabled when disabled prop is true', () => { + renderInput({ conversationType: 'contact', disabled: true }); + fireEvent.change(getInput(), { target: { value: 'Hello' } }); + expect(getSendButton()).toBeDisabled(); + }); + }); + + describe('byte counter display', () => { + it('shows byte counter for DM conversations', () => { + renderInput({ conversationType: 'contact' }); + fireEvent.change(getInput(), { target: { value: 'Hello' } }); + + // Should show "5/156" somewhere (DM hard limit = 156) + expect(screen.getByText(/5\/156/)).toBeTruthy(); + }); + + it('shows byte counter for channel conversations', () => { + renderInput({ conversationType: 'channel', senderName: 'MyNode' }); + fireEvent.change(getInput(), { target: { value: 'Hello' } }); + + // Channel hard limit = 156 - byteLen("MyNode") - 2 = 156 - 6 - 2 = 148 + expect(screen.getByText(/5\/148/)).toBeTruthy(); + }); + + it('does not show byte counter for raw conversations', () => { + renderInput({ conversationType: 'raw' }); + fireEvent.change(getInput(), { target: { value: 'Hello' } }); + + // No counter should be visible + expect(screen.queryByText(/\/\d+/)).toBeNull(); + }); + + it('accounts for multi-byte characters in byte count', () => { + renderInput({ conversationType: 'contact' }); + // Emoji: "🥝" is 4 bytes in UTF-8 + fireEvent.change(getInput(), { target: { value: '🥝' } }); + const bytes = byteLen('🥝'); // Should be 4 + expect(bytes).toBe(4); + expect(screen.getByText(new RegExp(`${bytes}/156`))).toBeTruthy(); + }); + }); + + describe('channel limit adjusts for sender name', () => { + it('reduces limit based on sender name byte length', () => { + // Sender name "LongNodeName" = 12 bytes + 2 for ": " = 14 overhead + // Hard limit = 156 - 14 = 142 + renderInput({ conversationType: 'channel', senderName: 'LongNodeName' }); + fireEvent.change(getInput(), { target: { value: 'x' } }); + expect(screen.getByText(/1\/142/)).toBeTruthy(); + }); + + it('uses default 10-byte name when sender name is absent', () => { + // Default: 10 bytes + 2 = 12 overhead. Hard limit = 156 - 12 = 144 + renderInput({ conversationType: 'channel' }); + fireEvent.change(getInput(), { target: { value: 'x' } }); + expect(screen.getByText(/1\/144/)).toBeTruthy(); + }); + + it('handles multi-byte sender names correctly', () => { + // "🥝Node" = 4 + 4 = 8 bytes name + 2 separator = 10 overhead + // Hard limit = 156 - 10 = 146 + const senderName = '🥝Node'; + const nameBytes = byteLen(senderName); + const expectedLimit = 156 - nameBytes - 2; + renderInput({ conversationType: 'channel', senderName }); + fireEvent.change(getInput(), { target: { value: 'x' } }); + expect(screen.getByText(new RegExp(`1/${expectedLimit}`))).toBeTruthy(); + }); + }); + + describe('warning states', () => { + it('shows warning text when exceeding DM warning threshold', () => { + renderInput({ conversationType: 'contact' }); + // DM warning threshold = 140 bytes + const text = 'x'.repeat(141); + fireEvent.change(getInput(), { target: { value: text } }); + // Rendered in both desktop and mobile variants + expect(screen.getAllByText(/may impact multi-repeater hop delivery/).length).toBeGreaterThan( + 0 + ); + }); + + it('shows truncation warning when exceeding DM hard limit', () => { + renderInput({ conversationType: 'contact' }); + // DM hard limit = 156 bytes + const text = 'x'.repeat(157); + fireEvent.change(getInput(), { target: { value: text } }); + // Rendered in both desktop and mobile variants + expect(screen.getAllByText(/likely truncated by radio/).length).toBeGreaterThan(0); + }); + + it('shows no warning for short messages', () => { + renderInput({ conversationType: 'contact' }); + fireEvent.change(getInput(), { target: { value: 'Hello' } }); + expect(screen.queryByText(/truncated/)).toBeNull(); + expect(screen.queryByText(/may impact/)).toBeNull(); + }); + }); + + describe('send button remains enabled past hard limit (current behavior)', () => { + it('does not disable send button when over hard limit', () => { + // NOTE: This documents the current behavior where canSubmit only checks + // text.trim().length > 0, NOT the limit state. This is related to + // hitlist item 1.1 — the send button stays enabled even over the limit. + renderInput({ conversationType: 'contact' }); + const text = 'x'.repeat(200); // Well over 156 byte limit + fireEvent.change(getInput(), { target: { value: text } }); + + // Button is still enabled — canSubmit only checks non-empty text + expect(getSendButton()).toBeEnabled(); + }); + }); +}); diff --git a/frontend/src/test/newMessageModal.test.tsx b/frontend/src/test/newMessageModal.test.tsx new file mode 100644 index 0000000..1b3eb18 --- /dev/null +++ b/frontend/src/test/newMessageModal.test.tsx @@ -0,0 +1,194 @@ +/** + * Tests for NewMessageModal form state reset. + * + * Verifies that form fields are cleared when the modal closes (via Create, + * Cancel, or Dialog dismiss) and when switching tabs. + */ + +import { render, screen, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +import { NewMessageModal } from '../components/NewMessageModal'; +import type { Contact } from '../types'; + +// Mock sonner (toast) +vi.mock('../components/ui/sonner', () => ({ + toast: { success: vi.fn(), error: vi.fn() }, +})); + +const mockContact: Contact = { + public_key: 'aa'.repeat(32), + name: 'Alice', + type: 1, + flags: 0, + last_path: null, + last_path_len: -1, + last_advert: null, + lat: null, + lon: null, + last_seen: null, + on_radio: false, + last_contacted: null, + last_read_at: null, + first_seen: null, +}; + +describe('NewMessageModal form reset', () => { + const onClose = vi.fn(); + const onSelectConversation = vi.fn(); + const onCreateContact = vi.fn().mockResolvedValue(undefined); + const onCreateChannel = vi.fn().mockResolvedValue(undefined); + const onCreateHashtagChannel = vi.fn().mockResolvedValue(undefined); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + function renderModal(open = true) { + return render( + + ); + } + + async function switchToTab(user: ReturnType, name: string) { + await user.click(screen.getByRole('tab', { name })); + } + + describe('hashtag tab', () => { + it('clears name after successful Create', async () => { + const user = userEvent.setup(); + const { unmount } = renderModal(); + await switchToTab(user, 'Hashtag'); + + const input = screen.getByPlaceholderText('channel-name') as HTMLInputElement; + await user.type(input, 'testchan'); + expect(input.value).toBe('testchan'); + + await user.click(screen.getByRole('button', { name: 'Create' })); + + await waitFor(() => { + expect(onCreateHashtagChannel).toHaveBeenCalledWith('#testchan', false); + }); + expect(onClose).toHaveBeenCalled(); + unmount(); + + // Re-render to simulate reopening — state should be reset + renderModal(); + await switchToTab(user, 'Hashtag'); + expect((screen.getByPlaceholderText('channel-name') as HTMLInputElement).value).toBe(''); + }); + + it('clears name when Cancel is clicked', async () => { + const user = userEvent.setup(); + renderModal(); + await switchToTab(user, 'Hashtag'); + + const input = screen.getByPlaceholderText('channel-name') as HTMLInputElement; + await user.type(input, 'mychannel'); + expect(input.value).toBe('mychannel'); + + await user.click(screen.getByRole('button', { name: 'Cancel' })); + expect(onClose).toHaveBeenCalled(); + }); + }); + + describe('new-contact tab', () => { + it('clears name and key after successful Create', async () => { + const user = userEvent.setup(); + renderModal(); + await switchToTab(user, 'Contact'); + + await user.type(screen.getByPlaceholderText('Contact name'), 'Bob'); + await user.type(screen.getByPlaceholderText('64-character hex public key'), 'bb'.repeat(32)); + + await user.click(screen.getByRole('button', { name: 'Create' })); + + await waitFor(() => { + expect(onCreateContact).toHaveBeenCalledWith('Bob', 'bb'.repeat(32), false); + }); + expect(onClose).toHaveBeenCalled(); + }); + }); + + describe('new-room tab', () => { + it('clears name and key after successful Create', async () => { + const user = userEvent.setup(); + renderModal(); + await switchToTab(user, 'Room'); + + await user.type(screen.getByPlaceholderText('Room name'), 'MyRoom'); + await user.type(screen.getByPlaceholderText('Pre-shared key (hex)'), 'cc'.repeat(16)); + + await user.click(screen.getByRole('button', { name: 'Create' })); + + await waitFor(() => { + expect(onCreateChannel).toHaveBeenCalledWith('MyRoom', 'cc'.repeat(16), false); + }); + expect(onClose).toHaveBeenCalled(); + }); + }); + + describe('tab switching resets form', () => { + it('clears contact fields when switching to room tab', async () => { + const user = userEvent.setup(); + renderModal(); + await switchToTab(user, 'Contact'); + + await user.type(screen.getByPlaceholderText('Contact name'), 'Bob'); + await user.type(screen.getByPlaceholderText('64-character hex public key'), 'deadbeef'); + + // Switch to Room tab — fields should reset + await switchToTab(user, 'Room'); + + expect((screen.getByPlaceholderText('Room name') as HTMLInputElement).value).toBe(''); + expect((screen.getByPlaceholderText('Pre-shared key (hex)') as HTMLInputElement).value).toBe( + '' + ); + }); + + it('clears room fields when switching to hashtag tab', async () => { + const user = userEvent.setup(); + renderModal(); + await switchToTab(user, 'Room'); + + await user.type(screen.getByPlaceholderText('Room name'), 'SecretRoom'); + await user.type(screen.getByPlaceholderText('Pre-shared key (hex)'), 'ff'.repeat(16)); + + await switchToTab(user, 'Hashtag'); + + expect((screen.getByPlaceholderText('channel-name') as HTMLInputElement).value).toBe(''); + }); + }); + + describe('tryHistorical checkbox resets', () => { + it('resets tryHistorical when switching tabs', async () => { + const user = userEvent.setup(); + renderModal(); + await switchToTab(user, 'Hashtag'); + + // Check the "Try decrypting" checkbox + const checkbox = screen.getByRole('checkbox', { name: /Try decrypting/ }); + await user.click(checkbox); + + // The streaming message should appear + expect(screen.getByText(/Messages will stream in/)).toBeTruthy(); + + // Switch tab and come back + await switchToTab(user, 'Contact'); + await switchToTab(user, 'Hashtag'); + + // The streaming message should be gone (tryHistorical was reset) + expect(screen.queryByText(/Messages will stream in/)).toBeNull(); + }); + }); +}); diff --git a/frontend/src/test/useContactsAndChannels.test.ts b/frontend/src/test/useContactsAndChannels.test.ts new file mode 100644 index 0000000..c5a2a09 --- /dev/null +++ b/frontend/src/test/useContactsAndChannels.test.ts @@ -0,0 +1,176 @@ +/** + * Tests for useContactsAndChannels hook. + * + * Focuses on pagination logic in fetchAllContacts (which fetches 1000 items + * per page and continues until a page returns fewer than pageSize results). + */ + +import { act, renderHook } from '@testing-library/react'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; + +import { useContactsAndChannels } from '../hooks/useContactsAndChannels'; +import type { Contact } from '../types'; + +// Mock api module +vi.mock('../api', () => ({ + api: { + getContacts: vi.fn(), + getChannels: vi.fn(), + createContact: vi.fn(), + createChannel: vi.fn(), + deleteContact: vi.fn(), + deleteChannel: vi.fn(), + decryptHistoricalPackets: vi.fn(), + getUndecryptedPacketCount: vi.fn(), + }, +})); + +// Mock prefetch — takePrefetchOrFetch calls the fetcher directly +vi.mock('../prefetch', () => ({ + takePrefetchOrFetch: vi.fn((_key: string, fetcher: () => Promise) => fetcher()), +})); + +// Mock sonner +vi.mock('../components/ui/sonner', () => ({ + toast: { success: vi.fn(), error: vi.fn() }, +})); + +// Mock messageCache +vi.mock('../messageCache', () => ({ + remove: vi.fn(), +})); + +function makeContact(suffix: string): Contact { + const key = suffix.padStart(64, '0'); + return { + public_key: key, + name: `Contact-${suffix}`, + type: 1, + flags: 0, + last_path: null, + last_path_len: -1, + last_advert: null, + lat: null, + lon: null, + last_seen: null, + on_radio: false, + last_contacted: null, + last_read_at: null, + first_seen: null, + }; +} + +function makeContacts(count: number, startIndex = 0): Contact[] { + return Array.from({ length: count }, (_, i) => + makeContact(String(startIndex + i).padStart(4, '0')) + ); +} + +describe('useContactsAndChannels', () => { + const setActiveConversation = vi.fn(); + const pendingDeleteFallbackRef = { current: false }; + const hasSetDefaultConversation = { current: false }; + + beforeEach(() => { + vi.clearAllMocks(); + pendingDeleteFallbackRef.current = false; + hasSetDefaultConversation.current = false; + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + function renderUseContactsAndChannels() { + return renderHook(() => + useContactsAndChannels({ + setActiveConversation, + pendingDeleteFallbackRef, + hasSetDefaultConversation, + }) + ); + } + + describe('fetchAllContacts pagination', () => { + it('returns contacts directly when fewer than page size', async () => { + const { api } = await import('../api'); + const contacts = makeContacts(50); + vi.mocked(api.getContacts).mockResolvedValueOnce(contacts); + + const { result } = renderUseContactsAndChannels(); + + let fetched: Contact[] = []; + await act(async () => { + fetched = await result.current.fetchAllContacts(); + }); + + expect(fetched).toHaveLength(50); + // Should only call once (no pagination needed) + expect(api.getContacts).toHaveBeenCalledTimes(1); + expect(api.getContacts).toHaveBeenCalledWith(1000, 0); + }); + + it('paginates when first page returns exactly page size', async () => { + const { api } = await import('../api'); + const page1 = makeContacts(1000, 0); + const page2 = makeContacts(200, 1000); + + vi.mocked(api.getContacts) + .mockResolvedValueOnce(page1) // First page: full + .mockResolvedValueOnce(page2); // Second page: partial (done) + + const { result } = renderUseContactsAndChannels(); + + let fetched: Contact[] = []; + await act(async () => { + fetched = await result.current.fetchAllContacts(); + }); + + expect(fetched).toHaveLength(1200); + expect(api.getContacts).toHaveBeenCalledTimes(2); + expect(api.getContacts).toHaveBeenNthCalledWith(1, 1000, 0); + expect(api.getContacts).toHaveBeenNthCalledWith(2, 1000, 1000); + }); + + it('paginates through multiple full pages', async () => { + const { api } = await import('../api'); + const page1 = makeContacts(1000, 0); + const page2 = makeContacts(1000, 1000); + const page3 = makeContacts(500, 2000); + + vi.mocked(api.getContacts) + .mockResolvedValueOnce(page1) + .mockResolvedValueOnce(page2) + .mockResolvedValueOnce(page3); + + const { result } = renderUseContactsAndChannels(); + + let fetched: Contact[] = []; + await act(async () => { + fetched = await result.current.fetchAllContacts(); + }); + + expect(fetched).toHaveLength(2500); + expect(api.getContacts).toHaveBeenCalledTimes(3); + expect(api.getContacts).toHaveBeenNthCalledWith(3, 1000, 2000); + }); + + it('handles exactly page size total (boundary case)', async () => { + const { api } = await import('../api'); + const page1 = makeContacts(1000, 0); + const page2: Contact[] = []; // Empty second page + + vi.mocked(api.getContacts).mockResolvedValueOnce(page1).mockResolvedValueOnce(page2); + + const { result } = renderUseContactsAndChannels(); + + let fetched: Contact[] = []; + await act(async () => { + fetched = await result.current.fetchAllContacts(); + }); + + expect(fetched).toHaveLength(1000); + expect(api.getContacts).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/tests/test_event_handlers.py b/tests/test_event_handlers.py index 6077a9e..8da5a61 100644 --- a/tests/test_event_handlers.py +++ b/tests/test_event_handlers.py @@ -477,6 +477,61 @@ class TestContactMessageCLIFiltering: assert len(messages) == 0 +class TestContactMessageDBErrorResilience: + """Test that DB errors in on_contact_message propagate without crashing silently.""" + + @pytest.mark.asyncio + async def test_db_error_in_create_propagates(self, test_db): + """When MessageRepository.create raises, the exception propagates. + + If this handler silently swallowed DB errors, messages would be lost + without any indication. The exception should propagate so the caller + (MeshCore event dispatcher) can handle it. + """ + from app.event_handlers import on_contact_message + + class MockEvent: + payload = { + "pubkey_prefix": "abc123def456", + "text": "DB will fail", + "txt_type": 0, + "sender_timestamp": 1700000000, + } + + with ( + patch("app.event_handlers.broadcast_event"), + patch.object( + MessageRepository, + "create", + side_effect=Exception("database is locked"), + ), + ): + with pytest.raises(Exception, match="database is locked"): + await on_contact_message(MockEvent()) + + @pytest.mark.asyncio + async def test_db_error_in_contact_lookup_propagates(self, test_db): + """When ContactRepository.get_by_key_or_prefix raises an unexpected error, + it propagates rather than being silently swallowed.""" + from app.event_handlers import on_contact_message + + class MockEvent: + payload = { + "public_key": "ab" * 32, + "text": "Lookup will fail", + "txt_type": 0, + "sender_timestamp": 1700000000, + } + + with patch.object( + ContactRepository, + "get_by_key_or_prefix", + side_effect=RuntimeError("connection pool exhausted"), + ): + with pytest.raises(RuntimeError, match="connection pool exhausted"): + await on_contact_message(MockEvent()) + + class TestEventHandlerRegistration: """Test event handler registration and cleanup.""" diff --git a/tests/test_real_crypto.py b/tests/test_real_crypto.py new file mode 100644 index 0000000..24d26f1 --- /dev/null +++ b/tests/test_real_crypto.py @@ -0,0 +1,387 @@ +"""Tests using real MeshCore packet data and cryptographic keys. + +These tests verify the decryption pipeline end-to-end with actual radio packets +captured from the mesh network. No crypto functions are mocked. + +Test data: + - Client 1 ("a1b2c3d3"): sender of the DM + - Client 2 ("face1233"): receiver of the DM + - Channel: #six77 (hashtag room, key derived from SHA-256 of name) +""" + +from hashlib import sha256 +from unittest.mock import patch + +import pytest + +from app.decoder import ( + DecryptedDirectMessage, + PayloadType, + RouteType, + decrypt_direct_message, + derive_public_key, + derive_shared_secret, + parse_packet, + try_decrypt_dm, + try_decrypt_packet_with_channel_key, +) +from app.repository import ContactRepository, MessageRepository, RawPacketRepository + +# --------------------------------------------------------------------------- +# Real test data captured from a MeshCore mesh network +# --------------------------------------------------------------------------- + +# Client 1 (sender of the DM) +CLIENT1_PUBLIC_HEX = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7" +CLIENT1_PRIVATE_HEX = ( + "1808C3512F063796E492B9FA101A7A6239F14E71F8D1D5AD086E8E228ED0A076" + "D5ED26C82C6E64ABF1954336E42CF68E4AB288A4D38E40ED0F5870FED95C1DEB" +) +CLIENT1_PUBLIC = bytes.fromhex(CLIENT1_PUBLIC_HEX) +CLIENT1_PRIVATE = bytes.fromhex(CLIENT1_PRIVATE_HEX) + +# Client 2 (receiver of the DM) +CLIENT2_PUBLIC_HEX = "face123334789e2b81519afdbc39a3c9eb7ea3457ad367d3243597a484847e46" +CLIENT2_PRIVATE_HEX = ( + "58BA1940E97099CBB4357C62CE9C7F4B245C94C90D722E67201B989F9FEACF7B" + "77ACADDB84438514022BDB0FC3140C2501859BE1772AC7B8C7E41DC0F40490A1" +) +CLIENT2_PUBLIC = bytes.fromhex(CLIENT2_PUBLIC_HEX) +CLIENT2_PRIVATE = bytes.fromhex(CLIENT2_PRIVATE_HEX) + +# DM packet: client 1 -> client 2 +DM_PACKET_HEX = "0900FAA1295471ADB44A98B13CA528A4B5C4FBC29B4DA3CED477519B2FBD8FD5467C31E5D58B" +DM_PACKET = bytes.fromhex(DM_PACKET_HEX) +DM_PLAINTEXT = "Hello there, Mr. Face!" + +# Channel message in #six77 +CHANNEL_PACKET_HEX = ( + "1500E69C7A89DD0AF6A2D69F5823B88F9720731E4B887C56932BF889255D8D926D" + "99195927144323A42DD8A158F878B518B8304DF55E80501C7D02A9FFD578D35182" + "83156BBA257BF8413E80A237393B2E4149BBBC864371140A9BBC4E23EB9BF203EF" + "0D029214B3E3AAC3C0295690ACDB89A28619E7E5F22C83E16073AD679D25FA904D" + "07E5ACF1DB5A7C77D7E1719FB9AE5BF55541EE0D7F59ED890E12CF0FEED6700818" +) +CHANNEL_PACKET = bytes.fromhex(CHANNEL_PACKET_HEX) +CHANNEL_NAME = "#six77" +CHANNEL_KEY = sha256(CHANNEL_NAME.encode("utf-8")).digest()[:16] +CHANNEL_PLAINTEXT_FULL = ( + "Flightless🥝: hello there; this hashtag room is essentially public. " + "MeshCore has great crypto; use private rooms or DMs for private comms instead!" +) +CHANNEL_SENDER = "Flightless🥝" +CHANNEL_MESSAGE_BODY = ( + "hello there; this hashtag room is essentially public. " + "MeshCore has great crypto; use private rooms or DMs for private comms instead!" +) + + +# ============================================================================ +# Direct Message Decryption +# ============================================================================ + + +class TestDMDecryption: + """Test DM decryption using real captured packet data.""" + + def test_derive_public_key_from_private(self): + """derive_public_key reproduces known public keys from private keys.""" + assert derive_public_key(CLIENT1_PRIVATE) == CLIENT1_PUBLIC + assert derive_public_key(CLIENT2_PRIVATE) == CLIENT2_PUBLIC + + def test_shared_secret_is_symmetric(self): + """Both parties derive the same ECDH shared secret.""" + secret_1to2 = derive_shared_secret(CLIENT1_PRIVATE, CLIENT2_PUBLIC) + secret_2to1 = derive_shared_secret(CLIENT2_PRIVATE, CLIENT1_PUBLIC) + assert secret_1to2 == secret_2to1 + + def test_parse_dm_packet_header(self): + """Raw DM packet parses to the expected header fields.""" + info = parse_packet(DM_PACKET) + assert info is not None + assert info.route_type == RouteType.FLOOD + assert info.payload_type == PayloadType.TEXT_MESSAGE + assert info.path_length == 0 + + def test_decrypt_dm_as_receiver(self): + """Receiver (face1233) decrypts the DM with correct plaintext.""" + result = try_decrypt_dm( + DM_PACKET, + our_private_key=CLIENT2_PRIVATE, + their_public_key=CLIENT1_PUBLIC, + our_public_key=CLIENT2_PUBLIC, + ) + assert result is not None + assert isinstance(result, DecryptedDirectMessage) + assert result.message == DM_PLAINTEXT + + def test_decrypt_dm_as_sender(self): + """Sender (a1b2c3d3) decrypts the DM too (outgoing echo scenario).""" + result = try_decrypt_dm( + DM_PACKET, + our_private_key=CLIENT1_PRIVATE, + their_public_key=CLIENT2_PUBLIC, + our_public_key=CLIENT1_PUBLIC, + ) + assert result is not None + assert result.message == DM_PLAINTEXT + + def test_direction_hashes_match_key_prefixes(self): + """dest_hash and src_hash correspond to first bytes of public keys.""" + result = try_decrypt_dm( + DM_PACKET, + our_private_key=CLIENT2_PRIVATE, + their_public_key=CLIENT1_PUBLIC, + our_public_key=CLIENT2_PUBLIC, + ) + assert result is not None + # Packet was sent FROM client1 TO client2 + assert result.src_hash == format(CLIENT1_PUBLIC[0], "02x") # a1 + assert result.dest_hash == format(CLIENT2_PUBLIC[0], "02x") # fa + + def test_wrong_key_fails_mac(self): + """Decryption with an unrelated key fails (MAC mismatch).""" + wrong_private = b"\x01" * 64 + result = try_decrypt_dm( + DM_PACKET, + our_private_key=wrong_private, + their_public_key=CLIENT1_PUBLIC, + ) + assert result is None + + def test_decrypt_dm_payload_directly(self): + """decrypt_direct_message works with just the payload and shared secret.""" + info = parse_packet(DM_PACKET) + assert info is not None + + shared = derive_shared_secret(CLIENT2_PRIVATE, CLIENT1_PUBLIC) + result = decrypt_direct_message(info.payload, shared) + assert result is not None + assert result.message == DM_PLAINTEXT + assert result.timestamp > 0 + + +# ============================================================================ +# Channel Message Decryption +# ============================================================================ + + +class TestChannelDecryption: + """Test channel message decryption using real captured packet data.""" + + def test_parse_channel_packet_header(self): + """Raw channel packet parses to GROUP_TEXT.""" + info = parse_packet(CHANNEL_PACKET) + assert info is not None + assert info.payload_type == PayloadType.GROUP_TEXT + + def test_decrypt_channel_message(self): + """Channel message decrypts to expected sender and body.""" + result = try_decrypt_packet_with_channel_key(CHANNEL_PACKET, CHANNEL_KEY) + assert result is not None + assert result.sender == CHANNEL_SENDER + assert result.message == CHANNEL_MESSAGE_BODY + + def test_full_text_reconstructed(self): + """Reconstructed 'sender: message' matches the original plaintext.""" + result = try_decrypt_packet_with_channel_key(CHANNEL_PACKET, CHANNEL_KEY) + assert result is not None + full = f"{result.sender}: {result.message}" + assert full == CHANNEL_PLAINTEXT_FULL + + def test_channel_hash_matches_packet(self): + """Channel hash in packet matches hash computed from key.""" + from app.decoder import calculate_channel_hash + + info = parse_packet(CHANNEL_PACKET) + assert info is not None + packet_hash = format(info.payload[0], "02x") + expected_hash = calculate_channel_hash(CHANNEL_KEY) + assert packet_hash == expected_hash + + def test_wrong_channel_key_fails(self): + """Decryption with a different channel key returns None.""" + wrong_key = b"\x00" * 16 + result = try_decrypt_packet_with_channel_key(CHANNEL_PACKET, wrong_key) + assert result is None + + def test_hashtag_key_derivation(self): + """Hashtag channel key is SHA-256(name)[:16], matching radio firmware.""" + key = sha256(b"#six77").digest()[:16] + assert len(key) == 16 + # Key should decrypt our packet + result = try_decrypt_packet_with_channel_key(CHANNEL_PACKET, key) + assert result is not None + + +# ============================================================================ +# Historical DM Decryption Pipeline (Integration) +# ============================================================================ + + +class TestHistoricalDMDecryptionPipeline: + """Integration test: store a real DM packet, run historical decryption, + verify correct message and direction end up in the DB.""" + + @pytest.mark.asyncio + async def test_historical_decrypt_stores_incoming_dm(self, test_db, captured_broadcasts): + """run_historical_dm_decryption decrypts a real packet and stores it + with the correct direction (incoming from client1 to client2).""" + from app.packet_processor import run_historical_dm_decryption + + # Store the undecrypted raw packet (message_id=NULL means undecrypted) + pkt_id, _ = await RawPacketRepository.create(DM_PACKET, 1700000000) + + # Add client1 as a known contact + await ContactRepository.upsert( + { + "public_key": CLIENT1_PUBLIC_HEX, + "name": "Client1", + "type": 1, + } + ) + + broadcasts, mock_broadcast = captured_broadcasts + + with patch("app.packet_processor.broadcast_event", mock_broadcast): + # Decrypt as client2 (the receiver) + await run_historical_dm_decryption( + private_key_bytes=CLIENT2_PRIVATE, + contact_public_key_bytes=CLIENT1_PUBLIC, + contact_public_key_hex=CLIENT1_PUBLIC_HEX, + display_name="Client1", + ) + + # Verify the message was stored + messages = await MessageRepository.get_all( + msg_type="PRIV", conversation_key=CLIENT1_PUBLIC_HEX.lower(), limit=10 + ) + assert len(messages) == 1 + + msg = messages[0] + assert msg.text == DM_PLAINTEXT + assert msg.outgoing is False # We are client2, message is FROM client1 + assert msg.type == "PRIV" + + # Verify a message broadcast was sent + msg_broadcasts = [b for b in broadcasts if b["type"] == "message"] + assert len(msg_broadcasts) == 1 + assert msg_broadcasts[0]["data"]["text"] == DM_PLAINTEXT + assert msg_broadcasts[0]["data"]["outgoing"] is False + + @pytest.mark.asyncio + async def test_historical_decrypt_skips_outgoing_by_design(self, test_db, captured_broadcasts): + """Historical decryption skips outgoing DMs (they're stored by the send endpoint). + + run_historical_dm_decryption passes our_public_key=None, which disables + the outbound hash check. When our first byte differs from the contact's + (255/256 cases), outgoing packets fail the inbound src_hash check and + are skipped — this is correct behavior. + """ + from app.packet_processor import run_historical_dm_decryption + + await RawPacketRepository.create(DM_PACKET, 1700000000) + + await ContactRepository.upsert( + { + "public_key": CLIENT2_PUBLIC_HEX, + "name": "Client2", + "type": 1, + } + ) + + broadcasts, mock_broadcast = captured_broadcasts + + with patch("app.packet_processor.broadcast_event", mock_broadcast): + # Decrypt as client1 (the sender) — first bytes differ (a1 != fa) + # so historical decryption correctly skips this outgoing packet + await run_historical_dm_decryption( + private_key_bytes=CLIENT1_PRIVATE, + contact_public_key_bytes=CLIENT2_PUBLIC, + contact_public_key_hex=CLIENT2_PUBLIC_HEX, + display_name="Client2", + ) + + # No messages stored — outgoing DMs are handled by the send endpoint + messages = await MessageRepository.get_all( + msg_type="PRIV", conversation_key=CLIENT2_PUBLIC_HEX.lower(), limit=10 + ) + assert len(messages) == 0 + + @pytest.mark.asyncio + async def test_historical_decrypt_broadcasts_success(self, test_db, captured_broadcasts): + """Successful decryption broadcasts a success notification.""" + from app.packet_processor import run_historical_dm_decryption + + await RawPacketRepository.create(DM_PACKET, 1700000000) + + await ContactRepository.upsert( + { + "public_key": CLIENT1_PUBLIC_HEX, + "name": "Client1", + "type": 1, + } + ) + + broadcasts, mock_broadcast = captured_broadcasts + + from unittest.mock import MagicMock + + mock_success = MagicMock() + + with ( + patch("app.packet_processor.broadcast_event", mock_broadcast), + patch("app.websocket.broadcast_success", mock_success), + ): + await run_historical_dm_decryption( + private_key_bytes=CLIENT2_PRIVATE, + contact_public_key_bytes=CLIENT1_PUBLIC, + contact_public_key_hex=CLIENT1_PUBLIC_HEX, + display_name="Client1", + ) + + mock_success.assert_called_once() + args = mock_success.call_args.args + assert "Client1" in args[0] + assert "1 message" in args[1] + + +class TestHistoricalChannelDecryptionPipeline: + """Integration test: store a real channel packet, process it through + the channel message pipeline, verify correct message in DB.""" + + @pytest.mark.asyncio + async def test_process_channel_packet_end_to_end(self, test_db, captured_broadcasts): + """process_raw_packet decrypts a real channel packet and stores + the message with correct sender and text.""" + from app.repository import ChannelRepository + + # Register the #six77 channel + channel_key_hex = CHANNEL_KEY.hex().upper() + await ChannelRepository.upsert(key=channel_key_hex, name=CHANNEL_NAME, is_hashtag=True) + + # Store the raw packet and process it + broadcasts, mock_broadcast = captured_broadcasts + + with patch("app.packet_processor.broadcast_event", mock_broadcast): + from app.packet_processor import process_raw_packet + + result = await process_raw_packet(raw_bytes=CHANNEL_PACKET) + + # Verify it was decrypted + assert result is not None + assert result["decrypted"] is True + assert result["channel_name"] == CHANNEL_NAME + assert result["sender"] == CHANNEL_SENDER + + # Verify message in DB + messages = await MessageRepository.get_all( + msg_type="CHAN", conversation_key=channel_key_hex, limit=10 + ) + assert len(messages) == 1 + assert messages[0].text == CHANNEL_PLAINTEXT_FULL + + # Verify a "message" broadcast was sent + msg_broadcasts = [b for b in broadcasts if b["type"] == "message"] + assert len(msg_broadcasts) == 1 + assert msg_broadcasts[0]["data"]["text"] == CHANNEL_PLAINTEXT_FULL diff --git a/tests/test_send_messages.py b/tests/test_send_messages.py index 265b5f4..8591d8c 100644 --- a/tests/test_send_messages.py +++ b/tests/test_send_messages.py @@ -571,6 +571,90 @@ class TestResendChannelMessage: assert "expired" in exc_info.value.detail.lower() +class TestRadioExceptionMidSend: + """Test that radio exceptions during send don't leave orphaned DB state.""" + + @pytest.mark.asyncio + async def test_dm_send_radio_exception_no_orphan_message(self, test_db): + """When mc.commands.send_msg() raises, no message should be stored in DB.""" + mc = _make_mc() + pub_key = "ab" * 32 + await _insert_contact(pub_key, "Alice") + + # Make the radio command raise (simulates serial timeout / connection drop) + mc.commands.send_msg = AsyncMock(side_effect=ConnectionError("Serial port disconnected")) + + with ( + patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): + with pytest.raises(ConnectionError): + await send_direct_message( + SendDirectMessageRequest(destination=pub_key, text="This will fail") + ) + + # No message should be stored — the exception prevented reaching MessageRepository.create + messages = await MessageRepository.get_all( + msg_type="PRIV", conversation_key=pub_key, limit=10 + ) + assert len(messages) == 0 + + @pytest.mark.asyncio + async def test_channel_send_radio_exception_no_orphan_message(self, test_db): + """When mc.commands.send_chan_msg() raises, no message should be stored in DB.""" + from app.repository import ChannelRepository + + mc = _make_mc(name="TestNode") + chan_key = "ab" * 16 + await ChannelRepository.upsert(key=chan_key, name="#test") + + mc.commands.send_chan_msg = AsyncMock( + side_effect=ConnectionError("Serial port disconnected") + ) + + with ( + patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): + with pytest.raises(ConnectionError): + await send_channel_message( + SendChannelMessageRequest(channel_key=chan_key, text="This will fail") + ) + + messages = await MessageRepository.get_all( + msg_type="CHAN", conversation_key=chan_key.upper(), limit=10 + ) + assert len(messages) == 0 + + @pytest.mark.asyncio + async def test_channel_send_set_channel_exception_no_orphan(self, test_db): + """When mc.commands.set_channel() raises, send is not attempted and no message stored.""" + from app.repository import ChannelRepository + + mc = _make_mc(name="TestNode") + chan_key = "cd" * 16 + await ChannelRepository.upsert(key=chan_key, name="#broken") + + mc.commands.set_channel = AsyncMock(side_effect=TimeoutError("Radio not responding")) + + with ( + patch("app.routers.messages.require_connected", return_value=mc), + patch.object(radio_manager, "_meshcore", mc), + ): + with pytest.raises(TimeoutError): + await send_channel_message( + SendChannelMessageRequest(channel_key=chan_key, text="Never sent") + ) + + # send_chan_msg should never have been called + mc.commands.send_chan_msg.assert_not_called() + + messages = await MessageRepository.get_all( + msg_type="CHAN", conversation_key=chan_key.upper(), limit=10 + ) + assert len(messages) == 0 + + class TestConcurrentChannelSends: """Test that concurrent channel sends are serialized by the radio operation lock. diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 5741423..ab816c0 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,7 +1,7 @@ """Tests for WebSocket manager functionality.""" import asyncio -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest @@ -203,3 +203,48 @@ class TestWebSocketConnectionManagement: # Should not raise await ws_manager.disconnect(mock_websocket) assert len(ws_manager.active_connections) == 0 + + +class TestBroadcastEventFanout: + """Test that broadcast_event dispatches to WS, private MQTT, and community MQTT.""" + + @pytest.mark.asyncio + async def test_broadcast_event_dispatches_to_all_three_sinks(self): + """broadcast_event creates a WS task, calls mqtt_broadcast, and + calls community_mqtt_broadcast.""" + from app.websocket import broadcast_event + + with ( + patch("app.websocket.ws_manager") as mock_ws, + patch("app.mqtt.mqtt_broadcast") as mock_mqtt, + patch("app.community_mqtt.community_mqtt_broadcast") as mock_community, + ): + mock_ws.broadcast = AsyncMock() + + broadcast_event("message", {"id": 1, "text": "hello"}) + + # Let the asyncio task (ws_manager.broadcast) run + await asyncio.sleep(0) + + mock_ws.broadcast.assert_called_once_with("message", {"id": 1, "text": "hello"}) + mock_mqtt.assert_called_once_with("message", {"id": 1, "text": "hello"}) + mock_community.assert_called_once_with("message", {"id": 1, "text": "hello"}) + + @pytest.mark.asyncio + async def test_broadcast_event_passes_event_type_to_mqtt_filters(self): + """MQTT sinks receive the event_type so they can filter by message vs raw_packet.""" + from app.websocket import broadcast_event + + with ( + patch("app.websocket.ws_manager") as mock_ws, + patch("app.mqtt.mqtt_broadcast") as mock_mqtt, + patch("app.community_mqtt.community_mqtt_broadcast") as mock_community, + ): + mock_ws.broadcast = AsyncMock() + + broadcast_event("raw_packet", {"data": "ff00"}) + await asyncio.sleep(0) + + # Both MQTT sinks receive the event type for filtering + assert mock_mqtt.call_args.args[0] == "raw_packet" + assert mock_community.call_args.args[0] == "raw_packet"