diff --git a/tests/test_api.py b/tests/test_api.py index 402707e..2e0e8e4 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,13 +1,71 @@ """Tests for API endpoints. These tests verify the REST API behavior for critical operations. -Uses FastAPI's TestClient for synchronous testing. +Uses httpx.AsyncClient or direct function calls with real in-memory SQLite. """ +import hashlib +import time from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest +from app.database import Database +from app.repository import ( + ChannelRepository, + ContactRepository, + MessageRepository, + RawPacketRepository, +) + + +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + import app.repository as repo_module + + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() + + +@pytest.fixture +def client(): + """Create an httpx AsyncClient for testing the app.""" + from app.main import app + + transport = httpx.ASGITransport(app=app) + return httpx.AsyncClient(transport=transport, base_url="http://test") + + +async def _insert_contact(public_key, name="Alice", **overrides): + """Insert a contact into the test database.""" + data = { + "public_key": public_key, + "name": name, + "type": 0, + "flags": 0, + "last_path": None, + "last_path_len": -1, + "last_advert": None, + "lat": None, + "lon": None, + "last_seen": None, + "on_radio": False, + "last_contacted": None, + } + data.update(overrides) + await ContactRepository.upsert(data) + class TestHealthEndpoint: """Test the health check endpoint.""" @@ -54,51 +112,44 @@ class TestHealthEndpoint: class TestMessagesEndpoint: """Test message-related endpoints.""" - def test_send_direct_message_requires_connection(self): + @pytest.mark.asyncio + async def test_send_direct_message_requires_connection(self, test_db, client): """Sending message when disconnected returns 503.""" - from fastapi.testclient import TestClient - with patch("app.dependencies.radio_manager") as mock_rm: mock_rm.is_connected = False mock_rm.meshcore = None - from app.main import app - - client = TestClient(app) - - response = client.post( + response = await client.post( "/api/messages/direct", json={"destination": "abc123", "text": "Hello"} ) assert response.status_code == 503 assert "not connected" in response.json()["detail"].lower() - def test_send_channel_message_requires_connection(self): + @pytest.mark.asyncio + async def test_send_channel_message_requires_connection(self, test_db, client): """Sending channel message when disconnected returns 503.""" - from fastapi.testclient import TestClient - with patch("app.dependencies.radio_manager") as mock_rm: mock_rm.is_connected = False mock_rm.meshcore = None - from app.main import app - - client = TestClient(app) - - response = client.post( + response = await client.post( "/api/messages/channel", json={"channel_key": "0123456789ABCDEF0123456789ABCDEF", "text": "Hello"}, ) assert response.status_code == 503 - def test_send_direct_message_emits_websocket_message_event(self): + @pytest.mark.asyncio + async def test_send_direct_message_emits_websocket_message_event(self, test_db, client): """POST /messages/direct should emit a WS message event for other clients.""" - from fastapi.testclient import TestClient from meshcore import EventType + pub_key = "ab" * 32 + await _insert_contact(pub_key, "Alice") + mock_mc = MagicMock() - mock_mc.get_contact_by_key_prefix.return_value = {"public_key": "ab" * 32} + mock_mc.get_contact_by_key_prefix.return_value = {"public_key": pub_key} mock_mc.commands.add_contact = AsyncMock( return_value=MagicMock(type=EventType.OK, payload={}) ) @@ -106,22 +157,12 @@ class TestMessagesEndpoint: return_value=MagicMock(type=EventType.MSG_SENT, payload={}) ) - mock_contact = MagicMock() - mock_contact.public_key = "ab" * 32 - mock_contact.to_radio_dict.return_value = {"public_key": "ab" * 32} - def _capture_task(coro): coro.close() return MagicMock() with ( patch("app.dependencies.radio_manager") as mock_rm, - patch( - "app.repository.ContactRepository.get_by_key_or_prefix", - new=AsyncMock(return_value=mock_contact), - ), - patch("app.repository.ContactRepository.update_last_contacted", new=AsyncMock()), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=123)), patch("app.bot.run_bot_for_message", new=AsyncMock()), patch("app.routers.messages.asyncio.create_task", side_effect=_capture_task), patch("app.routers.messages.broadcast_event", create=True) as mock_broadcast, @@ -129,52 +170,42 @@ class TestMessagesEndpoint: mock_rm.is_connected = True mock_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.post( + response = await client.post( "/api/messages/direct", - json={"destination": mock_contact.public_key, "text": "Hello"}, + json={"destination": pub_key, "text": "Hello"}, ) assert response.status_code == 200 mock_broadcast.assert_called_once() event_type, payload = mock_broadcast.call_args.args assert event_type == "message" - assert payload["id"] == 123 assert payload["type"] == "PRIV" - def test_send_channel_message_emits_websocket_message_event(self): + # Verify message was stored in real DB + messages = await MessageRepository.get_all(conversation_key=pub_key) + assert len(messages) == 1 + assert messages[0].text == "Hello" + + @pytest.mark.asyncio + async def test_send_channel_message_emits_websocket_message_event(self, test_db, client): """POST /messages/channel should emit a WS message event for other clients.""" - from fastapi.testclient import TestClient from meshcore import EventType + chan_key = "AA" * 16 + await ChannelRepository.upsert(key=chan_key, name="Public") + mock_mc = MagicMock() mock_mc.self_info = {"name": "TestNode"} ok_result = MagicMock(type=EventType.MSG_SENT, payload={}) mock_mc.commands.set_channel = AsyncMock(return_value=ok_result) mock_mc.commands.send_chan_msg = AsyncMock(return_value=ok_result) - mock_channel = MagicMock() - mock_channel.name = "Public" - mock_channel.key = "AA" * 16 - def _capture_task(coro): coro.close() return MagicMock() with ( patch("app.dependencies.radio_manager") as mock_rm, - patch( - "app.repository.ChannelRepository.get_by_key", - new=AsyncMock(return_value=mock_channel), - ), - patch( - "app.repository.AppSettingsRepository.get", - new=AsyncMock(return_value=MagicMock(experimental_channel_double_send=False)), - ), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=456)), - patch("app.repository.MessageRepository.get_ack_count", new=AsyncMock(return_value=0)), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()), patch("app.routers.messages.asyncio.create_task", side_effect=_capture_task), @@ -183,43 +214,28 @@ class TestMessagesEndpoint: mock_rm.is_connected = True mock_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.post( + response = await client.post( "/api/messages/channel", - json={"channel_key": mock_channel.key, "text": "Hello room"}, + json={"channel_key": chan_key, "text": "Hello room"}, ) assert response.status_code == 200 mock_broadcast.assert_called_once() event_type, payload = mock_broadcast.call_args.args assert event_type == "message" - assert payload["id"] == 456 assert payload["type"] == "CHAN" - def test_send_direct_message_contact_not_found(self): + @pytest.mark.asyncio + async def test_send_direct_message_contact_not_found(self, test_db, client): """Sending to unknown contact returns 404.""" - from fastapi.testclient import TestClient - mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix.return_value = None - with ( - patch("app.dependencies.radio_manager") as mock_rm, - patch( - "app.repository.ContactRepository.get_by_key_or_prefix", new_callable=AsyncMock - ) as mock_get, - ): + with patch("app.dependencies.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc - mock_get.return_value = None - from app.main import app - - client = TestClient(app) - - response = client.post( + response = await client.post( "/api/messages/direct", json={"destination": "nonexistent", "text": "Hello"} ) @@ -227,38 +243,29 @@ class TestMessagesEndpoint: assert "not found" in response.json()["detail"].lower() @pytest.mark.asyncio - async def test_send_direct_message_duplicate_returns_500(self): + async def test_send_direct_message_duplicate_returns_500(self, test_db): """If MessageRepository.create returns None (duplicate), returns 500.""" from app.models import SendDirectMessageRequest from app.routers.messages import send_direct_message + pub_key = "a" * 64 + await _insert_contact(pub_key, "TestContact") + mock_mc = MagicMock() - mock_mc.get_contact_by_key_prefix.return_value = {"public_key": "a" * 64} - - mock_add_result = MagicMock() - mock_add_result.type = MagicMock() - mock_add_result.type.name = "OK" - mock_mc.commands.add_contact = AsyncMock(return_value=mock_add_result) - - mock_send_result = MagicMock() - mock_send_result.type = MagicMock() - mock_send_result.type.name = "OK" - mock_send_result.payload = {"expected_ack": b"\x00\x01"} - mock_mc.commands.send_msg = AsyncMock(return_value=mock_send_result) - - mock_contact = MagicMock() - mock_contact.public_key = "a" * 64 - mock_contact.to_radio_dict.return_value = {"public_key": "a" * 64} + mock_mc.get_contact_by_key_prefix.return_value = {"public_key": pub_key} + mock_mc.commands.add_contact = AsyncMock( + return_value=MagicMock(type=MagicMock(name="OK"), payload={}) + ) + mock_mc.commands.send_msg = AsyncMock( + return_value=MagicMock(type=MagicMock(name="OK"), payload={"expected_ack": b"\x00\x01"}) + ) with ( patch("app.dependencies.radio_manager") as mock_rm, - patch("app.repository.ContactRepository") as mock_contact_repo, patch("app.routers.messages.MessageRepository") as mock_msg_repo, ): mock_rm.is_connected = True mock_rm.meshcore = mock_mc - mock_contact_repo.get_by_key_or_prefix = AsyncMock(return_value=mock_contact) - mock_contact_repo.update_last_contacted = AsyncMock() # Simulate duplicate - create returns None mock_msg_repo.create = AsyncMock(return_value=None) @@ -266,42 +273,35 @@ class TestMessagesEndpoint: with pytest.raises(HTTPException) as exc_info: await send_direct_message( - SendDirectMessageRequest(destination="a" * 64, text="Hello") + SendDirectMessageRequest(destination=pub_key, text="Hello") ) assert exc_info.value.status_code == 500 assert "unexpected duplicate" in exc_info.value.detail.lower() @pytest.mark.asyncio - async def test_send_channel_message_duplicate_returns_500(self): + async def test_send_channel_message_duplicate_returns_500(self, test_db): """If MessageRepository.create returns None (duplicate), returns 500.""" - from app.models import AppSettings, SendChannelMessageRequest + from app.models import SendChannelMessageRequest from app.routers.messages import send_channel_message - mock_mc = MagicMock() - mock_send_result = MagicMock() - mock_send_result.type = MagicMock() - mock_send_result.type.name = "OK" - mock_send_result.payload = {} - mock_mc.commands.send_chan_msg = AsyncMock(return_value=mock_send_result) - mock_mc.commands.set_channel = AsyncMock(return_value=mock_send_result) + chan_key = "0123456789ABCDEF0123456789ABCDEF" + await ChannelRepository.upsert(key=chan_key, name="test") - mock_channel = MagicMock() - mock_channel.name = "test" - mock_channel.key = "0123456789ABCDEF0123456789ABCDEF" + mock_mc = MagicMock() + mock_mc.commands.send_chan_msg = AsyncMock( + return_value=MagicMock(type=MagicMock(name="OK"), payload={}) + ) + mock_mc.commands.set_channel = AsyncMock( + return_value=MagicMock(type=MagicMock(name="OK"), payload={}) + ) with ( patch("app.dependencies.radio_manager") as mock_rm, - patch("app.repository.ChannelRepository") as mock_chan_repo, - patch( - "app.repository.AppSettingsRepository.get", - new=AsyncMock(return_value=AppSettings()), - ), patch("app.routers.messages.MessageRepository") as mock_msg_repo, ): mock_rm.is_connected = True mock_rm.meshcore = mock_mc - mock_chan_repo.get_by_key = AsyncMock(return_value=mock_channel) # Simulate duplicate - create returns None mock_msg_repo.create = AsyncMock(return_value=None) @@ -309,9 +309,7 @@ class TestMessagesEndpoint: with pytest.raises(HTTPException) as exc_info: await send_channel_message( - SendChannelMessageRequest( - channel_key="0123456789ABCDEF0123456789ABCDEF", text="Hello" - ) + SendChannelMessageRequest(channel_key=chan_key, text="Hello") ) assert exc_info.value.status_code == 500 @@ -322,54 +320,41 @@ class TestChannelsEndpoint: """Test channel-related endpoints.""" @pytest.mark.asyncio - async def test_create_hashtag_channel_derives_key(self): + async def test_create_hashtag_channel_derives_key(self, test_db): """Creating hashtag channel derives key from name and stores in DB.""" - import hashlib - from app.routers.channels import CreateChannelRequest, create_channel - with patch("app.routers.channels.ChannelRepository") as mock_repo: - mock_repo.upsert = AsyncMock() + request = CreateChannelRequest(name="#mychannel") + result = await create_channel(request) - request = CreateChannelRequest(name="#mychannel") + # Verify the key derivation + expected_key_hex = hashlib.sha256(b"#mychannel").digest()[:16].hex().upper() + assert result.key == expected_key_hex + assert result.name == "#mychannel" - result = await create_channel(request) - - # Verify the key derivation - channel stored in DB, not pushed to radio - expected_key_hex = hashlib.sha256(b"#mychannel").digest()[:16].hex().upper() - mock_repo.upsert.assert_called_once() - call_args = mock_repo.upsert.call_args - assert call_args[1]["key"] == expected_key_hex - assert call_args[1]["name"] == "#mychannel" - assert call_args[1]["is_hashtag"] is True - assert call_args[1]["on_radio"] is False # Not pushed to radio on create - - # Verify response - assert result.key == expected_key_hex - assert result.name == "#mychannel" + # Verify stored in real DB + channel = await ChannelRepository.get_by_key(expected_key_hex) + assert channel is not None + assert channel.name == "#mychannel" + assert channel.is_hashtag is True + assert channel.on_radio is False @pytest.mark.asyncio - async def test_create_channel_with_explicit_key(self): + async def test_create_channel_with_explicit_key(self, test_db): """Creating channel with explicit key uses provided key.""" from app.routers.channels import CreateChannelRequest, create_channel - with patch("app.routers.channels.ChannelRepository") as mock_repo: - mock_repo.upsert = AsyncMock() + explicit_key = "0123456789abcdef0123456789abcdef" # 32 hex chars = 16 bytes + request = CreateChannelRequest(name="private", key=explicit_key) + result = await create_channel(request) - explicit_key = "0123456789abcdef0123456789abcdef" # 32 hex chars = 16 bytes - request = CreateChannelRequest(name="private", key=explicit_key) + assert result.key == explicit_key.upper() - result = await create_channel(request) - - # Verify key stored in DB correctly (stored as uppercase hex) - mock_repo.upsert.assert_called_once() - call_args = mock_repo.upsert.call_args - assert call_args[1]["key"] == explicit_key.upper() - assert call_args[1]["name"] == "private" - assert call_args[1]["on_radio"] is False - - # Verify response - assert result.key == explicit_key.upper() + # Verify stored in real DB + channel = await ChannelRepository.get_by_key(explicit_key.upper()) + assert channel is not None + assert channel.name == "private" + assert channel.on_radio is False class TestPacketsEndpoint: @@ -396,957 +381,369 @@ class TestReadStateEndpoints: """Test read state tracking endpoints.""" @pytest.mark.asyncio - async def test_mark_contact_read_updates_timestamp(self): + async def test_mark_contact_read_updates_timestamp(self, test_db): """Marking contact as read updates last_read_at in database.""" - import time + pub_key = "abc123def456789012345678901234567890123456789012345678901234" + await _insert_contact(pub_key, "TestContact") - import aiosqlite + before_time = int(time.time()) - from app.database import db - from app.repository import ContactRepository + updated = await ContactRepository.update_last_read_at(pub_key) + assert updated is True - # Use in-memory database for testing - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - # Create contacts table with last_read_at column - await conn.execute(""" - CREATE TABLE contacts ( - public_key TEXT PRIMARY KEY, - name TEXT, - type INTEGER DEFAULT 0, - flags INTEGER DEFAULT 0, - last_path TEXT, - last_path_len INTEGER DEFAULT -1, - last_advert INTEGER, - lat REAL, - lon REAL, - last_seen INTEGER, - on_radio INTEGER DEFAULT 0, - last_contacted INTEGER, - last_read_at INTEGER - ) - """) - - # Insert a test contact - await conn.execute( - "INSERT INTO contacts (public_key, name) VALUES (?, ?)", - ("abc123def456789012345678901234567890123456789012345678901234", "TestContact"), - ) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - before_time = int(time.time()) - - # Update last_read_at - updated = await ContactRepository.update_last_read_at( - "abc123def456789012345678901234567890123456789012345678901234" - ) - - assert updated is True - - # Verify the timestamp was set - contact = await ContactRepository.get_by_key( - "abc123def456789012345678901234567890123456789012345678901234" - ) - assert contact is not None - assert contact.last_read_at is not None - assert contact.last_read_at >= before_time - finally: - db._connection = original_conn - await conn.close() + contact = await ContactRepository.get_by_key(pub_key) + assert contact is not None + assert contact.last_read_at is not None + assert contact.last_read_at >= before_time @pytest.mark.asyncio - async def test_mark_channel_read_updates_timestamp(self): + async def test_mark_channel_read_updates_timestamp(self, test_db): """Marking channel as read updates last_read_at in database.""" - import time + chan_key = "0123456789ABCDEF0123456789ABCDEF" + await ChannelRepository.upsert(key=chan_key, name="#testchannel") - import aiosqlite + before_time = int(time.time()) - from app.database import db - from app.repository import ChannelRepository + updated = await ChannelRepository.update_last_read_at(chan_key) + assert updated is True - # Use in-memory database for testing - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - # Create channels table with last_read_at column - await conn.execute(""" - CREATE TABLE channels ( - key TEXT PRIMARY KEY, - name TEXT NOT NULL, - is_hashtag INTEGER DEFAULT 0, - on_radio INTEGER DEFAULT 0, - last_read_at INTEGER - ) - """) - - # Insert a test channel - await conn.execute( - "INSERT INTO channels (key, name) VALUES (?, ?)", - ("0123456789ABCDEF0123456789ABCDEF", "#testchannel"), - ) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - before_time = int(time.time()) - - # Update last_read_at - updated = await ChannelRepository.update_last_read_at( - "0123456789ABCDEF0123456789ABCDEF" - ) - - assert updated is True - - # Verify the timestamp was set - channel = await ChannelRepository.get_by_key("0123456789ABCDEF0123456789ABCDEF") - assert channel is not None - assert channel.last_read_at is not None - assert channel.last_read_at >= before_time - finally: - db._connection = original_conn - await conn.close() + channel = await ChannelRepository.get_by_key(chan_key) + assert channel is not None + assert channel.last_read_at is not None + assert channel.last_read_at >= before_time @pytest.mark.asyncio - async def test_mark_nonexistent_contact_returns_false(self): + async def test_mark_nonexistent_contact_returns_false(self, test_db): """Marking nonexistent contact returns False.""" - import aiosqlite + updated = await ContactRepository.update_last_read_at("nonexistent") + assert updated is False - from app.database import db - from app.repository import ContactRepository - - # Use in-memory database for testing - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE contacts ( - public_key TEXT PRIMARY KEY, - name TEXT, - type INTEGER DEFAULT 0, - flags INTEGER DEFAULT 0, - last_path TEXT, - last_path_len INTEGER DEFAULT -1, - last_advert INTEGER, - lat REAL, - lon REAL, - last_seen INTEGER, - on_radio INTEGER DEFAULT 0, - last_contacted INTEGER, - last_read_at INTEGER - ) - """) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - updated = await ContactRepository.update_last_read_at("nonexistent") - assert updated is False - finally: - db._connection = original_conn - await conn.close() - - def test_mark_contact_read_endpoint_returns_404_for_missing(self): + @pytest.mark.asyncio + async def test_mark_contact_read_endpoint_returns_404_for_missing(self, test_db, client): """Mark-read endpoint returns 404 for nonexistent contact.""" - from fastapi.testclient import TestClient + response = await client.post("/api/contacts/nonexistent/mark-read") - with patch( - "app.repository.ContactRepository.get_by_key_or_prefix", new_callable=AsyncMock - ) as mock_get: - mock_get.return_value = None + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() - from app.main import app - - client = TestClient(app) - - response = client.post("/api/contacts/nonexistent/mark-read") - - assert response.status_code == 404 - assert "not found" in response.json()["detail"].lower() - - def test_mark_channel_read_endpoint_returns_404_for_missing(self): + @pytest.mark.asyncio + async def test_mark_channel_read_endpoint_returns_404_for_missing(self, test_db, client): """Mark-read endpoint returns 404 for nonexistent channel.""" - from fastapi.testclient import TestClient + response = await client.post("/api/channels/NONEXISTENT/mark-read") - with patch( - "app.repository.ChannelRepository.get_by_key", new_callable=AsyncMock - ) as mock_get: - mock_get.return_value = None - - from app.main import app - - client = TestClient(app) - - response = client.post("/api/channels/NONEXISTENT/mark-read") - - assert response.status_code == 404 - assert "not found" in response.json()["detail"].lower() + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() @pytest.mark.asyncio - async def test_get_unreads_returns_counts_and_mentions(self): + async def test_get_unreads_returns_counts_and_mentions(self, test_db): """GET /unreads returns unread counts, mentions, and last message times.""" - import aiosqlite - - from app.database import db - from app.repository import MessageRepository - - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - # Create tables - await conn.execute(""" - CREATE TABLE contacts ( - public_key TEXT PRIMARY KEY, - name TEXT, - type INTEGER DEFAULT 0, - flags INTEGER DEFAULT 0, - last_path TEXT, - last_path_len INTEGER DEFAULT -1, - last_advert INTEGER, - lat REAL, - lon REAL, - last_seen INTEGER, - on_radio INTEGER DEFAULT 0, - last_contacted INTEGER, - last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE channels ( - key TEXT PRIMARY KEY, - name TEXT NOT NULL, - is_hashtag INTEGER DEFAULT 0, - on_radio INTEGER DEFAULT 0, - last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE messages ( - id INTEGER PRIMARY KEY, - type TEXT NOT NULL, - conversation_key TEXT NOT NULL, - text TEXT NOT NULL, - sender_timestamp INTEGER, - received_at INTEGER NOT NULL, - paths TEXT, - txt_type INTEGER DEFAULT 0, - signature TEXT, - outgoing INTEGER DEFAULT 0, - acked INTEGER DEFAULT 0, - UNIQUE(type, conversation_key, text, sender_timestamp) - ) - """) - - # Insert channel and contact - await conn.execute( - "INSERT INTO channels (key, name, last_read_at) VALUES (?, ?, ?)", - ("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1", "Public", 1000), - ) - await conn.execute( - "INSERT INTO contacts (public_key, name, last_read_at) VALUES (?, ?, ?)", - ("abcd" * 16, "Alice", 1000), - ) - - # Insert messages: 2 unread channel msgs (after last_read_at=1000), - # 1 read (before), 1 outgoing (should not count) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1", "Bob: hello", 1001, 0), - ) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1", "Bob: @[testuser] hey", 1002, 0), - ) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1", "Bob: old msg", 999, 0), - ) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1", "Me: outgoing", 1003, 1), - ) - - # Insert 1 unread DM - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("PRIV", "abcd" * 16, "hi @[TeStUsEr] there", 1005, 0), - ) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - result = await MessageRepository.get_unread_counts("TestUser") - - # Channel: 2 unread (1001 and 1002), one has mention - assert result["counts"]["channel-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1"] == 2 - assert result["mentions"]["channel-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1"] is True - - # Contact: 1 unread with mention (also case-insensitive) - assert result["counts"][f"contact-{'abcd' * 16}"] == 1 - assert result["mentions"][f"contact-{'abcd' * 16}"] is True - - # Last message times should include all conversations - assert "channel-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1" in result["last_message_times"] - assert result["last_message_times"]["channel-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1"] == 1003 - assert f"contact-{'abcd' * 16}" in result["last_message_times"] - assert result["last_message_times"][f"contact-{'abcd' * 16}"] == 1005 - finally: - db._connection = original_conn - await conn.close() - - @pytest.mark.asyncio - async def test_get_unreads_no_name_skips_mentions(self): - """GET /unreads without name param returns counts but no mention flags.""" - import aiosqlite - - from app.database import db - from app.repository import MessageRepository - - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE channels ( - key TEXT PRIMARY KEY, - name TEXT NOT NULL, - is_hashtag INTEGER DEFAULT 0, - on_radio INTEGER DEFAULT 0, - last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE contacts ( - public_key TEXT PRIMARY KEY, - name TEXT, - type INTEGER DEFAULT 0, - flags INTEGER DEFAULT 0, - last_path TEXT, - last_path_len INTEGER DEFAULT -1, - last_advert INTEGER, - lat REAL, - lon REAL, - last_seen INTEGER, - on_radio INTEGER DEFAULT 0, - last_contacted INTEGER, - last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE messages ( - id INTEGER PRIMARY KEY, - type TEXT NOT NULL, - conversation_key TEXT NOT NULL, - text TEXT NOT NULL, - sender_timestamp INTEGER, - received_at INTEGER NOT NULL, - paths TEXT, - txt_type INTEGER DEFAULT 0, - signature TEXT, - outgoing INTEGER DEFAULT 0, - acked INTEGER DEFAULT 0, - UNIQUE(type, conversation_key, text, sender_timestamp) - ) - """) - - await conn.execute( - "INSERT INTO channels (key, name, last_read_at) VALUES (?, ?, ?)", - ("CHAN1KEY1CHAN1KEY1CHAN1KEY1CHAN1KEY1", "Public", 0), - ) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", "CHAN1KEY1CHAN1KEY1CHAN1KEY1CHAN1KEY1", "Bob: @[Alice] hey", 1001, 0), - ) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - result = await MessageRepository.get_unread_counts(None) - - assert result["counts"]["channel-CHAN1KEY1CHAN1KEY1CHAN1KEY1CHAN1KEY1"] == 1 - # No mentions since name was None - assert len(result["mentions"]) == 0 - finally: - db._connection = original_conn - await conn.close() - - @pytest.mark.asyncio - async def test_unreads_reset_after_mark_read(self): - """Marking a conversation as read zeroes its unread count; new messages after count again.""" - import aiosqlite - - from app.database import db - from app.repository import MessageRepository - - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE channels ( - key TEXT PRIMARY KEY, name TEXT NOT NULL, - is_hashtag INTEGER DEFAULT 0, on_radio INTEGER DEFAULT 0, last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE contacts ( - public_key TEXT PRIMARY KEY, name TEXT, - type INTEGER DEFAULT 0, flags INTEGER DEFAULT 0, - last_path TEXT, last_path_len INTEGER DEFAULT -1, - last_advert INTEGER, lat REAL, lon REAL, last_seen INTEGER, - on_radio INTEGER DEFAULT 0, last_contacted INTEGER, last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE messages ( - id INTEGER PRIMARY KEY, type TEXT NOT NULL, - conversation_key TEXT NOT NULL, text TEXT NOT NULL, - sender_timestamp INTEGER, received_at INTEGER NOT NULL, - paths TEXT, txt_type INTEGER DEFAULT 0, signature TEXT, - outgoing INTEGER DEFAULT 0, acked INTEGER DEFAULT 0, - UNIQUE(type, conversation_key, text, sender_timestamp) - ) - """) - chan_key = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1" - await conn.execute( - "INSERT INTO channels (key, name, last_read_at) VALUES (?, ?, ?)", - (chan_key, "Public", 1000), - ) - # 2 unread messages (received_at > last_read_at=1000) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", chan_key, "msg1", 1001, 0), - ) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", chan_key, "msg2", 1002, 0), - ) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - # Verify 2 unread - result = await MessageRepository.get_unread_counts(None) - assert result["counts"][f"channel-{chan_key}"] == 2 - - # Simulate mark-read by updating last_read_at to after all messages - await conn.execute( - "UPDATE channels SET last_read_at = ? WHERE key = ?", (1002, chan_key) - ) - await conn.commit() - - # Verify 0 unread - result = await MessageRepository.get_unread_counts(None) - assert result["counts"].get(f"channel-{chan_key}", 0) == 0 - - # New message arrives after the read point - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("CHAN", chan_key, "msg3", 1003, 0), - ) - await conn.commit() - - # Verify exactly 1 unread - result = await MessageRepository.get_unread_counts(None) - assert result["counts"][f"channel-{chan_key}"] == 1 - finally: - db._connection = original_conn - await conn.close() - - @pytest.mark.asyncio - async def test_unreads_exclude_outgoing_messages(self): - """Outgoing messages should never count as unread, even when received_at > last_read_at. - - This is critical: without the outgoing filter, every message we send would - show as an unread badge in the sidebar. - """ - import aiosqlite - - from app.database import db - from app.repository import MessageRepository - - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE channels ( - key TEXT PRIMARY KEY, name TEXT NOT NULL, - is_hashtag INTEGER DEFAULT 0, on_radio INTEGER DEFAULT 0, last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE contacts ( - public_key TEXT PRIMARY KEY, name TEXT, - type INTEGER DEFAULT 0, flags INTEGER DEFAULT 0, - last_path TEXT, last_path_len INTEGER DEFAULT -1, - last_advert INTEGER, lat REAL, lon REAL, last_seen INTEGER, - on_radio INTEGER DEFAULT 0, last_contacted INTEGER, last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE messages ( - id INTEGER PRIMARY KEY, type TEXT NOT NULL, - conversation_key TEXT NOT NULL, text TEXT NOT NULL, - sender_timestamp INTEGER, received_at INTEGER NOT NULL, - paths TEXT, txt_type INTEGER DEFAULT 0, signature TEXT, - outgoing INTEGER DEFAULT 0, acked INTEGER DEFAULT 0, - UNIQUE(type, conversation_key, text, sender_timestamp) - ) - """) - contact_key = "abcd" * 16 - await conn.execute( - "INSERT INTO contacts (public_key, name, last_read_at) VALUES (?, ?, ?)", - (contact_key, "Bob", 1000), - ) - # 1 incoming (should count) + 2 outgoing (should NOT count) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("PRIV", contact_key, "incoming msg", 1001, 0), - ) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("PRIV", contact_key, "my reply", 1002, 1), - ) - await conn.execute( - "INSERT INTO messages (type, conversation_key, text, received_at, outgoing) VALUES (?, ?, ?, ?, ?)", - ("PRIV", contact_key, "another reply", 1003, 1), - ) - await conn.commit() - original_conn = db._connection - db._connection = conn + await ChannelRepository.upsert(key=chan_key, name="Public") + await ChannelRepository.update_last_read_at(chan_key, 1000) + await _insert_contact(contact_key, "Alice") + await ContactRepository.update_last_read_at(contact_key, 1000) - try: - result = await MessageRepository.get_unread_counts(None) - # Only the 1 incoming message should count as unread - assert result["counts"][f"contact-{contact_key}"] == 1 - finally: - db._connection = original_conn - await conn.close() + # 2 unread channel msgs (received_at > last_read_at=1000), 1 read, 1 outgoing + await MessageRepository.create( + msg_type="CHAN", + text="Bob: hello", + received_at=1001, + conversation_key=chan_key, + sender_timestamp=1001, + ) + await MessageRepository.create( + msg_type="CHAN", + text="Bob: @[testuser] hey", + received_at=1002, + conversation_key=chan_key, + sender_timestamp=1002, + ) + await MessageRepository.create( + msg_type="CHAN", + text="Bob: old msg", + received_at=999, + conversation_key=chan_key, + sender_timestamp=999, + ) + await MessageRepository.create( + msg_type="CHAN", + text="Me: outgoing", + received_at=1003, + conversation_key=chan_key, + sender_timestamp=1003, + outgoing=True, + ) + # 1 unread DM with mention + await MessageRepository.create( + msg_type="PRIV", + text="hi @[TeStUsEr] there", + received_at=1005, + conversation_key=contact_key, + sender_timestamp=1005, + ) + + result = await MessageRepository.get_unread_counts("TestUser") + + # Channel: 2 unread (1001 and 1002), one has mention + assert result["counts"][f"channel-{chan_key}"] == 2 + assert result["mentions"][f"channel-{chan_key}"] is True + + # Contact: 1 unread with mention (case-insensitive) + assert result["counts"][f"contact-{contact_key}"] == 1 + assert result["mentions"][f"contact-{contact_key}"] is True + + # Last message times should include all conversations + assert result["last_message_times"][f"channel-{chan_key}"] == 1003 + assert result["last_message_times"][f"contact-{contact_key}"] == 1005 @pytest.mark.asyncio - async def test_mark_all_read_updates_all_conversations(self): + async def test_get_unreads_no_name_skips_mentions(self, test_db): + """GET /unreads without name param returns counts but no mention flags.""" + chan_key = "CHAN1KEY1CHAN1KEY1CHAN1KEY1CHAN1KEY1" + await ChannelRepository.upsert(key=chan_key, name="Public") + await ChannelRepository.update_last_read_at(chan_key, 0) + + await MessageRepository.create( + msg_type="CHAN", + text="Bob: @[Alice] hey", + received_at=1001, + conversation_key=chan_key, + sender_timestamp=1001, + ) + + result = await MessageRepository.get_unread_counts(None) + + assert result["counts"][f"channel-{chan_key}"] == 1 + assert len(result["mentions"]) == 0 + + @pytest.mark.asyncio + async def test_unreads_reset_after_mark_read(self, test_db): + """Marking a conversation as read zeroes its unread count; new messages after count again.""" + chan_key = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1" + await ChannelRepository.upsert(key=chan_key, name="Public") + await ChannelRepository.update_last_read_at(chan_key, 1000) + + # 2 unread messages (received_at > last_read_at=1000) + await MessageRepository.create( + msg_type="CHAN", + text="msg1", + received_at=1001, + conversation_key=chan_key, + sender_timestamp=1001, + ) + await MessageRepository.create( + msg_type="CHAN", + text="msg2", + received_at=1002, + conversation_key=chan_key, + sender_timestamp=1002, + ) + + # Verify 2 unread + result = await MessageRepository.get_unread_counts(None) + assert result["counts"][f"channel-{chan_key}"] == 2 + + # Mark as read + await ChannelRepository.update_last_read_at(chan_key, 1002) + + # Verify 0 unread + result = await MessageRepository.get_unread_counts(None) + assert result["counts"].get(f"channel-{chan_key}", 0) == 0 + + # New message arrives after the read point + await MessageRepository.create( + msg_type="CHAN", + text="msg3", + received_at=1003, + conversation_key=chan_key, + sender_timestamp=1003, + ) + + # Verify exactly 1 unread + result = await MessageRepository.get_unread_counts(None) + assert result["counts"][f"channel-{chan_key}"] == 1 + + @pytest.mark.asyncio + async def test_unreads_exclude_outgoing_messages(self, test_db): + """Outgoing messages should never count as unread.""" + contact_key = "abcd" * 16 + await _insert_contact(contact_key, "Bob") + await ContactRepository.update_last_read_at(contact_key, 1000) + + # 1 incoming (should count) + 2 outgoing (should NOT count) + await MessageRepository.create( + msg_type="PRIV", + text="incoming msg", + received_at=1001, + conversation_key=contact_key, + sender_timestamp=1001, + ) + await MessageRepository.create( + msg_type="PRIV", + text="my reply", + received_at=1002, + conversation_key=contact_key, + sender_timestamp=1002, + outgoing=True, + ) + await MessageRepository.create( + msg_type="PRIV", + text="another reply", + received_at=1003, + conversation_key=contact_key, + sender_timestamp=1003, + outgoing=True, + ) + + result = await MessageRepository.get_unread_counts(None) + # Only the 1 incoming message should count as unread + assert result["counts"][f"contact-{contact_key}"] == 1 + + @pytest.mark.asyncio + async def test_mark_all_read_updates_all_conversations(self, test_db): """Bulk mark-all-read updates all contacts and channels.""" - import time + await _insert_contact("contact1", "Alice") + await _insert_contact("contact2", "Bob") + await ChannelRepository.upsert(key="CHAN1KEY1CHAN1KEY1CHAN1KEY1CHAN1KEY1", name="#test1") + await ChannelRepository.upsert(key="CHAN2KEY2CHAN2KEY2CHAN2KEY2CHAN2KEY2", name="#test2") - import aiosqlite + before_time = int(time.time()) - from app.database import db + from app.routers.read_state import mark_all_read - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row + result = await mark_all_read() - # Create tables - await conn.execute(""" - CREATE TABLE contacts ( - public_key TEXT PRIMARY KEY, - name TEXT, - last_read_at INTEGER - ) - """) - await conn.execute(""" - CREATE TABLE channels ( - key TEXT PRIMARY KEY, - name TEXT NOT NULL, - last_read_at INTEGER - ) - """) + assert result["status"] == "ok" + assert result["timestamp"] >= before_time - # Insert test data with NULL last_read_at - await conn.execute( - "INSERT INTO contacts (public_key, name) VALUES (?, ?)", ("contact1", "Alice") - ) - await conn.execute( - "INSERT INTO contacts (public_key, name) VALUES (?, ?)", ("contact2", "Bob") - ) - await conn.execute("INSERT INTO channels (key, name) VALUES (?, ?)", ("CHAN1", "#test1")) - await conn.execute("INSERT INTO channels (key, name) VALUES (?, ?)", ("CHAN2", "#test2")) - await conn.commit() + # Verify all contacts updated + for key in ["contact1", "contact2"]: + contact = await ContactRepository.get_by_key(key) + assert contact.last_read_at >= before_time - original_conn = db._connection - db._connection = conn - - try: - before_time = int(time.time()) - - # Call the endpoint - from app.routers.read_state import mark_all_read - - result = await mark_all_read() - - assert result["status"] == "ok" - assert result["timestamp"] >= before_time - - # Verify all contacts updated - cursor = await conn.execute("SELECT last_read_at FROM contacts") - rows = await cursor.fetchall() - for row in rows: - assert row["last_read_at"] >= before_time - - # Verify all channels updated - cursor = await conn.execute("SELECT last_read_at FROM channels") - rows = await cursor.fetchall() - for row in rows: - assert row["last_read_at"] >= before_time - finally: - db._connection = original_conn - await conn.close() + # Verify all channels updated + for key in ["CHAN1KEY1CHAN1KEY1CHAN1KEY1CHAN1KEY1", "CHAN2KEY2CHAN2KEY2CHAN2KEY2CHAN2KEY2"]: + channel = await ChannelRepository.get_by_key(key) + assert channel.last_read_at >= before_time class TestRawPacketRepository: """Test raw packet storage with deduplication.""" @pytest.mark.asyncio - async def test_create_returns_id_for_new_packet(self): + async def test_create_returns_id_for_new_packet(self, test_db): """First insert of packet data returns a valid ID.""" - import aiosqlite + packet_data = b"\x01\x02\x03\x04\x05" + packet_id, is_new = await RawPacketRepository.create(packet_data, 1234567890) - from app.database import db - from app.repository import RawPacketRepository - - # Use in-memory database for testing - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - # Create the raw_packets table with payload_hash for deduplication - await conn.execute(""" - CREATE TABLE raw_packets ( - id INTEGER PRIMARY KEY, - timestamp INTEGER NOT NULL, - data BLOB NOT NULL, - message_id INTEGER, - payload_hash TEXT - ) - """) - await conn.execute( - "CREATE UNIQUE INDEX IF NOT EXISTS idx_raw_packets_payload_hash ON raw_packets(payload_hash)" - ) - await conn.commit() - - # Patch the db._connection to use our test connection - original_conn = db._connection - db._connection = conn - - try: - packet_data = b"\x01\x02\x03\x04\x05" - packet_id, is_new = await RawPacketRepository.create(packet_data, 1234567890) - - assert packet_id is not None - assert packet_id > 0 - assert is_new is True - finally: - db._connection = original_conn - await conn.close() + assert packet_id is not None + assert packet_id > 0 + assert is_new is True @pytest.mark.asyncio - async def test_different_packets_both_stored(self): + async def test_different_packets_both_stored(self, test_db): """Different packet data both get stored with unique IDs.""" - import aiosqlite + packet1 = b"\x01\x02\x03" + packet2 = b"\x04\x05\x06" - from app.database import db - from app.repository import RawPacketRepository + id1, is_new1 = await RawPacketRepository.create(packet1, 1234567890) + id2, is_new2 = await RawPacketRepository.create(packet2, 1234567891) - # Use in-memory database for testing - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - # Create the raw_packets table with payload_hash for deduplication - await conn.execute(""" - CREATE TABLE raw_packets ( - id INTEGER PRIMARY KEY, - timestamp INTEGER NOT NULL, - data BLOB NOT NULL, - message_id INTEGER, - payload_hash TEXT - ) - """) - await conn.execute( - "CREATE UNIQUE INDEX IF NOT EXISTS idx_raw_packets_payload_hash ON raw_packets(payload_hash)" - ) - await conn.commit() - - # Patch the db._connection to use our test connection - original_conn = db._connection - db._connection = conn - - try: - packet1 = b"\x01\x02\x03" - packet2 = b"\x04\x05\x06" - - id1, is_new1 = await RawPacketRepository.create(packet1, 1234567890) - id2, is_new2 = await RawPacketRepository.create(packet2, 1234567891) - - assert id1 is not None - assert id2 is not None - assert id1 != id2 - assert is_new1 is True - assert is_new2 is True - finally: - db._connection = original_conn - await conn.close() + assert id1 is not None + assert id2 is not None + assert id1 != id2 + assert is_new1 is True + assert is_new2 is True @pytest.mark.asyncio - async def test_duplicate_packet_returns_existing_id(self): + async def test_duplicate_packet_returns_existing_id(self, test_db): """Inserting same payload twice returns existing ID and is_new=False.""" - import aiosqlite + # Same packet data inserted twice + packet_data = b"\x01\x02\x03\x04\x05" + id1, is_new1 = await RawPacketRepository.create(packet_data, 1234567890) + id2, is_new2 = await RawPacketRepository.create(packet_data, 1234567891) - from app.database import db - from app.repository import RawPacketRepository - - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - # Create the raw_packets table with payload_hash for deduplication - await conn.execute(""" - CREATE TABLE raw_packets ( - id INTEGER PRIMARY KEY, - timestamp INTEGER NOT NULL, - data BLOB NOT NULL, - message_id INTEGER, - payload_hash TEXT - ) - """) - await conn.execute( - "CREATE UNIQUE INDEX IF NOT EXISTS idx_raw_packets_payload_hash ON raw_packets(payload_hash)" - ) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - # Same packet data inserted twice - packet_data = b"\x01\x02\x03\x04\x05" - id1, is_new1 = await RawPacketRepository.create(packet_data, 1234567890) - id2, is_new2 = await RawPacketRepository.create(packet_data, 1234567891) - - # Both should return the same ID - assert id1 == id2 - # First is new, second is not - assert is_new1 is True - assert is_new2 is False - finally: - db._connection = original_conn - await conn.close() + # Both should return the same ID + assert id1 == id2 + # First is new, second is not + assert is_new1 is True + assert is_new2 is False @pytest.mark.asyncio - async def test_malformed_packet_uses_full_data_hash(self): + async def test_malformed_packet_uses_full_data_hash(self, test_db): """Malformed packets (can't extract payload) hash full data for dedup.""" - import aiosqlite + # Single byte is too short to be valid packet (extract_payload returns None) + malformed = b"\x01" + id1, is_new1 = await RawPacketRepository.create(malformed, 1234567890) + id2, is_new2 = await RawPacketRepository.create(malformed, 1234567891) - from app.database import db - from app.repository import RawPacketRepository + # Should still deduplicate using full data hash + assert id1 == id2 + assert is_new1 is True + assert is_new2 is False - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE raw_packets ( - id INTEGER PRIMARY KEY, - timestamp INTEGER NOT NULL, - data BLOB NOT NULL, - message_id INTEGER, - payload_hash TEXT - ) - """) - await conn.execute( - "CREATE UNIQUE INDEX IF NOT EXISTS idx_raw_packets_payload_hash ON raw_packets(payload_hash)" - ) - await conn.commit() - - original_conn = db._connection - db._connection = conn - - try: - # Single byte is too short to be valid packet (extract_payload returns None) - malformed = b"\x01" - id1, is_new1 = await RawPacketRepository.create(malformed, 1234567890) - id2, is_new2 = await RawPacketRepository.create(malformed, 1234567891) - - # Should still deduplicate using full data hash - assert id1 == id2 - assert is_new1 is True - assert is_new2 is False - - # Different malformed packet should get different ID - different_malformed = b"\x02" - id3, is_new3 = await RawPacketRepository.create(different_malformed, 1234567892) - assert id3 != id1 - assert is_new3 is True - finally: - db._connection = original_conn - await conn.close() + # Different malformed packet should get different ID + different_malformed = b"\x02" + id3, is_new3 = await RawPacketRepository.create(different_malformed, 1234567892) + assert id3 != id1 + assert is_new3 is True @pytest.mark.asyncio - async def test_prune_old_undecrypted_deletes_old_packets(self): + async def test_prune_old_undecrypted_deletes_old_packets(self, test_db): """Prune deletes undecrypted packets older than specified days.""" - import time - - import aiosqlite - - from app.database import db - from app.repository import RawPacketRepository - - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE raw_packets ( - id INTEGER PRIMARY KEY, - timestamp INTEGER NOT NULL, - data BLOB NOT NULL UNIQUE, - message_id INTEGER - ) - """) - now = int(time.time()) old_timestamp = now - (15 * 86400) # 15 days ago recent_timestamp = now - (5 * 86400) # 5 days ago - # Insert old undecrypted packet (message_id NULL = undecrypted) - await conn.execute( - "INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)", - (old_timestamp, b"\x01\x02\x03"), - ) - # Insert recent undecrypted packet (message_id NULL = undecrypted) - await conn.execute( - "INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)", - (recent_timestamp, b"\x04\x05\x06"), - ) + # Insert old undecrypted packet + await RawPacketRepository.create(b"\x01\x02\x03", old_timestamp) + # Insert recent undecrypted packet + await RawPacketRepository.create(b"\x04\x05\x06", recent_timestamp) # Insert old but decrypted packet (should NOT be deleted) - # message_id NOT NULL = decrypted - await conn.execute( - "INSERT INTO raw_packets (timestamp, data, message_id) VALUES (?, ?, ?)", - (old_timestamp, b"\x07\x08\x09", 1), - ) - await conn.commit() + old_id, _ = await RawPacketRepository.create(b"\x07\x08\x09", old_timestamp) + await RawPacketRepository.mark_decrypted(old_id, 1) - original_conn = db._connection - db._connection = conn + # Prune packets older than 10 days + deleted = await RawPacketRepository.prune_old_undecrypted(10) - try: - # Prune packets older than 10 days - deleted = await RawPacketRepository.prune_old_undecrypted(10) - - assert deleted == 1 # Only the old undecrypted packet - - # Verify remaining packets - cursor = await conn.execute("SELECT COUNT(*) as count FROM raw_packets") - row = await cursor.fetchone() - assert row["count"] == 2 # Recent undecrypted + old decrypted - finally: - db._connection = original_conn - await conn.close() + assert deleted == 1 # Only the old undecrypted packet @pytest.mark.asyncio - async def test_prune_old_undecrypted_returns_zero_when_nothing_to_delete(self): + async def test_prune_old_undecrypted_returns_zero_when_nothing_to_delete(self, test_db): """Prune returns 0 when no packets match criteria.""" - import time - - import aiosqlite - - from app.database import db - from app.repository import RawPacketRepository - - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE raw_packets ( - id INTEGER PRIMARY KEY, - timestamp INTEGER NOT NULL, - data BLOB NOT NULL UNIQUE, - message_id INTEGER - ) - """) - now = int(time.time()) recent_timestamp = now - (5 * 86400) # 5 days ago - # Insert only recent packet (message_id NULL = undecrypted) - await conn.execute( - "INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)", - (recent_timestamp, b"\x01\x02\x03"), - ) - await conn.commit() + # Insert only recent packet + await RawPacketRepository.create(b"\x01\x02\x03", recent_timestamp) - original_conn = db._connection - db._connection = conn - - try: - # Prune packets older than 10 days (none should match) - deleted = await RawPacketRepository.prune_old_undecrypted(10) - assert deleted == 0 - finally: - db._connection = original_conn - await conn.close() + # Prune packets older than 10 days (none should match) + deleted = await RawPacketRepository.prune_old_undecrypted(10) + assert deleted == 0 class TestMaintenanceEndpoint: """Test database maintenance endpoint.""" @pytest.mark.asyncio - async def test_maintenance_prunes_and_vacuums(self): + async def test_maintenance_prunes_and_vacuums(self, test_db): """Maintenance endpoint prunes old packets and runs vacuum.""" - import time - - import aiosqlite - - from app.database import db from app.routers.packets import MaintenanceRequest, run_maintenance - conn = await aiosqlite.connect(":memory:") - conn.row_factory = aiosqlite.Row - - await conn.execute(""" - CREATE TABLE raw_packets ( - id INTEGER PRIMARY KEY, - timestamp INTEGER NOT NULL, - data BLOB NOT NULL UNIQUE, - message_id INTEGER - ) - """) - now = int(time.time()) old_timestamp = now - (20 * 86400) # 20 days ago - # Insert old undecrypted packets (message_id NULL = undecrypted) - await conn.execute( - "INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)", - (old_timestamp, b"\x01\x02\x03"), - ) - await conn.execute( - "INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)", - (old_timestamp, b"\x04\x05\x06"), - ) - await conn.commit() + # Insert old undecrypted packets + await RawPacketRepository.create(b"\x01\x02\x03", old_timestamp) + await RawPacketRepository.create(b"\x04\x05\x06", old_timestamp) - original_conn = db._connection - db._connection = conn + request = MaintenanceRequest(prune_undecrypted_days=14) + result = await run_maintenance(request) - try: - request = MaintenanceRequest(prune_undecrypted_days=14) - result = await run_maintenance(request) - - assert result.packets_deleted == 2 - assert result.vacuumed is True - finally: - db._connection = original_conn - await conn.close() + assert result.packets_deleted == 2 + assert result.vacuumed is True class TestHealthEndpointDatabaseSize: diff --git a/tests/test_contacts_router.py b/tests/test_contacts_router.py index 1b9b6d8..a06f16e 100644 --- a/tests/test_contacts_router.py +++ b/tests/test_contacts_router.py @@ -3,25 +3,45 @@ Verifies the contact CRUD endpoints, sync, mark-read, delete, and add/remove from radio operations. -Uses FastAPI TestClient with mocked dependencies, consistent -with the test_api.py pattern. +Uses httpx.AsyncClient with real in-memory SQLite database. """ from unittest.mock import AsyncMock, MagicMock, patch +import httpx +import pytest from meshcore import EventType +from app.database import Database +from app.repository import ContactRepository + # Sample 64-char hex public keys for testing KEY_A = "aa" * 32 # aaaa...aa KEY_B = "bb" * 32 # bbbb...bb KEY_C = "cc" * 32 # cccc...cc -def _make_contact(public_key=KEY_A, name="Alice", **overrides): - """Create a mock Contact model instance.""" - from app.models import Contact +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + import app.repository as repo_module - defaults = { + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() + + +async def _insert_contact(public_key=KEY_A, name="Alice", on_radio=False, **overrides): + """Insert a contact into the test database.""" + data = { "public_key": public_key, "name": name, "type": 0, @@ -32,214 +52,135 @@ def _make_contact(public_key=KEY_A, name="Alice", **overrides): "lat": None, "lon": None, "last_seen": None, - "on_radio": False, + "on_radio": on_radio, "last_contacted": None, - "last_read_at": None, } - defaults.update(overrides) - return Contact(**defaults) + data.update(overrides) + await ContactRepository.upsert(data) + + +@pytest.fixture +def client(): + """Create an httpx AsyncClient for testing the app.""" + from app.main import app + + transport = httpx.ASGITransport(app=app) + return httpx.AsyncClient(transport=transport, base_url="http://test") class TestListContacts: """Test GET /api/contacts.""" - def test_list_returns_contacts(self): - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_list_returns_contacts(self, test_db, client): + await _insert_contact(KEY_A, "Alice") + await _insert_contact(KEY_B, "Bob") - contacts = [_make_contact(KEY_A, "Alice"), _make_contact(KEY_B, "Bob")] - - with patch( - "app.routers.contacts.ContactRepository.get_all", - new_callable=AsyncMock, - return_value=contacts, - ): - from app.main import app - - client = TestClient(app) - response = client.get("/api/contacts") + response = await client.get("/api/contacts") assert response.status_code == 200 data = response.json() assert len(data) == 2 - assert data[0]["public_key"] == KEY_A - assert data[1]["public_key"] == KEY_B + keys = {d["public_key"] for d in data} + assert KEY_A in keys + assert KEY_B in keys - def test_list_pagination_params(self): - """Pagination parameters are forwarded to repository.""" - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_list_pagination_params(self, test_db, client): + # Insert 3 contacts + await _insert_contact(KEY_A, "Alice") + await _insert_contact(KEY_B, "Bob") + await _insert_contact(KEY_C, "Carol") - with patch( - "app.routers.contacts.ContactRepository.get_all", - new_callable=AsyncMock, - return_value=[], - ) as mock_get_all: - from app.main import app - - client = TestClient(app) - response = client.get("/api/contacts?limit=5&offset=10") + response = await client.get("/api/contacts?limit=2&offset=0") assert response.status_code == 200 - mock_get_all.assert_called_once_with(limit=5, offset=10) + data = response.json() + assert len(data) == 2 class TestCreateContact: """Test POST /api/contacts.""" - def test_create_new_contact(self): - from fastapi.testclient import TestClient - - with ( - patch( - "app.routers.contacts.ContactRepository.get_by_key", - new_callable=AsyncMock, - return_value=None, - ), - patch( - "app.routers.contacts.ContactRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - patch( - "app.routers.contacts.MessageRepository.claim_prefix_messages", - new_callable=AsyncMock, - return_value=0, - ), - ): - from app.main import app - - client = TestClient(app) - response = client.post( - "/api/contacts", - json={"public_key": KEY_A, "name": "NewContact"}, - ) + @pytest.mark.asyncio + async def test_create_new_contact(self, test_db, client): + response = await client.post( + "/api/contacts", + json={"public_key": KEY_A, "name": "NewContact"}, + ) assert response.status_code == 200 data = response.json() assert data["public_key"] == KEY_A assert data["name"] == "NewContact" - mock_upsert.assert_called_once() - def test_create_invalid_hex(self): + # Verify in DB + contact = await ContactRepository.get_by_key(KEY_A) + assert contact is not None + assert contact.name == "NewContact" + + @pytest.mark.asyncio + async def test_create_invalid_hex(self, test_db, client): """Non-hex public key returns 400.""" - from fastapi.testclient import TestClient - - with patch( - "app.routers.contacts.ContactRepository.get_by_key", - new_callable=AsyncMock, - return_value=None, - ): - from app.main import app - - client = TestClient(app) - response = client.post( - "/api/contacts", - json={"public_key": "zz" * 32, "name": "Bad"}, - ) + response = await client.post( + "/api/contacts", + json={"public_key": "zz" * 32, "name": "Bad"}, + ) assert response.status_code == 400 assert "hex" in response.json()["detail"].lower() - def test_create_short_key_rejected(self): + @pytest.mark.asyncio + async def test_create_short_key_rejected(self, test_db, client): """Key shorter than 64 chars is rejected by pydantic validation.""" - from fastapi.testclient import TestClient - - from app.main import app - - client = TestClient(app) - response = client.post( + response = await client.post( "/api/contacts", json={"public_key": "aa" * 16, "name": "Short"}, ) assert response.status_code == 422 - def test_create_existing_updates_name(self): + @pytest.mark.asyncio + async def test_create_existing_updates_name(self, test_db, client): """Creating a contact that exists updates the name.""" - from fastapi.testclient import TestClient + await _insert_contact(KEY_A, "OldName") - existing = _make_contact(KEY_A, "OldName") - - with ( - patch( - "app.routers.contacts.ContactRepository.get_by_key", - new_callable=AsyncMock, - return_value=existing, - ), - patch( - "app.routers.contacts.ContactRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): - from app.main import app - - client = TestClient(app) - response = client.post( - "/api/contacts", - json={"public_key": KEY_A, "name": "NewName"}, - ) + response = await client.post( + "/api/contacts", + json={"public_key": KEY_A, "name": "NewName"}, + ) assert response.status_code == 200 - # Upsert called with new name - mock_upsert.assert_called_once() - upsert_data = mock_upsert.call_args[0][0] - assert upsert_data["name"] == "NewName" + # Verify name was updated in DB + contact = await ContactRepository.get_by_key(KEY_A) + assert contact.name == "NewName" class TestGetContact: """Test GET /api/contacts/{public_key}.""" - def test_get_existing(self): - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_get_existing(self, test_db, client): + await _insert_contact(KEY_A, "Alice") - contact = _make_contact(KEY_A, "Alice") - - with patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ): - from app.main import app - - client = TestClient(app) - response = client.get(f"/api/contacts/{KEY_A}") + response = await client.get(f"/api/contacts/{KEY_A}") assert response.status_code == 200 assert response.json()["name"] == "Alice" - def test_get_not_found(self): - from fastapi.testclient import TestClient - - with patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=None, - ): - from app.main import app - - client = TestClient(app) - response = client.get(f"/api/contacts/{KEY_A}") + @pytest.mark.asyncio + async def test_get_not_found(self, test_db, client): + response = await client.get(f"/api/contacts/{KEY_A}") assert response.status_code == 404 - def test_get_ambiguous_prefix_returns_409(self): - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_get_ambiguous_prefix_returns_409(self, test_db, client): + # Insert two contacts that share a prefix + await _insert_contact("abcd12" + "00" * 29, "ContactA") + await _insert_contact("abcd12" + "ff" * 29, "ContactB") - from app.repository import AmbiguousPublicKeyPrefixError - - with patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - side_effect=AmbiguousPublicKeyPrefixError( - "abcd12", - [ - "abcd120000000000000000000000000000000000000000000000000000000000", - "abcd12ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - ], - ), - ): - from app.main import app - - client = TestClient(app) - response = client.get("/api/contacts/abcd12") + response = await client.get("/api/contacts/abcd12") assert response.status_code == 409 assert "ambiguous" in response.json()["detail"].lower() @@ -248,43 +189,22 @@ class TestGetContact: class TestMarkRead: """Test POST /api/contacts/{public_key}/mark-read.""" - def test_mark_read_updates_timestamp(self): - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_mark_read_updates_timestamp(self, test_db, client): + await _insert_contact(KEY_A) - contact = _make_contact(KEY_A) - - with ( - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - patch( - "app.routers.contacts.ContactRepository.update_last_read_at", - new_callable=AsyncMock, - return_value=True, - ), - ): - from app.main import app - - client = TestClient(app) - response = client.post(f"/api/contacts/{KEY_A}/mark-read") + response = await client.post(f"/api/contacts/{KEY_A}/mark-read") assert response.status_code == 200 assert response.json()["status"] == "ok" - def test_mark_read_not_found(self): - from fastapi.testclient import TestClient + # Verify last_read_at was set in DB + contact = await ContactRepository.get_by_key(KEY_A) + assert contact.last_read_at is not None - with patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=None, - ): - from app.main import app - - client = TestClient(app) - response = client.post(f"/api/contacts/{KEY_A}/mark-read") + @pytest.mark.asyncio + async def test_mark_read_not_found(self, test_db, client): + response = await client.post(f"/api/contacts/{KEY_A}/mark-read") assert response.status_code == 404 @@ -292,79 +212,44 @@ class TestMarkRead: class TestDeleteContact: """Test DELETE /api/contacts/{public_key}.""" - def test_delete_existing(self): - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_delete_existing(self, test_db, client): + await _insert_contact(KEY_A) - contact = _make_contact(KEY_A) - - with ( - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - patch( - "app.routers.contacts.ContactRepository.delete", - new_callable=AsyncMock, - ), - patch("app.routers.contacts.radio_manager") as mock_rm, - ): + with patch("app.routers.contacts.radio_manager") as mock_rm: mock_rm.is_connected = False mock_rm.meshcore = None - from app.main import app - - client = TestClient(app) - response = client.delete(f"/api/contacts/{KEY_A}") + response = await client.delete(f"/api/contacts/{KEY_A}") assert response.status_code == 200 assert response.json()["status"] == "ok" - def test_delete_not_found(self): - from fastapi.testclient import TestClient + # Verify deleted from DB + contact = await ContactRepository.get_by_key(KEY_A) + assert contact is None - with patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=None, - ): - from app.main import app - - client = TestClient(app) - response = client.delete(f"/api/contacts/{KEY_A}") + @pytest.mark.asyncio + async def test_delete_not_found(self, test_db, client): + response = await client.delete(f"/api/contacts/{KEY_A}") assert response.status_code == 404 - def test_delete_removes_from_radio_if_connected(self): + @pytest.mark.asyncio + async def test_delete_removes_from_radio_if_connected(self, test_db, client): """When radio is connected and contact is on radio, remove it first.""" - from fastapi.testclient import TestClient - - contact = _make_contact(KEY_A, on_radio=True) + await _insert_contact(KEY_A, on_radio=True) mock_radio_contact = MagicMock() mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=mock_radio_contact) mock_mc.commands.remove_contact = AsyncMock() - with ( - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - patch( - "app.routers.contacts.ContactRepository.delete", - new_callable=AsyncMock, - ), - patch("app.routers.contacts.radio_manager") as mock_rm, - ): + with patch("app.routers.contacts.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.delete(f"/api/contacts/{KEY_A}") + response = await client.delete(f"/api/contacts/{KEY_A}") assert response.status_code == 200 mock_mc.commands.remove_contact.assert_called_once_with(mock_radio_contact) @@ -373,9 +258,8 @@ class TestDeleteContact: class TestSyncContacts: """Test POST /api/contacts/sync.""" - def test_sync_from_radio(self): - from fastapi.testclient import TestClient - + @pytest.mark.asyncio + async def test_sync_from_radio(self, test_db, client): mock_mc = MagicMock() mock_result = MagicMock() mock_result.type = EventType.OK @@ -385,35 +269,27 @@ class TestSyncContacts: } mock_mc.commands.get_contacts = AsyncMock(return_value=mock_result) - with ( - patch("app.dependencies.radio_manager") as mock_dep_rm, - patch( - "app.routers.contacts.ContactRepository.upsert", new_callable=AsyncMock - ) as mock_upsert, - ): + with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.post("/api/contacts/sync") + response = await client.post("/api/contacts/sync") assert response.status_code == 200 assert response.json()["synced"] == 2 - assert mock_upsert.call_count == 2 - def test_sync_requires_connection(self): - from fastapi.testclient import TestClient + # Verify contacts are in real DB + alice = await ContactRepository.get_by_key(KEY_A) + assert alice is not None + assert alice.name == "Alice" + @pytest.mark.asyncio + async def test_sync_requires_connection(self, test_db, client): with patch("app.dependencies.radio_manager") as mock_rm: mock_rm.is_connected = False mock_rm.meshcore = None - from app.main import app - - client = TestClient(app) - response = client.post("/api/contacts/sync") + response = await client.post("/api/contacts/sync") assert response.status_code == 503 @@ -421,71 +297,50 @@ class TestSyncContacts: class TestAddRemoveRadio: """Test add-to-radio and remove-from-radio endpoints.""" - def test_add_to_radio(self): - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_add_to_radio(self, test_db, client): + await _insert_contact(KEY_A) - contact = _make_contact(KEY_A) mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None) # Not on radio mock_result = MagicMock() mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - with ( - patch("app.dependencies.radio_manager") as mock_dep_rm, - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - patch( - "app.routers.contacts.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ) as mock_set_on_radio, - ): + with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.post(f"/api/contacts/{KEY_A}/add-to-radio") + response = await client.post(f"/api/contacts/{KEY_A}/add-to-radio") assert response.status_code == 200 mock_mc.commands.add_contact.assert_called_once() - mock_set_on_radio.assert_called_once_with(KEY_A, True) - def test_add_already_on_radio(self): + # Verify on_radio flag updated in DB + contact = await ContactRepository.get_by_key(KEY_A) + assert contact.on_radio is True + + @pytest.mark.asyncio + async def test_add_already_on_radio(self, test_db, client): """Adding a contact already on radio returns ok without calling add_contact.""" - from fastapi.testclient import TestClient + await _insert_contact(KEY_A, on_radio=True) - contact = _make_contact(KEY_A, on_radio=True) mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # On radio - with ( - patch("app.dependencies.radio_manager") as mock_dep_rm, - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - ): + with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.post(f"/api/contacts/{KEY_A}/add-to-radio") + response = await client.post(f"/api/contacts/{KEY_A}/add-to-radio") assert response.status_code == 200 assert "already" in response.json()["message"].lower() - def test_remove_from_radio(self): - from fastapi.testclient import TestClient + @pytest.mark.asyncio + async def test_remove_from_radio(self, test_db, client): + await _insert_contact(KEY_A, on_radio=True) - contact = _make_contact(KEY_A, on_radio=True) mock_radio_contact = MagicMock() mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=mock_radio_contact) @@ -493,63 +348,37 @@ class TestAddRemoveRadio: mock_result.type = EventType.OK mock_mc.commands.remove_contact = AsyncMock(return_value=mock_result) - with ( - patch("app.dependencies.radio_manager") as mock_dep_rm, - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - patch( - "app.routers.contacts.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ) as mock_set_on_radio, - ): + with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.post(f"/api/contacts/{KEY_A}/remove-from-radio") + response = await client.post(f"/api/contacts/{KEY_A}/remove-from-radio") assert response.status_code == 200 mock_mc.commands.remove_contact.assert_called_once_with(mock_radio_contact) - mock_set_on_radio.assert_called_once_with(KEY_A, False) - def test_add_requires_connection(self): - from fastapi.testclient import TestClient + # Verify on_radio flag updated in DB + contact = await ContactRepository.get_by_key(KEY_A) + assert contact.on_radio is False + @pytest.mark.asyncio + async def test_add_requires_connection(self, test_db, client): with patch("app.dependencies.radio_manager") as mock_rm: mock_rm.is_connected = False mock_rm.meshcore = None - from app.main import app - - client = TestClient(app) - response = client.post(f"/api/contacts/{KEY_A}/add-to-radio") + response = await client.post(f"/api/contacts/{KEY_A}/add-to-radio") assert response.status_code == 503 - def test_remove_not_found(self): - from fastapi.testclient import TestClient - + @pytest.mark.asyncio + async def test_remove_not_found(self, test_db, client): mock_mc = MagicMock() - with ( - patch("app.dependencies.radio_manager") as mock_dep_rm, - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=None, - ), - ): + with patch("app.dependencies.radio_manager") as mock_dep_rm: mock_dep_rm.is_connected = True mock_dep_rm.meshcore = mock_mc - from app.main import app - - client = TestClient(app) - response = client.post(f"/api/contacts/{KEY_A}/remove-from-radio") + response = await client.post(f"/api/contacts/{KEY_A}/remove-from-radio") assert response.status_code == 404 diff --git a/tests/test_event_handlers.py b/tests/test_event_handlers.py index ea8cf2d..c28b5ad 100644 --- a/tests/test_event_handlers.py +++ b/tests/test_event_handlers.py @@ -1,7 +1,7 @@ """Tests for event handler logic. These tests verify the ACK tracking mechanism for direct message -delivery confirmation. +delivery confirmation, contact message handling, and event registration. """ import time @@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from app.database import Database from app.event_handlers import ( _active_subscriptions, _cleanup_expired_acks, @@ -16,7 +17,28 @@ from app.event_handlers import ( register_event_handlers, track_pending_ack, ) -from app.repository import AmbiguousPublicKeyPrefixError +from app.repository import ( + ContactRepository, + MessageRepository, +) + + +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + import app.repository as repo_module + + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() @pytest.fixture(autouse=True) @@ -79,93 +101,93 @@ class TestAckEventHandler: """Test the on_ack event handler.""" @pytest.mark.asyncio - async def test_ack_matches_pending_message(self): + async def test_ack_matches_pending_message(self, test_db): """Matching ACK code updates message and broadcasts.""" from app.event_handlers import on_ack - # Setup pending ACK - track_pending_ack("deadbeef", message_id=123, timeout_ms=10000) + # Insert a real message to get a valid ID + msg_id = await MessageRepository.create( + msg_type="PRIV", + text="Hello", + received_at=1700000000, + conversation_key="aa" * 32, + sender_timestamp=1700000000, + ) - # Mock dependencies - with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.broadcast_event") as mock_broadcast, - ): - mock_repo.increment_ack_count = AsyncMock(return_value=1) + # Setup pending ACK with the real message ID + track_pending_ack("deadbeef", message_id=msg_id, timeout_ms=10000) + + with patch("app.event_handlers.broadcast_event") as mock_broadcast: - # Create mock event class MockEvent: payload = {"code": "deadbeef"} await on_ack(MockEvent()) - # Verify ack count incremented - mock_repo.increment_ack_count.assert_called_once_with(123) + # Verify ack count incremented (real DB) + ack_count = await MessageRepository.get_ack_count(msg_id) + assert ack_count == 1 # Verify broadcast sent with ack_count mock_broadcast.assert_called_once_with( - "message_acked", {"message_id": 123, "ack_count": 1} + "message_acked", {"message_id": msg_id, "ack_count": 1} ) # Verify pending ACK removed assert "deadbeef" not in _pending_acks @pytest.mark.asyncio - async def test_ack_no_match_does_nothing(self): + async def test_ack_no_match_does_nothing(self, test_db): """Non-matching ACK code is ignored.""" from app.event_handlers import on_ack - track_pending_ack("expected", message_id=1, timeout_ms=10000) + msg_id = await MessageRepository.create( + msg_type="PRIV", + text="Hello", + received_at=1700000000, + conversation_key="aa" * 32, + sender_timestamp=1700000000, + ) + track_pending_ack("expected", message_id=msg_id, timeout_ms=10000) - with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.broadcast_event") as mock_broadcast, - ): - mock_repo.increment_ack_count = AsyncMock() + with patch("app.event_handlers.broadcast_event") as mock_broadcast: class MockEvent: payload = {"code": "different"} await on_ack(MockEvent()) - mock_repo.increment_ack_count.assert_not_called() + # Ack count should remain 0 + ack_count = await MessageRepository.get_ack_count(msg_id) + assert ack_count == 0 + mock_broadcast.assert_not_called() assert "expected" in _pending_acks @pytest.mark.asyncio - async def test_ack_empty_code_ignored(self): + async def test_ack_empty_code_ignored(self, test_db): """ACK with empty code is ignored.""" from app.event_handlers import on_ack - with patch("app.event_handlers.MessageRepository") as mock_repo: - mock_repo.increment_ack_count = AsyncMock() + with patch("app.event_handlers.broadcast_event") as mock_broadcast: class MockEvent: payload = {"code": ""} await on_ack(MockEvent()) - mock_repo.increment_ack_count.assert_not_called() + mock_broadcast.assert_not_called() class TestContactMessageCLIFiltering: - """Test that CLI responses (txt_type=1) are filtered out. - - This prevents duplicate messages when sending CLI commands to repeaters: - the command endpoint returns the response directly, so we must NOT also - persist/broadcast it via the normal message handler. - """ + """Test that CLI responses (txt_type=1) are filtered out.""" @pytest.mark.asyncio - async def test_cli_response_skipped_not_stored(self): + async def test_cli_response_skipped_not_stored(self, test_db): """CLI responses (txt_type=1) are not stored in database.""" from app.event_handlers import on_contact_message - with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.ContactRepository") as mock_contact_repo, - patch("app.event_handlers.broadcast_event") as mock_broadcast, - ): + with patch("app.event_handlers.broadcast_event") as mock_broadcast: class MockEvent: payload = { @@ -177,15 +199,15 @@ class TestContactMessageCLIFiltering: await on_contact_message(MockEvent()) - # Should NOT store in database - mock_repo.create.assert_not_called() # Should NOT broadcast via WebSocket mock_broadcast.assert_not_called() - # Should NOT update contact last_contacted - mock_contact_repo.update_last_contacted.assert_not_called() + + # Should NOT have stored anything in DB + messages = await MessageRepository.get_all() + assert len(messages) == 0 @pytest.mark.asyncio - async def test_normal_message_schedules_bot_in_background(self): + async def test_normal_message_schedules_bot_in_background(self, test_db): """Normal messages should schedule bot execution without blocking.""" from app.event_handlers import on_contact_message @@ -194,14 +216,10 @@ class TestContactMessageCLIFiltering: return MagicMock() with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.ContactRepository") as mock_contact_repo, patch("app.event_handlers.broadcast_event"), patch("app.event_handlers.asyncio.create_task", side_effect=_capture_task) as mock_task, patch("app.bot.run_bot_for_message", new_callable=AsyncMock) as mock_bot, ): - mock_repo.create = AsyncMock(return_value=42) - mock_contact_repo.get_by_key_or_prefix = AsyncMock(return_value=None) class MockEvent: payload = { @@ -217,18 +235,14 @@ class TestContactMessageCLIFiltering: mock_bot.assert_called_once() @pytest.mark.asyncio - async def test_normal_message_still_processed(self): + async def test_normal_message_still_processed(self, test_db): """Normal messages (txt_type=0) are still processed normally.""" from app.event_handlers import on_contact_message with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.ContactRepository") as mock_contact_repo, patch("app.event_handlers.broadcast_event") as mock_broadcast, patch("app.bot.run_bot_for_message", new_callable=AsyncMock), ): - mock_repo.create = AsyncMock(return_value=42) - mock_contact_repo.get_by_key_or_prefix = AsyncMock(return_value=None) class MockEvent: payload = { @@ -240,24 +254,23 @@ class TestContactMessageCLIFiltering: await on_contact_message(MockEvent()) - # SHOULD store in database - mock_repo.create.assert_called_once() + # SHOULD be stored in database + messages = await MessageRepository.get_all() + assert len(messages) == 1 + assert messages[0].text == "Hello, this is a normal message" + # SHOULD broadcast via WebSocket mock_broadcast.assert_called_once() @pytest.mark.asyncio - async def test_broadcast_payload_has_correct_acked_type(self): + async def test_broadcast_payload_has_correct_acked_type(self, test_db): """Broadcast payload should have acked as integer 0, not boolean False.""" from app.event_handlers import on_contact_message with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.ContactRepository") as mock_contact_repo, patch("app.event_handlers.broadcast_event") as mock_broadcast, patch("app.bot.run_bot_for_message", new_callable=AsyncMock), ): - mock_repo.create = AsyncMock(return_value=42) - mock_contact_repo.get_by_key_or_prefix = AsyncMock(return_value=None) class MockEvent: payload = { @@ -281,18 +294,14 @@ class TestContactMessageCLIFiltering: assert isinstance(payload["acked"], int) @pytest.mark.asyncio - async def test_missing_txt_type_defaults_to_normal(self): + async def test_missing_txt_type_defaults_to_normal(self, test_db): """Messages without txt_type field are treated as normal (not filtered).""" from app.event_handlers import on_contact_message with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.ContactRepository") as mock_contact_repo, patch("app.event_handlers.broadcast_event"), patch("app.bot.run_bot_for_message", new_callable=AsyncMock), ): - mock_repo.create = AsyncMock(return_value=42) - mock_contact_repo.get_by_key_or_prefix = AsyncMock(return_value=None) class MockEvent: payload = { @@ -305,29 +314,36 @@ class TestContactMessageCLIFiltering: await on_contact_message(MockEvent()) # SHOULD still be processed (defaults to txt_type=0) - mock_repo.create.assert_called_once() + messages = await MessageRepository.get_all() + assert len(messages) == 1 @pytest.mark.asyncio - async def test_ambiguous_prefix_stores_dm_under_prefix(self): + async def test_ambiguous_prefix_stores_dm_under_prefix(self, test_db): """Ambiguous sender prefixes should still be stored under the prefix key.""" from app.event_handlers import on_contact_message + # Insert two contacts that share the same prefix to trigger ambiguity + await ContactRepository.upsert( + { + "public_key": "abc123" + "00" * 29, + "name": "ContactA", + "type": 1, + "flags": 0, + } + ) + await ContactRepository.upsert( + { + "public_key": "abc123" + "ff" * 29, + "name": "ContactB", + "type": 1, + "flags": 0, + } + ) + with ( - patch("app.event_handlers.MessageRepository") as mock_repo, - patch("app.event_handlers.ContactRepository") as mock_contact_repo, patch("app.event_handlers.broadcast_event") as mock_broadcast, patch("app.bot.run_bot_for_message", new_callable=AsyncMock), ): - mock_repo.create = AsyncMock(return_value=77) - mock_contact_repo.get_by_key_or_prefix = AsyncMock( - side_effect=AmbiguousPublicKeyPrefixError( - "abc123", - [ - "abc1230000000000000000000000000000000000000000000000000000000000", - "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - ], - ) - ) class MockEvent: payload = { @@ -339,8 +355,10 @@ class TestContactMessageCLIFiltering: await on_contact_message(MockEvent()) - mock_repo.create.assert_called_once() - assert mock_repo.create.await_args.kwargs["conversation_key"] == "abc123" + # Should store in DB under the prefix key + messages = await MessageRepository.get_all() + assert len(messages) == 1 + assert messages[0].conversation_key == "abc123" mock_broadcast.assert_called_once() _, payload = mock_broadcast.call_args.args @@ -414,7 +432,6 @@ class TestEventHandlerRegistration: mock_meshcore.subscribe.return_value = MagicMock() # Create subscriptions where unsubscribe raises an exception - # (simulates old dispatcher being in a bad state after reconnect) bad_sub = MagicMock() bad_sub.unsubscribe.side_effect = RuntimeError("Dispatcher is dead") _active_subscriptions.append(bad_sub) @@ -437,86 +454,92 @@ class TestOnPathUpdate: """Test the on_path_update event handler.""" @pytest.mark.asyncio - async def test_updates_path_for_existing_contact(self): + async def test_updates_path_for_existing_contact(self, test_db): """Path is updated when the contact exists in the database.""" from app.event_handlers import on_path_update - mock_contact = MagicMock() - mock_contact.public_key = "aa" * 32 + await ContactRepository.upsert( + { + "public_key": "aa" * 32, + "name": "Alice", + "type": 1, + "flags": 0, + } + ) - with patch("app.event_handlers.ContactRepository") as mock_repo: - mock_repo.get_by_key_prefix = AsyncMock(return_value=mock_contact) - mock_repo.update_path = AsyncMock() + class MockEvent: + payload = { + "pubkey_prefix": "aaaaaa", + "path": "0102", + "path_len": 2, + } - class MockEvent: - payload = { - "pubkey_prefix": "aaaaaa", - "path": "0102", - "path_len": 2, - } + await on_path_update(MockEvent()) - await on_path_update(MockEvent()) - - mock_repo.get_by_key_prefix.assert_called_once_with("aaaaaa") - mock_repo.update_path.assert_called_once_with("aa" * 32, "0102", 2) + # Verify path was updated in DB + contact = await ContactRepository.get_by_key("aa" * 32) + assert contact is not None + assert contact.last_path == "0102" + assert contact.last_path_len == 2 @pytest.mark.asyncio - async def test_does_nothing_when_contact_not_found(self): + async def test_does_nothing_when_contact_not_found(self, test_db): """No update is attempted when the contact is not in the database.""" from app.event_handlers import on_path_update - with patch("app.event_handlers.ContactRepository") as mock_repo: - mock_repo.get_by_key_prefix = AsyncMock(return_value=None) - mock_repo.update_path = AsyncMock() + class MockEvent: + payload = { + "pubkey_prefix": "unknown", + "path": "0102", + "path_len": 2, + } - class MockEvent: - payload = { - "pubkey_prefix": "unknown", - "path": "0102", - "path_len": 2, - } - - await on_path_update(MockEvent()) - - mock_repo.get_by_key_prefix.assert_called_once_with("unknown") - mock_repo.update_path.assert_not_called() + # Should not raise + await on_path_update(MockEvent()) @pytest.mark.asyncio - async def test_uses_defaults_for_missing_payload_fields(self): + async def test_uses_defaults_for_missing_payload_fields(self, test_db): """Missing payload fields fall back to defaults (empty path, -1 length).""" from app.event_handlers import on_path_update - mock_contact = MagicMock() - mock_contact.public_key = "bb" * 32 + await ContactRepository.upsert( + { + "public_key": "bb" * 32, + "name": "Bob", + "type": 1, + "flags": 0, + } + ) - with patch("app.event_handlers.ContactRepository") as mock_repo: - mock_repo.get_by_key_prefix = AsyncMock(return_value=mock_contact) - mock_repo.update_path = AsyncMock() + class MockEvent: + payload = {} - class MockEvent: - payload = {} + await on_path_update(MockEvent()) - await on_path_update(MockEvent()) - - mock_repo.get_by_key_prefix.assert_called_once_with("") - mock_repo.update_path.assert_called_once_with("bb" * 32, "", -1) + # With empty prefix, get_by_key_prefix("") should return None since + # no key starts with "" uniquely (if multiple contacts exist) or + # the single contact if only one. But with prefix="", the LIKE query + # matches all contacts. With exactly one contact, it returns it. + # The update_path call sets path="" and path_len=-1. + contact = await ContactRepository.get_by_key("bb" * 32) + assert contact is not None + assert contact.last_path == "" + assert contact.last_path_len == -1 class TestOnNewContact: """Test the on_new_contact event handler.""" @pytest.mark.asyncio - async def test_creates_contact_and_broadcasts(self): + async def test_creates_contact_and_broadcasts(self, test_db): """Valid new contact is upserted and broadcast via WebSocket.""" from app.event_handlers import on_new_contact with ( - patch("app.event_handlers.ContactRepository") as mock_repo, patch("app.event_handlers.broadcast_event") as mock_broadcast, patch("app.event_handlers.time") as mock_time, ): mock_time.time.return_value = 1700000000 - mock_repo.upsert = AsyncMock() class MockEvent: payload = { @@ -528,13 +551,12 @@ class TestOnNewContact: await on_new_contact(MockEvent()) - mock_repo.upsert.assert_called_once() - upserted_data = mock_repo.upsert.call_args[0][0] - - assert upserted_data["public_key"] == "cc" * 32 - assert upserted_data["name"] == "Charlie" - assert upserted_data["on_radio"] is True - assert upserted_data["last_seen"] == 1700000000 + # Verify contact was created in real DB + contact = await ContactRepository.get_by_key("cc" * 32) + assert contact is not None + assert contact.name == "Charlie" + assert contact.on_radio is True + assert contact.last_seen == 1700000000 mock_broadcast.assert_called_once() event_type, contact_data = mock_broadcast.call_args[0] @@ -542,55 +564,50 @@ class TestOnNewContact: assert contact_data["public_key"] == "cc" * 32 @pytest.mark.asyncio - async def test_returns_early_on_empty_public_key(self): + async def test_returns_early_on_empty_public_key(self, test_db): """Handler exits without upserting when public_key is empty.""" from app.event_handlers import on_new_contact - with ( - patch("app.event_handlers.ContactRepository") as mock_repo, - patch("app.event_handlers.broadcast_event") as mock_broadcast, - ): - mock_repo.upsert = AsyncMock() + with patch("app.event_handlers.broadcast_event") as mock_broadcast: class MockEvent: payload = {"public_key": "", "adv_name": "Ghost"} await on_new_contact(MockEvent()) - mock_repo.upsert.assert_not_called() mock_broadcast.assert_not_called() + # No contacts should exist + contacts = await ContactRepository.get_all() + assert len(contacts) == 0 + @pytest.mark.asyncio - async def test_returns_early_on_missing_public_key(self): + async def test_returns_early_on_missing_public_key(self, test_db): """Handler exits without upserting when public_key field is absent.""" from app.event_handlers import on_new_contact - with ( - patch("app.event_handlers.ContactRepository") as mock_repo, - patch("app.event_handlers.broadcast_event") as mock_broadcast, - ): - mock_repo.upsert = AsyncMock() + with patch("app.event_handlers.broadcast_event") as mock_broadcast: class MockEvent: payload = {"adv_name": "NoKey"} await on_new_contact(MockEvent()) - mock_repo.upsert.assert_not_called() mock_broadcast.assert_not_called() + contacts = await ContactRepository.get_all() + assert len(contacts) == 0 + @pytest.mark.asyncio - async def test_sets_on_radio_true(self): + async def test_sets_on_radio_true(self, test_db): """Contact data passed to upsert has on_radio=True.""" from app.event_handlers import on_new_contact with ( - patch("app.event_handlers.ContactRepository") as mock_repo, patch("app.event_handlers.broadcast_event"), patch("app.event_handlers.time") as mock_time, ): mock_time.time.return_value = 1700000000 - mock_repo.upsert = AsyncMock() class MockEvent: payload = { @@ -602,21 +619,20 @@ class TestOnNewContact: await on_new_contact(MockEvent()) - upserted_data = mock_repo.upsert.call_args[0][0] - assert upserted_data["on_radio"] is True + contact = await ContactRepository.get_by_key("dd" * 32) + assert contact is not None + assert contact.on_radio is True @pytest.mark.asyncio - async def test_sets_last_seen_to_current_timestamp(self): + async def test_sets_last_seen_to_current_timestamp(self, test_db): """Contact data includes last_seen set to current time.""" from app.event_handlers import on_new_contact with ( - patch("app.event_handlers.ContactRepository") as mock_repo, patch("app.event_handlers.broadcast_event"), patch("app.event_handlers.time") as mock_time, ): mock_time.time.return_value = 1700099999 - mock_repo.upsert = AsyncMock() class MockEvent: payload = { @@ -628,5 +644,6 @@ class TestOnNewContact: await on_new_contact(MockEvent()) - upserted_data = mock_repo.upsert.call_args[0][0] - assert upserted_data["last_seen"] == 1700099999 + contact = await ContactRepository.get_by_key("ee" * 32) + assert contact is not None + assert contact.last_seen == 1700099999 diff --git a/tests/test_radio_sync.py b/tests/test_radio_sync.py index 729e509..0ea602f 100644 --- a/tests/test_radio_sync.py +++ b/tests/test_radio_sync.py @@ -1,7 +1,7 @@ """Tests for radio_sync module. -These tests verify the polling pause mechanism that prevents -message polling from interfering with repeater CLI operations. +These tests verify the polling pause mechanism, radio time sync, +contact/channel sync operations, and default channel management. """ from unittest.mock import AsyncMock, MagicMock, patch @@ -9,13 +9,38 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from meshcore import EventType -from app.models import Contact, Favorite +from app.database import Database +from app.models import Favorite from app.radio_sync import ( is_polling_paused, pause_polling, sync_radio_time, sync_recent_contacts_to_radio, ) +from app.repository import ( + AppSettingsRepository, + ChannelRepository, + ContactRepository, + MessageRepository, +) + + +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + import app.repository as repo_module + + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() @pytest.fixture(autouse=True) @@ -30,6 +55,37 @@ def reset_sync_state(): radio_sync._last_contact_sync = 0.0 +KEY_A = "aa" * 32 +KEY_B = "bb" * 32 + + +async def _insert_contact( + public_key=KEY_A, + name="Alice", + on_radio=False, + contact_type=0, + last_contacted=None, + last_advert=None, +): + """Insert a contact into the test database.""" + await ContactRepository.upsert( + { + "public_key": public_key, + "name": name, + "type": contact_type, + "flags": 0, + "last_path": None, + "last_path_len": -1, + "last_advert": last_advert, + "lat": None, + "lon": None, + "last_seen": None, + "on_radio": on_radio, + "last_contacted": last_contacted, + } + ) + + class TestPollingPause: """Test the polling pause mechanism.""" @@ -165,38 +221,14 @@ class TestSyncRadioTime: assert result is False -KEY_A = "aa" * 32 -KEY_B = "bb" * 32 - - -def _make_contact(public_key=KEY_A, name="Alice", on_radio=False, **overrides): - """Create a Contact model instance for testing.""" - defaults = { - "public_key": public_key, - "name": name, - "type": 0, - "flags": 0, - "last_path": None, - "last_path_len": -1, - "last_advert": None, - "lat": None, - "lon": None, - "last_seen": None, - "on_radio": on_radio, - "last_contacted": None, - "last_read_at": None, - } - defaults.update(overrides) - return Contact(**defaults) - - class TestSyncRecentContactsToRadio: """Test the sync_recent_contacts_to_radio function.""" @pytest.mark.asyncio - async def test_loads_contacts_not_on_radio(self): + async def test_loads_contacts_not_on_radio(self, test_db): """Contacts not on radio are added via add_contact.""" - contacts = [_make_contact(KEY_A, "Alice"), _make_contact(KEY_B, "Bob")] + await _insert_contact(KEY_A, "Alice", last_contacted=2000) + await _insert_contact(KEY_B, "Bob", last_contacted=1000) mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None) @@ -204,40 +236,31 @@ class TestSyncRecentContactsToRadio: mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 200 - mock_settings.favorites = [] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=contacts, - ), - patch( - "app.radio_sync.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ) as mock_set_on_radio, - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc result = await sync_recent_contacts_to_radio() assert result["loaded"] == 2 - assert mock_set_on_radio.call_count == 2 + # Verify contacts are now marked as on_radio in DB + alice = await ContactRepository.get_by_key(KEY_A) + bob = await ContactRepository.get_by_key(KEY_B) + assert alice.on_radio is True + assert bob.on_radio is True @pytest.mark.asyncio - async def test_favorites_loaded_before_recent_contacts(self): + async def test_favorites_loaded_before_recent_contacts(self, test_db): """Favorite contacts are loaded first, then recents until limit.""" - favorite_contact = _make_contact(KEY_A, "Alice") - recent_contacts = [_make_contact(KEY_B, "Bob"), _make_contact("cc" * 32, "Carol")] + await _insert_contact(KEY_A, "Alice", last_contacted=100) + await _insert_contact(KEY_B, "Bob", last_contacted=2000) + await _insert_contact("cc" * 32, "Carol", last_contacted=1000) + + # Set max_radio_contacts=2 and add KEY_A as favorite + await AppSettingsRepository.update( + max_radio_contacts=2, + favorites=[Favorite(type="contact", id=KEY_A)], + ) mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None) @@ -245,49 +268,29 @@ class TestSyncRecentContactsToRadio: mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 2 - mock_settings.favorites = [Favorite(type="contact", id=KEY_A)] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=favorite_contact, - ) as mock_get_by_key_or_prefix, - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=recent_contacts, - ), - patch( - "app.radio_sync.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ), - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc result = await sync_recent_contacts_to_radio() assert result["loaded"] == 2 - mock_get_by_key_or_prefix.assert_called_once_with(KEY_A) + # KEY_A (favorite) should be loaded first, then KEY_B (most recent) loaded_keys = [ call.args[0]["public_key"] for call in mock_mc.commands.add_contact.call_args_list ] assert loaded_keys == [KEY_A, KEY_B] @pytest.mark.asyncio - async def test_favorite_contact_not_loaded_twice_if_also_recent(self): + async def test_favorite_contact_not_loaded_twice_if_also_recent(self, test_db): """A favorite contact that is also recent is loaded only once.""" - favorite_contact = _make_contact(KEY_A, "Alice") - recent_contacts = [favorite_contact, _make_contact(KEY_B, "Bob")] + await _insert_contact(KEY_A, "Alice", last_contacted=2000) + await _insert_contact(KEY_B, "Bob", last_contacted=1000) + + await AppSettingsRepository.update( + max_radio_contacts=2, + favorites=[Favorite(type="contact", id=KEY_A)], + ) mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None) @@ -295,32 +298,7 @@ class TestSyncRecentContactsToRadio: mock_result.type = EventType.OK mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 2 - mock_settings.favorites = [Favorite(type="contact", id=KEY_A)] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=favorite_contact, - ), - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=recent_contacts, - ), - patch( - "app.radio_sync.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ), - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -333,35 +311,15 @@ class TestSyncRecentContactsToRadio: assert loaded_keys == [KEY_A, KEY_B] @pytest.mark.asyncio - async def test_skips_contacts_already_on_radio(self): + async def test_skips_contacts_already_on_radio(self, test_db): """Contacts already on radio are counted but not re-added.""" - contacts = [_make_contact(KEY_A, "Alice", on_radio=True)] + await _insert_contact(KEY_A, "Alice", on_radio=True) mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # Found mock_mc.commands.add_contact = AsyncMock() - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 200 - mock_settings.favorites = [] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=contacts, - ), - patch( - "app.radio_sync.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ), - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -372,28 +330,12 @@ class TestSyncRecentContactsToRadio: mock_mc.commands.add_contact.assert_not_called() @pytest.mark.asyncio - async def test_throttled_when_called_quickly(self): + async def test_throttled_when_called_quickly(self, test_db): """Second call within throttle window returns throttled result.""" mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None) - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 200 - mock_settings.favorites = [] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=[], - ), - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -407,27 +349,11 @@ class TestSyncRecentContactsToRadio: assert result2["loaded"] == 0 @pytest.mark.asyncio - async def test_force_bypasses_throttle(self): + async def test_force_bypasses_throttle(self, test_db): """force=True bypasses the throttle window.""" mock_mc = MagicMock() - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 200 - mock_settings.favorites = [] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=[], - ), - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -451,34 +377,14 @@ class TestSyncRecentContactsToRadio: assert "error" in result @pytest.mark.asyncio - async def test_marks_on_radio_when_found_but_not_flagged(self): + async def test_marks_on_radio_when_found_but_not_flagged(self, test_db): """Contact found on radio but not flagged gets set_on_radio(True).""" - contact = _make_contact(KEY_A, "Alice", on_radio=False) + await _insert_contact(KEY_A, "Alice", on_radio=False) mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=MagicMock()) # Found - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 200 - mock_settings.favorites = [] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=[contact], - ), - patch( - "app.radio_sync.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ) as mock_set_on_radio, - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -486,12 +392,13 @@ class TestSyncRecentContactsToRadio: assert result["already_on_radio"] == 1 # Should update the flag since contact.on_radio was False - mock_set_on_radio.assert_called_once_with(KEY_A, True) + contact = await ContactRepository.get_by_key(KEY_A) + assert contact.on_radio is True @pytest.mark.asyncio - async def test_handles_add_failure(self): + async def test_handles_add_failure(self, test_db): """Failed add_contact increments the failed counter.""" - contacts = [_make_contact(KEY_A, "Alice")] + await _insert_contact(KEY_A, "Alice") mock_mc = MagicMock() mock_mc.get_contact_by_key_prefix = MagicMock(return_value=None) @@ -500,27 +407,7 @@ class TestSyncRecentContactsToRadio: mock_result.payload = {"error": "Radio full"} mock_mc.commands.add_contact = AsyncMock(return_value=mock_result) - mock_settings = MagicMock() - mock_settings.max_radio_contacts = 200 - mock_settings.favorites = [] - - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.get_recent_non_repeaters", - new_callable=AsyncMock, - return_value=contacts, - ), - patch( - "app.radio_sync.ContactRepository.set_on_radio", - new_callable=AsyncMock, - ), - patch( - "app.radio_sync.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=mock_settings, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -549,7 +436,7 @@ class TestSyncAndOffloadContacts: assert "error" in result @pytest.mark.asyncio - async def test_syncs_and_removes_contacts(self): + async def test_syncs_and_removes_contacts(self, test_db): """Contacts are upserted to DB and removed from radio.""" from app.radio_sync import sync_and_offload_contacts @@ -569,18 +456,7 @@ class TestSyncAndOffloadContacts: mock_mc.commands.get_contacts = AsyncMock(return_value=mock_get_result) mock_mc.commands.remove_contact = AsyncMock(return_value=mock_remove_result) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - patch( - "app.radio_sync.MessageRepository.claim_prefix_messages", - new_callable=AsyncMock, - return_value=0, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -588,14 +464,29 @@ class TestSyncAndOffloadContacts: assert result["synced"] == 2 assert result["removed"] == 2 - assert mock_upsert.call_count == 2 - assert mock_mc.commands.remove_contact.call_count == 2 + + # Verify contacts are in real DB + alice = await ContactRepository.get_by_key(KEY_A) + bob = await ContactRepository.get_by_key(KEY_B) + assert alice is not None + assert alice.name == "Alice" + assert bob is not None + assert bob.name == "Bob" @pytest.mark.asyncio - async def test_claims_prefix_messages_for_each_contact(self): + async def test_claims_prefix_messages_for_each_contact(self, test_db): """claim_prefix_messages is called for each synced contact.""" from app.radio_sync import sync_and_offload_contacts + # Pre-insert a message with a prefix key that matches KEY_A + await MessageRepository.create( + msg_type="PRIV", + text="Hello from prefix", + received_at=1700000000, + conversation_key=KEY_A[:12], + sender_timestamp=1700000000, + ) + contact_payload = {KEY_A: {"adv_name": "Alice", "type": 1, "flags": 0}} mock_get_result = MagicMock() @@ -609,27 +500,19 @@ class TestSyncAndOffloadContacts: mock_mc.commands.get_contacts = AsyncMock(return_value=mock_get_result) mock_mc.commands.remove_contact = AsyncMock(return_value=mock_remove_result) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.upsert", - new_callable=AsyncMock, - ), - patch( - "app.radio_sync.MessageRepository.claim_prefix_messages", - new_callable=AsyncMock, - return_value=3, - ) as mock_claim, - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc await sync_and_offload_contacts() - mock_claim.assert_called_once_with(KEY_A.lower()) + # Verify the prefix message was claimed (promoted to full key) + messages = await MessageRepository.get_all(conversation_key=KEY_A) + assert len(messages) == 1 + assert messages[0].conversation_key == KEY_A.lower() @pytest.mark.asyncio - async def test_handles_remove_failure_gracefully(self): + async def test_handles_remove_failure_gracefully(self, test_db): """Failed remove_contact logs warning but continues to next contact.""" from app.radio_sync import sync_and_offload_contacts @@ -654,18 +537,7 @@ class TestSyncAndOffloadContacts: # First remove fails, second succeeds mock_mc.commands.remove_contact = AsyncMock(side_effect=[mock_fail_result, mock_ok_result]) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.upsert", - new_callable=AsyncMock, - ), - patch( - "app.radio_sync.MessageRepository.claim_prefix_messages", - new_callable=AsyncMock, - return_value=0, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -676,7 +548,7 @@ class TestSyncAndOffloadContacts: assert result["removed"] == 1 @pytest.mark.asyncio - async def test_handles_remove_exception_gracefully(self): + async def test_handles_remove_exception_gracefully(self, test_db): """Exception during remove_contact is caught and processing continues.""" from app.radio_sync import sync_and_offload_contacts @@ -690,18 +562,7 @@ class TestSyncAndOffloadContacts: mock_mc.commands.get_contacts = AsyncMock(return_value=mock_get_result) mock_mc.commands.remove_contact = AsyncMock(side_effect=Exception("Timeout")) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.upsert", - new_callable=AsyncMock, - ), - patch( - "app.radio_sync.MessageRepository.claim_prefix_messages", - new_callable=AsyncMock, - return_value=0, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -733,7 +594,7 @@ class TestSyncAndOffloadContacts: assert "error" in result @pytest.mark.asyncio - async def test_upserts_with_on_radio_false(self): + async def test_upserts_with_on_radio_false(self, test_db): """Contacts are upserted with on_radio=False (being removed from radio).""" from app.radio_sync import sync_and_offload_contacts @@ -750,25 +611,15 @@ class TestSyncAndOffloadContacts: mock_mc.commands.get_contacts = AsyncMock(return_value=mock_get_result) mock_mc.commands.remove_contact = AsyncMock(return_value=mock_remove_result) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ContactRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - patch( - "app.radio_sync.MessageRepository.claim_prefix_messages", - new_callable=AsyncMock, - return_value=0, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc await sync_and_offload_contacts() - upserted_data = mock_upsert.call_args[0][0] - assert upserted_data["on_radio"] is False + contact = await ContactRepository.get_by_key(KEY_A) + assert contact is not None + assert contact.on_radio is False class TestSyncAndOffloadChannels: @@ -790,7 +641,7 @@ class TestSyncAndOffloadChannels: assert "error" in result @pytest.mark.asyncio - async def test_syncs_valid_channel_and_clears(self): + async def test_syncs_valid_channel_and_clears(self, test_db): """Valid channel is upserted to DB and cleared from radio.""" from app.radio_sync import sync_and_offload_channels @@ -812,13 +663,7 @@ class TestSyncAndOffloadChannels: clear_result.type = EventType.OK mock_mc.commands.set_channel = AsyncMock(return_value=clear_result) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -826,12 +671,13 @@ class TestSyncAndOffloadChannels: assert result["synced"] == 1 assert result["cleared"] == 1 - mock_upsert.assert_called_once_with( - key="8B3387E9C5CDEA6AC9E5EDBAA115CD72", - name="#general", - is_hashtag=True, - on_radio=False, - ) + + # Verify channel is in real DB + channel = await ChannelRepository.get_by_key("8B3387E9C5CDEA6AC9E5EDBAA115CD72") + assert channel is not None + assert channel.name == "#general" + assert channel.is_hashtag is True + assert channel.on_radio is False @pytest.mark.asyncio async def test_skips_empty_channel_name(self): @@ -853,13 +699,7 @@ class TestSyncAndOffloadChannels: side_effect=[empty_name_result] + [other_result] * 39 ) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -867,7 +707,6 @@ class TestSyncAndOffloadChannels: assert result["synced"] == 0 assert result["cleared"] == 0 - mock_upsert.assert_not_called() @pytest.mark.asyncio async def test_skips_channel_with_zero_key(self): @@ -889,23 +728,16 @@ class TestSyncAndOffloadChannels: side_effect=[zero_key_result] + [other_result] * 39 ) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc result = await sync_and_offload_channels() assert result["synced"] == 0 - mock_upsert.assert_not_called() @pytest.mark.asyncio - async def test_non_hashtag_channel_detected(self): + async def test_non_hashtag_channel_detected(self, test_db): """Channel without '#' prefix has is_hashtag=False.""" from app.radio_sync import sync_and_offload_channels @@ -926,23 +758,18 @@ class TestSyncAndOffloadChannels: clear_result.type = EventType.OK mock_mc.commands.set_channel = AsyncMock(return_value=clear_result) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc await sync_and_offload_channels() - mock_upsert.assert_called_once() - assert mock_upsert.call_args.kwargs["is_hashtag"] is False + channel = await ChannelRepository.get_by_key("8B3387E9C5CDEA6AC9E5EDBAA115CD72") + assert channel is not None + assert channel.is_hashtag is False @pytest.mark.asyncio - async def test_clears_channel_with_empty_name_and_zero_key(self): + async def test_clears_channel_with_empty_name_and_zero_key(self, test_db): """Cleared channels are set with empty name and 16 zero bytes.""" from app.radio_sync import sync_and_offload_channels @@ -963,13 +790,7 @@ class TestSyncAndOffloadChannels: clear_result.type = EventType.OK mock_mc.commands.set_channel = AsyncMock(return_value=clear_result) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -982,7 +803,7 @@ class TestSyncAndOffloadChannels: ) @pytest.mark.asyncio - async def test_handles_clear_failure_gracefully(self): + async def test_handles_clear_failure_gracefully(self, test_db): """Failed set_channel logs warning but continues processing.""" from app.radio_sync import sync_and_offload_channels @@ -1011,13 +832,7 @@ class TestSyncAndOffloadChannels: mock_mc.commands.set_channel = AsyncMock(side_effect=[fail_result, ok_result]) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -1037,13 +852,7 @@ class TestSyncAndOffloadChannels: mock_mc = MagicMock() mock_mc.commands.get_channel = AsyncMock(return_value=empty_result) - with ( - patch("app.radio_sync.radio_manager") as mock_rm, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ), - ): + with patch("app.radio_sync.radio_manager") as mock_rm: mock_rm.is_connected = True mock_rm.meshcore = mock_mc @@ -1060,104 +869,68 @@ class TestEnsureDefaultChannels: PUBLIC_KEY = "8B3387E9C5CDEA6AC9E5EDBAA115CD72" @pytest.mark.asyncio - async def test_creates_public_channel_when_missing(self): + async def test_creates_public_channel_when_missing(self, test_db): """Public channel is created when it does not exist.""" from app.radio_sync import ensure_default_channels - with ( - patch( - "app.radio_sync.ChannelRepository.get_by_key", - new_callable=AsyncMock, - return_value=None, - ) as mock_get, - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): - await ensure_default_channels() + await ensure_default_channels() - mock_get.assert_called_once_with(self.PUBLIC_KEY) - mock_upsert.assert_called_once_with( + channel = await ChannelRepository.get_by_key(self.PUBLIC_KEY) + assert channel is not None + assert channel.name == "Public" + assert channel.is_hashtag is False + assert channel.on_radio is False + + @pytest.mark.asyncio + async def test_fixes_public_channel_with_wrong_name(self, test_db): + """Public channel name is corrected when it exists with wrong name.""" + from app.radio_sync import ensure_default_channels + + # Pre-insert with wrong name + await ChannelRepository.upsert( + key=self.PUBLIC_KEY, + name="public", # Wrong case + is_hashtag=False, + on_radio=True, + ) + + await ensure_default_channels() + + channel = await ChannelRepository.get_by_key(self.PUBLIC_KEY) + assert channel.name == "Public" + assert channel.on_radio is True # Preserves existing on_radio state + + @pytest.mark.asyncio + async def test_no_op_when_public_channel_exists_correctly(self, test_db): + """No upsert when Public channel already exists with correct name.""" + from app.radio_sync import ensure_default_channels + + await ChannelRepository.upsert( key=self.PUBLIC_KEY, name="Public", is_hashtag=False, on_radio=False, ) - @pytest.mark.asyncio - async def test_fixes_public_channel_with_wrong_name(self): - """Public channel name is corrected when it exists with wrong name.""" - from app.radio_sync import ensure_default_channels + await ensure_default_channels() - existing = MagicMock() - existing.name = "public" # Wrong case - existing.on_radio = True - - with ( - patch( - "app.radio_sync.ChannelRepository.get_by_key", - new_callable=AsyncMock, - return_value=existing, - ), - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): - await ensure_default_channels() - - mock_upsert.assert_called_once_with( - key=self.PUBLIC_KEY, - name="Public", - is_hashtag=False, - on_radio=True, # Preserves existing on_radio state - ) + # Still exists and unchanged + channel = await ChannelRepository.get_by_key(self.PUBLIC_KEY) + assert channel.name == "Public" @pytest.mark.asyncio - async def test_no_op_when_public_channel_exists_correctly(self): - """No upsert when Public channel already exists with correct name.""" - from app.radio_sync import ensure_default_channels - - existing = MagicMock() - existing.name = "Public" - existing.on_radio = False - - with ( - patch( - "app.radio_sync.ChannelRepository.get_by_key", - new_callable=AsyncMock, - return_value=existing, - ), - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): - await ensure_default_channels() - - mock_upsert.assert_not_called() - - @pytest.mark.asyncio - async def test_preserves_on_radio_state_when_fixing_name(self): + async def test_preserves_on_radio_state_when_fixing_name(self, test_db): """existing.on_radio is passed through when fixing the channel name.""" from app.radio_sync import ensure_default_channels - existing = MagicMock() - existing.name = "Pub" - existing.on_radio = True + await ChannelRepository.upsert( + key=self.PUBLIC_KEY, + name="Pub", + is_hashtag=False, + on_radio=True, + ) - with ( - patch( - "app.radio_sync.ChannelRepository.get_by_key", - new_callable=AsyncMock, - return_value=existing, - ), - patch( - "app.radio_sync.ChannelRepository.upsert", - new_callable=AsyncMock, - ) as mock_upsert, - ): - await ensure_default_channels() + await ensure_default_channels() - assert mock_upsert.call_args.kwargs["on_radio"] is True + channel = await ChannelRepository.get_by_key(self.PUBLIC_KEY) + assert channel.on_radio is True diff --git a/tests/test_repeater_routes.py b/tests/test_repeater_routes.py index 90701f4..371c3bc 100644 --- a/tests/test_repeater_routes.py +++ b/tests/test_repeater_routes.py @@ -6,12 +6,32 @@ import pytest from fastapi import HTTPException from meshcore import EventType -from app.models import CommandRequest, Contact, TelemetryRequest +from app.database import Database +from app.models import CommandRequest, TelemetryRequest +from app.repository import ContactRepository from app.routers.contacts import request_telemetry, request_trace, send_repeater_command KEY_A = "aa" * 32 +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + import app.repository as repo_module + + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() + + def _radio_result(event_type=EventType.OK, payload=None): result = MagicMock() result.type = event_type @@ -19,8 +39,24 @@ def _radio_result(event_type=EventType.OK, payload=None): return result -def _make_contact(public_key: str, contact_type: int, name: str = "Node") -> Contact: - return Contact(public_key=public_key, name=name, type=contact_type) +async def _insert_contact(public_key: str, name: str = "Node", contact_type: int = 0): + """Insert a contact into the test database.""" + await ContactRepository.upsert( + { + "public_key": public_key, + "name": name, + "type": contact_type, + "flags": 0, + "last_path": None, + "last_path_len": -1, + "last_advert": None, + "lat": None, + "lon": None, + "last_seen": None, + "on_radio": False, + "last_contacted": None, + } + ) def _mock_mc(): @@ -41,33 +77,20 @@ def _mock_mc(): class TestTelemetryRoute: @pytest.mark.asyncio - async def test_returns_404_when_contact_missing(self): + async def test_returns_404_when_contact_missing(self, test_db): mc = _mock_mc() - with ( - patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=None, - ), - ): + with patch("app.routers.contacts.require_connected", return_value=mc): with pytest.raises(HTTPException) as exc: await request_telemetry(KEY_A, TelemetryRequest(password="pw")) assert exc.value.status_code == 404 @pytest.mark.asyncio - async def test_returns_400_for_non_repeater_contact(self): + async def test_returns_400_for_non_repeater_contact(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=1, name="Client") - with ( - patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - ): + await _insert_contact(KEY_A, name="Client", contact_type=1) + + with patch("app.routers.contacts.require_connected", return_value=mc): with pytest.raises(HTTPException) as exc: await request_telemetry(KEY_A, TelemetryRequest(password="pw")) @@ -75,18 +98,13 @@ class TestTelemetryRoute: assert "not a repeater" in exc.value.detail.lower() @pytest.mark.asyncio - async def test_status_retry_timeout_returns_504(self): + async def test_status_retry_timeout_returns_504(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=2, name="Repeater") + await _insert_contact(KEY_A, name="Repeater", contact_type=2) mc.commands.req_status_sync = AsyncMock(side_effect=[None, None, None]) with ( patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), patch( "app.routers.contacts.prepare_repeater_connection", new_callable=AsyncMock, @@ -100,9 +118,9 @@ class TestTelemetryRoute: mock_prepare.assert_awaited_once() @pytest.mark.asyncio - async def test_clock_timeout_uses_fallback_message_and_restores_auto_fetch(self): + async def test_clock_timeout_uses_fallback_message_and_restores_auto_fetch(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=2, name="Repeater") + await _insert_contact(KEY_A, name="Repeater", contact_type=2) mc.commands.req_status_sync = AsyncMock( return_value={ "pubkey_pre": "aaaaaaaaaaaa", @@ -119,16 +137,6 @@ class TestTelemetryRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - patch( - "app.routers.contacts.ContactRepository.get_by_key_prefix", - new_callable=AsyncMock, - return_value=None, - ), patch( "app.routers.contacts.prepare_repeater_connection", new_callable=AsyncMock, @@ -147,21 +155,14 @@ class TestTelemetryRoute: class TestRepeaterCommandRoute: @pytest.mark.asyncio - async def test_send_cmd_error_raises_and_restores_auto_fetch(self): + async def test_send_cmd_error_raises_and_restores_auto_fetch(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=2, name="Repeater") + await _insert_contact(KEY_A, name="Repeater", contact_type=2) mc.commands.send_cmd = AsyncMock( return_value=_radio_result(EventType.ERROR, {"err": "bad"}) ) - with ( - patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - ): + with patch("app.routers.contacts.require_connected", return_value=mc): with pytest.raises(HTTPException) as exc: await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -169,20 +170,13 @@ class TestRepeaterCommandRoute: mc.start_auto_message_fetching.assert_awaited_once() @pytest.mark.asyncio - async def test_timeout_returns_no_response_message(self): + async def test_timeout_returns_no_response_message(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=2, name="Repeater") + await _insert_contact(KEY_A, name="Repeater", contact_type=2) mc.commands.send_cmd = AsyncMock(return_value=_radio_result(EventType.OK)) mc.wait_for_event = AsyncMock(return_value=None) - with ( - patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - ): + with patch("app.routers.contacts.require_connected", return_value=mc): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) assert response.command == "ver" @@ -190,9 +184,9 @@ class TestRepeaterCommandRoute: mc.start_auto_message_fetching.assert_awaited_once() @pytest.mark.asyncio - async def test_success_returns_command_response_text_and_timestamp(self): + async def test_success_returns_command_response_text_and_timestamp(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=2, name="Repeater") + await _insert_contact(KEY_A, name="Repeater", contact_type=2) mc.commands.send_cmd = AsyncMock(return_value=_radio_result(EventType.OK)) mc.wait_for_event = AsyncMock(return_value=MagicMock()) mc.commands.get_msg = AsyncMock( @@ -202,14 +196,7 @@ class TestRepeaterCommandRoute: ) ) - with ( - patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), - ): + with patch("app.routers.contacts.require_connected", return_value=mc): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) assert response.command == "ver" @@ -219,20 +206,15 @@ class TestRepeaterCommandRoute: class TestTraceRoute: @pytest.mark.asyncio - async def test_send_trace_error_returns_500(self): + async def test_send_trace_error_returns_500(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=1, name="Client") + await _insert_contact(KEY_A, name="Client", contact_type=1) mc.commands.send_trace = AsyncMock( return_value=_radio_result(EventType.ERROR, {"err": "x"}) ) with ( patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), patch("app.routers.contacts.random.randint", return_value=1234), ): with pytest.raises(HTTPException) as exc: @@ -241,19 +223,14 @@ class TestTraceRoute: assert exc.value.status_code == 500 @pytest.mark.asyncio - async def test_wait_timeout_returns_504(self): + async def test_wait_timeout_returns_504(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=1, name="Client") + await _insert_contact(KEY_A, name="Client", contact_type=1) mc.commands.send_trace = AsyncMock(return_value=_radio_result(EventType.OK)) mc.wait_for_event = AsyncMock(return_value=None) with ( patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), patch("app.routers.contacts.random.randint", return_value=1234), ): with pytest.raises(HTTPException) as exc: @@ -262,9 +239,9 @@ class TestTraceRoute: assert exc.value.status_code == 504 @pytest.mark.asyncio - async def test_success_returns_remote_and_local_snr(self): + async def test_success_returns_remote_and_local_snr(self, test_db): mc = _mock_mc() - contact = _make_contact(KEY_A, contact_type=1, name="Client") + await _insert_contact(KEY_A, name="Client", contact_type=1) mc.commands.send_trace = AsyncMock(return_value=_radio_result(EventType.OK)) mc.wait_for_event = AsyncMock( return_value=MagicMock(payload={"path": [{"snr": 5.5}, {"snr": 3.2}], "path_len": 2}) @@ -272,11 +249,6 @@ class TestTraceRoute: with ( patch("app.routers.contacts.require_connected", return_value=mc), - patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", - new_callable=AsyncMock, - return_value=contact, - ), patch("app.routers.contacts.random.randint", return_value=1234), ): response = await request_trace(KEY_A) diff --git a/tests/test_send_messages.py b/tests/test_send_messages.py index 396105a..647d058 100644 --- a/tests/test_send_messages.py +++ b/tests/test_send_messages.py @@ -7,17 +7,37 @@ import pytest from fastapi import HTTPException from meshcore import EventType +from app.database import Database from app.models import ( - AppSettings, - Channel, - Contact, SendChannelMessageRequest, SendDirectMessageRequest, ) -from app.repository import AmbiguousPublicKeyPrefixError +from app.repository import ( + AppSettingsRepository, + ChannelRepository, + ContactRepository, +) from app.routers.messages import send_channel_message, send_direct_message +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + import app.repository as repo_module + + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() + + def _make_radio_result(payload=None): """Create a mock radio command result.""" result = MagicMock() @@ -39,28 +59,41 @@ def _make_mc(name="TestNode"): return mc +async def _insert_contact(public_key, name="Alice"): + """Insert a contact into the test database.""" + await ContactRepository.upsert( + { + "public_key": public_key, + "name": name, + "type": 0, + "flags": 0, + "last_path": None, + "last_path_len": -1, + "last_advert": None, + "lat": None, + "lon": None, + "last_seen": None, + "on_radio": False, + "last_contacted": None, + } + ) + + class TestOutgoingDMBotTrigger: """Test that sending a DM triggers bots with is_outgoing=True.""" @pytest.mark.asyncio - async def test_send_dm_triggers_bot(self): + async def test_send_dm_triggers_bot(self, test_db): """Sending a DM creates a background task to run bots.""" mc = _make_mc() - db_contact = Contact(public_key="ab" * 32, name="Alice") + pub_key = "ab" * 32 + await _insert_contact(pub_key, "Alice") with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ContactRepository.get_by_key_or_prefix", - new=AsyncMock(return_value=db_contact), - ), - patch("app.repository.ContactRepository.update_last_contacted", new=AsyncMock()), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): - request = SendDirectMessageRequest( - destination=db_contact.public_key, text="!lasttime Alice" - ) + request = SendDirectMessageRequest(destination=pub_key, text="!lasttime Alice") await send_direct_message(request) # Let the background task run @@ -71,14 +104,15 @@ class TestOutgoingDMBotTrigger: assert call_kwargs["message_text"] == "!lasttime Alice" assert call_kwargs["is_dm"] is True assert call_kwargs["is_outgoing"] is True - assert call_kwargs["sender_key"] == db_contact.public_key + assert call_kwargs["sender_key"] == pub_key assert call_kwargs["channel_key"] is None @pytest.mark.asyncio - async def test_send_dm_bot_does_not_block_response(self): + async def test_send_dm_bot_does_not_block_response(self, test_db): """Bot trigger runs in background and doesn't delay the message response.""" mc = _make_mc() - db_contact = Contact(public_key="ab" * 32, name="Alice") + pub_key = "ab" * 32 + await _insert_contact(pub_key, "Alice") # Bot that would take a long time async def _slow(**kw): @@ -88,37 +122,26 @@ class TestOutgoingDMBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ContactRepository.get_by_key_or_prefix", - new=AsyncMock(return_value=db_contact), - ), - patch("app.repository.ContactRepository.update_last_contacted", new=AsyncMock()), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), patch("app.bot.run_bot_for_message", new=slow_bot), ): - request = SendDirectMessageRequest(destination=db_contact.public_key, text="Hello") + request = SendDirectMessageRequest(destination=pub_key, text="Hello") # This should return immediately, not wait 10 seconds message = await send_direct_message(request) assert message.text == "Hello" assert message.outgoing is True @pytest.mark.asyncio - async def test_send_dm_passes_no_sender_name(self): + async def test_send_dm_passes_no_sender_name(self, test_db): """Outgoing DMs pass sender_name=None (we are the sender).""" mc = _make_mc() - db_contact = Contact(public_key="cd" * 32, name="Bob") + pub_key = "cd" * 32 + await _insert_contact(pub_key, "Bob") with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ContactRepository.get_by_key_or_prefix", - new=AsyncMock(return_value=db_contact), - ), - patch("app.repository.ContactRepository.update_last_contacted", new=AsyncMock()), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): - request = SendDirectMessageRequest(destination=db_contact.public_key, text="test") + request = SendDirectMessageRequest(destination=pub_key, text="test") await send_direct_message(request) await asyncio.sleep(0) @@ -126,25 +149,15 @@ class TestOutgoingDMBotTrigger: assert call_kwargs["sender_name"] is None @pytest.mark.asyncio - async def test_send_dm_ambiguous_prefix_returns_409(self): + async def test_send_dm_ambiguous_prefix_returns_409(self, test_db): """Ambiguous destination prefix should fail instead of selecting a random contact.""" mc = _make_mc() - with ( - patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ContactRepository.get_by_key_or_prefix", - new=AsyncMock( - side_effect=AmbiguousPublicKeyPrefixError( - "abc123", - [ - "abc1230000000000000000000000000000000000000000000000000000000000", - "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - ], - ) - ), - ), - ): + # Insert two contacts that share the prefix "abc123" + await _insert_contact("abc123" + "00" * 29, "ContactA") + await _insert_contact("abc123" + "ff" * 29, "ContactB") + + with patch("app.routers.messages.require_connected", return_value=mc): with pytest.raises(HTTPException) as exc_info: await send_direct_message( SendDirectMessageRequest(destination="abc123", text="Hello") @@ -158,29 +171,18 @@ class TestOutgoingChannelBotTrigger: """Test that sending a channel message triggers bots with is_outgoing=True.""" @pytest.mark.asyncio - async def test_send_channel_msg_triggers_bot(self): + async def test_send_channel_msg_triggers_bot(self, test_db): """Sending a channel message creates a background task to run bots.""" mc = _make_mc(name="MyNode") - db_channel = Channel(key="aa" * 16, name="#general") + chan_key = "aa" * 16 + await ChannelRepository.upsert(key=chan_key, name="#general") with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ChannelRepository.get_by_key", - new=AsyncMock(return_value=db_channel), - ), - patch( - "app.repository.AppSettingsRepository.get", - new=AsyncMock(return_value=AppSettings()), - ), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), - patch("app.repository.MessageRepository.get_ack_count", new=AsyncMock(return_value=0)), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): - request = SendChannelMessageRequest( - channel_key=db_channel.key, text="!lasttime5 someone" - ) + request = SendChannelMessageRequest(channel_key=chan_key, text="!lasttime5 someone") await send_channel_message(request) await asyncio.sleep(0) @@ -189,33 +191,24 @@ class TestOutgoingChannelBotTrigger: assert call_kwargs["message_text"] == "!lasttime5 someone" assert call_kwargs["is_dm"] is False assert call_kwargs["is_outgoing"] is True - assert call_kwargs["channel_key"] == db_channel.key.upper() + assert call_kwargs["channel_key"] == chan_key.upper() assert call_kwargs["channel_name"] == "#general" assert call_kwargs["sender_name"] == "MyNode" assert call_kwargs["sender_key"] is None @pytest.mark.asyncio - async def test_send_channel_msg_no_radio_name(self): + async def test_send_channel_msg_no_radio_name(self, test_db): """When radio has no name, sender_name is None.""" mc = _make_mc(name="") - db_channel = Channel(key="bb" * 16, name="#test") + chan_key = "bb" * 16 + await ChannelRepository.upsert(key=chan_key, name="#test") with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ChannelRepository.get_by_key", - new=AsyncMock(return_value=db_channel), - ), - patch( - "app.repository.AppSettingsRepository.get", - new=AsyncMock(return_value=AppSettings()), - ), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), - patch("app.repository.MessageRepository.get_ack_count", new=AsyncMock(return_value=0)), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()) as mock_bot, ): - request = SendChannelMessageRequest(channel_key=db_channel.key, text="hello") + request = SendChannelMessageRequest(channel_key=chan_key, text="hello") await send_channel_message(request) await asyncio.sleep(0) @@ -223,10 +216,11 @@ class TestOutgoingChannelBotTrigger: assert call_kwargs["sender_name"] is None @pytest.mark.asyncio - async def test_send_channel_msg_bot_does_not_block_response(self): + async def test_send_channel_msg_bot_does_not_block_response(self, test_db): """Bot trigger runs in background and doesn't delay the message response.""" mc = _make_mc(name="MyNode") - db_channel = Channel(key="cc" * 16, name="#slow") + chan_key = "cc" * 16 + await ChannelRepository.upsert(key=chan_key, name="#slow") async def _slow(**kw): await asyncio.sleep(10) @@ -235,44 +229,28 @@ class TestOutgoingChannelBotTrigger: with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ChannelRepository.get_by_key", - new=AsyncMock(return_value=db_channel), - ), - patch( - "app.repository.AppSettingsRepository.get", - new=AsyncMock(return_value=AppSettings()), - ), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), - patch("app.repository.MessageRepository.get_ack_count", new=AsyncMock(return_value=0)), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=slow_bot), ): - request = SendChannelMessageRequest(channel_key=db_channel.key, text="test") + request = SendChannelMessageRequest(channel_key=chan_key, text="test") message = await send_channel_message(request) assert message.outgoing is True @pytest.mark.asyncio - async def test_send_channel_msg_double_send_when_experimental_enabled(self): + async def test_send_channel_msg_double_send_when_experimental_enabled(self, test_db): """Experimental setting triggers an immediate byte-perfect duplicate send.""" mc = _make_mc(name="MyNode") - db_channel = Channel(key="dd" * 16, name="#double") - settings = AppSettings(experimental_channel_double_send=True) + chan_key = "dd" * 16 + await ChannelRepository.upsert(key=chan_key, name="#double") + await AppSettingsRepository.update(experimental_channel_double_send=True) with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ChannelRepository.get_by_key", - new=AsyncMock(return_value=db_channel), - ), - patch("app.repository.AppSettingsRepository.get", new=AsyncMock(return_value=settings)), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), - patch("app.repository.MessageRepository.get_ack_count", new=AsyncMock(return_value=0)), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()), patch("app.routers.messages.asyncio.sleep", new=AsyncMock()) as mock_sleep, ): - request = SendChannelMessageRequest(channel_key=db_channel.key, text="same bytes") + request = SendChannelMessageRequest(channel_key=chan_key, text="same bytes") await send_channel_message(request) assert mc.commands.send_chan_msg.await_count == 2 @@ -284,54 +262,37 @@ class TestOutgoingChannelBotTrigger: assert first_call["timestamp"] == second_call["timestamp"] @pytest.mark.asyncio - async def test_send_channel_msg_single_send_when_experimental_disabled(self): + async def test_send_channel_msg_single_send_when_experimental_disabled(self, test_db): """Default setting keeps channel sends to a single radio command.""" mc = _make_mc(name="MyNode") - db_channel = Channel(key="ee" * 16, name="#single") + chan_key = "ee" * 16 + await ChannelRepository.upsert(key=chan_key, name="#single") with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ChannelRepository.get_by_key", - new=AsyncMock(return_value=db_channel), - ), - patch( - "app.repository.AppSettingsRepository.get", - new=AsyncMock(return_value=AppSettings()), - ), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=1)), - patch("app.repository.MessageRepository.get_ack_count", new=AsyncMock(return_value=0)), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()), ): - request = SendChannelMessageRequest(channel_key=db_channel.key, text="single send") + request = SendChannelMessageRequest(channel_key=chan_key, text="single send") await send_channel_message(request) assert mc.commands.send_chan_msg.await_count == 1 @pytest.mark.asyncio - async def test_send_channel_msg_response_includes_current_ack_count(self): + async def test_send_channel_msg_response_includes_current_ack_count(self, test_db): """Send response reflects latest DB ack count at response time.""" mc = _make_mc(name="MyNode") - db_channel = Channel(key="ff" * 16, name="#acked") + chan_key = "ff" * 16 + await ChannelRepository.upsert(key=chan_key, name="#acked") with ( patch("app.routers.messages.require_connected", return_value=mc), - patch( - "app.repository.ChannelRepository.get_by_key", - new=AsyncMock(return_value=db_channel), - ), - patch( - "app.repository.AppSettingsRepository.get", - new=AsyncMock(return_value=AppSettings()), - ), - patch("app.repository.MessageRepository.create", new=AsyncMock(return_value=123)), - patch("app.repository.MessageRepository.get_ack_count", new=AsyncMock(return_value=2)), patch("app.decoder.calculate_channel_hash", return_value="abcd"), patch("app.bot.run_bot_for_message", new=AsyncMock()), ): - request = SendChannelMessageRequest(channel_key=db_channel.key, text="acked now") + request = SendChannelMessageRequest(channel_key=chan_key, text="acked now") message = await send_channel_message(request) - assert message.id == 123 - assert message.acked == 2 + # Fresh message has acked=0 + assert message.id is not None + assert message.acked == 0 diff --git a/tests/test_settings_router.py b/tests/test_settings_router.py index b93f9ed..a8d24ed 100644 --- a/tests/test_settings_router.py +++ b/tests/test_settings_router.py @@ -1,11 +1,11 @@ """Tests for settings router endpoints and validation behavior.""" -from unittest.mock import AsyncMock, patch - import pytest from fastapi import HTTPException -from app.models import AppSettings, BotConfig, Favorite +from app.database import Database +from app.models import AppSettings, BotConfig +from app.repository import AppSettingsRepository from app.routers.settings import ( AppSettingsUpdate, FavoriteRequest, @@ -16,71 +16,46 @@ from app.routers.settings import ( ) -def _settings( - *, - favorites: list[Favorite] | None = None, - migrated: bool = False, - max_radio_contacts: int = 200, - experimental_channel_double_send: bool = False, -) -> AppSettings: - return AppSettings( - max_radio_contacts=max_radio_contacts, - experimental_channel_double_send=experimental_channel_double_send, - favorites=favorites or [], - auto_decrypt_dm_on_advert=False, - sidebar_sort_order="recent", - last_message_times={}, - preferences_migrated=migrated, - advert_interval=0, - last_advert_time=0, - bots=[], - ) +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + import app.repository as repo_module + + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() class TestUpdateSettings: @pytest.mark.asyncio - async def test_forwards_only_provided_fields(self): - updated = _settings(max_radio_contacts=321) - with patch( - "app.routers.settings.AppSettingsRepository.update", - new_callable=AsyncMock, - return_value=updated, - ) as mock_update: - result = await update_settings( - AppSettingsUpdate( - max_radio_contacts=321, - advert_interval=3600, - experimental_channel_double_send=True, - ) + async def test_forwards_only_provided_fields(self, test_db): + result = await update_settings( + AppSettingsUpdate( + max_radio_contacts=321, + advert_interval=3600, + experimental_channel_double_send=True, ) + ) assert result.max_radio_contacts == 321 - assert mock_update.call_count == 1 - assert mock_update.call_args.kwargs == { - "max_radio_contacts": 321, - "advert_interval": 3600, - "experimental_channel_double_send": True, - } + assert result.advert_interval == 3600 + assert result.experimental_channel_double_send is True @pytest.mark.asyncio - async def test_empty_patch_returns_current_settings(self): - current = _settings() - with ( - patch( - "app.routers.settings.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=current, - ) as mock_get, - patch( - "app.routers.settings.AppSettingsRepository.update", - new_callable=AsyncMock, - ) as mock_update, - ): - result = await update_settings(AppSettingsUpdate()) + async def test_empty_patch_returns_current_settings(self, test_db): + result = await update_settings(AppSettingsUpdate()) - assert result == current - mock_get.assert_awaited_once() - mock_update.assert_not_awaited() + # Should return default settings without error + assert isinstance(result, AppSettings) + assert result.max_radio_contacts == 200 # default @pytest.mark.asyncio async def test_invalid_bot_syntax_returns_400(self): @@ -100,102 +75,61 @@ class TestUpdateSettings: class TestToggleFavorite: @pytest.mark.asyncio - async def test_adds_when_not_favorited(self): - initial = _settings(favorites=[]) - updated = _settings(favorites=[Favorite(type="contact", id="aa" * 32)]) + async def test_adds_when_not_favorited(self, test_db): request = FavoriteRequest(type="contact", id="aa" * 32) + result = await toggle_favorite(request) - with ( - patch( - "app.routers.settings.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=initial, - ), - patch( - "app.routers.settings.AppSettingsRepository.add_favorite", - new_callable=AsyncMock, - return_value=updated, - ) as mock_add, - patch( - "app.routers.settings.AppSettingsRepository.remove_favorite", - new_callable=AsyncMock, - ) as mock_remove, - ): - result = await toggle_favorite(request) - - assert result.favorites == updated.favorites - mock_add.assert_awaited_once_with("contact", "aa" * 32) - mock_remove.assert_not_awaited() + assert len(result.favorites) == 1 + assert result.favorites[0].type == "contact" + assert result.favorites[0].id == "aa" * 32 @pytest.mark.asyncio - async def test_removes_when_already_favorited(self): - initial = _settings(favorites=[Favorite(type="channel", id="ABCD")]) - updated = _settings(favorites=[]) - request = FavoriteRequest(type="channel", id="ABCD") + async def test_removes_when_already_favorited(self, test_db): + # Pre-add a favorite + await AppSettingsRepository.add_favorite("channel", "ABCD") - with ( - patch( - "app.routers.settings.AppSettingsRepository.get", - new_callable=AsyncMock, - return_value=initial, - ), - patch( - "app.routers.settings.AppSettingsRepository.remove_favorite", - new_callable=AsyncMock, - return_value=updated, - ) as mock_remove, - patch( - "app.routers.settings.AppSettingsRepository.add_favorite", - new_callable=AsyncMock, - ) as mock_add, - ): - result = await toggle_favorite(request) + request = FavoriteRequest(type="channel", id="ABCD") + result = await toggle_favorite(request) assert result.favorites == [] - mock_remove.assert_awaited_once_with("channel", "ABCD") - mock_add.assert_not_awaited() class TestMigratePreferences: @pytest.mark.asyncio - async def test_maps_frontend_payload_and_returns_migrated_true(self): + async def test_maps_frontend_payload_and_returns_migrated_true(self, test_db): request = MigratePreferencesRequest( favorites=[FavoriteRequest(type="contact", id="aa" * 32)], sort_order="alpha", last_message_times={"contact-aaaaaaaaaaaa": 123}, ) - settings = _settings(favorites=[Favorite(type="contact", id="aa" * 32)], migrated=True) - with patch( - "app.routers.settings.AppSettingsRepository.migrate_preferences_from_frontend", - new_callable=AsyncMock, - return_value=(settings, True), - ) as mock_migrate: - response = await migrate_preferences(request) + response = await migrate_preferences(request) assert response.migrated is True - assert response.settings == settings - assert mock_migrate.call_args.kwargs == { - "favorites": [{"type": "contact", "id": "aa" * 32}], - "sort_order": "alpha", - "last_message_times": {"contact-aaaaaaaaaaaa": 123}, - } + assert response.settings.preferences_migrated is True + assert response.settings.sidebar_sort_order == "alpha" + assert len(response.settings.favorites) == 1 + assert response.settings.favorites[0].type == "contact" + assert response.settings.favorites[0].id == "aa" * 32 + assert response.settings.last_message_times == {"contact-aaaaaaaaaaaa": 123} @pytest.mark.asyncio - async def test_returns_migrated_false_when_already_done(self): - request = MigratePreferencesRequest( + async def test_returns_migrated_false_when_already_done(self, test_db): + # First migration + first_request = MigratePreferencesRequest( + favorites=[FavoriteRequest(type="contact", id="bb" * 32)], + sort_order="recent", + last_message_times={}, + ) + await migrate_preferences(first_request) + + # Second attempt should be no-op + second_request = MigratePreferencesRequest( favorites=[], sort_order="recent", last_message_times={}, ) - settings = _settings(migrated=True) - - with patch( - "app.routers.settings.AppSettingsRepository.migrate_preferences_from_frontend", - new_callable=AsyncMock, - return_value=(settings, False), - ): - response = await migrate_preferences(request) + response = await migrate_preferences(second_request) assert response.migrated is False assert response.settings.preferences_migrated is True