diff --git a/app/main.py b/app/main.py index 59914d8..57de687 100644 --- a/app/main.py +++ b/app/main.py @@ -38,6 +38,13 @@ async def lifespan(app: FastAPI): await db.connect() logger.info("Database connected") + # Ensure default channels exist in the database even before the radio + # connects. Without this, a fresh or disconnected instance would return + # zero channels from GET /channels until the first successful radio sync. + from app.radio_sync import ensure_default_channels + + await ensure_default_channels() + try: await radio_manager.connect() logger.info("Connected to radio") diff --git a/app/repository.py b/app/repository.py index 46722fa..834514d 100644 --- a/app/repository.py +++ b/app/repository.py @@ -724,7 +724,6 @@ class MessageRepository: """ SELECT m.conversation_key, COUNT(*) as unread_count, - MAX(m.received_at) as last_message_time, SUM(CASE WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 ELSE 0 @@ -749,7 +748,6 @@ class MessageRepository: """ SELECT m.conversation_key, COUNT(*) as unread_count, - MAX(m.received_at) as last_message_time, SUM(CASE WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 ELSE 0 diff --git a/app/routers/channels.py b/app/routers/channels.py index 6b95554..a54799c 100644 --- a/app/routers/channels.py +++ b/app/routers/channels.py @@ -8,7 +8,6 @@ from pydantic import BaseModel, Field from app.dependencies import require_connected from app.models import Channel from app.radio import radio_manager -from app.radio_sync import ensure_default_channels from app.repository import ChannelRepository logger = logging.getLogger(__name__) @@ -26,8 +25,6 @@ class CreateChannelRequest(BaseModel): @router.get("", response_model=list[Channel]) async def list_channels() -> list[Channel]: """List all channels from the database.""" - # Ensure Public channel always exists (self-healing) - await ensure_default_channels() return await ChannelRepository.get_all() diff --git a/app/routers/packets.py b/app/routers/packets.py index 317302b..734e878 100644 --- a/app/routers/packets.py +++ b/app/routers/packets.py @@ -1,5 +1,6 @@ import logging from hashlib import sha256 +from sqlite3 import OperationalError import aiosqlite from fastapi import APIRouter, BackgroundTasks @@ -288,7 +289,9 @@ async def run_maintenance(request: MaintenanceRequest) -> MaintenanceResult: await vacuum_conn.executescript("VACUUM;") vacuumed = True logger.info("Database vacuumed") - except Exception as e: + except OperationalError as e: logger.warning("VACUUM skipped (database busy): %s", e) + except Exception as e: + logger.error("VACUUM failed unexpectedly: %s", e) return MaintenanceResult(packets_deleted=deleted, vacuumed=vacuumed) diff --git a/tests/test_event_handlers.py b/tests/test_event_handlers.py index aee87d6..d732e3a 100644 --- a/tests/test_event_handlers.py +++ b/tests/test_event_handlers.py @@ -463,6 +463,38 @@ class TestContactMessageCLIFiltering: _, payload = mock_broadcast.call_args.args assert payload["conversation_key"] == "abc123" + @pytest.mark.asyncio + async def test_repeater_message_skipped_not_stored(self, test_db): + """Messages from repeater contacts (type=2) are dropped, not stored.""" + from app.event_handlers import on_contact_message + + repeater_key = "dd" * 32 + await ContactRepository.upsert( + { + "public_key": repeater_key, + "name": "MyRepeater", + "type": 2, # CONTACT_TYPE_REPEATER + "flags": 0, + } + ) + + with patch("app.event_handlers.broadcast_event") as mock_broadcast: + + class MockEvent: + payload = { + "public_key": repeater_key, + "text": "Some repeater noise", + "txt_type": 0, + "sender_timestamp": 1700000000, + } + + await on_contact_message(MockEvent()) + + mock_broadcast.assert_not_called() + + messages = await MessageRepository.get_all() + assert len(messages) == 0 + class TestEventHandlerRegistration: """Test event handler registration and cleanup.""" diff --git a/tests/test_message_pagination.py b/tests/test_message_pagination.py index 7743ffe..1256786 100644 --- a/tests/test_message_pagination.py +++ b/tests/test_message_pagination.py @@ -24,16 +24,18 @@ async def test_db(): await db.disconnect() +CHAN_KEY = "ABC123DEF456ABC123DEF456ABC12345" +DM_KEY = "aa" * 32 + + @pytest.mark.asyncio async def test_cursor_pagination_avoids_overlap(test_db): - key = "ABC123DEF456ABC123DEF456ABC12345" - ids = [] for received_at, text in [(200, "m1"), (200, "m2"), (150, "m3"), (100, "m4")]: msg_id = await MessageRepository.create( msg_type="CHAN", text=text, - conversation_key=key, + conversation_key=CHAN_KEY, sender_timestamp=received_at, received_at=received_at, ) @@ -42,7 +44,7 @@ async def test_cursor_pagination_avoids_overlap(test_db): page1 = await MessageRepository.get_all( msg_type="CHAN", - conversation_key=key, + conversation_key=CHAN_KEY, limit=2, offset=0, ) @@ -51,7 +53,7 @@ async def test_cursor_pagination_avoids_overlap(test_db): oldest = page1[-1] page2 = await MessageRepository.get_all( msg_type="CHAN", - conversation_key=key, + conversation_key=CHAN_KEY, limit=2, offset=0, before=oldest.received_at, @@ -62,3 +64,163 @@ async def test_cursor_pagination_avoids_overlap(test_db): ids_page1 = {m.id for m in page1} ids_page2 = {m.id for m in page2} assert ids_page1.isdisjoint(ids_page2) + + +@pytest.mark.asyncio +async def test_empty_page_when_no_messages(test_db): + """Pagination on a conversation with no messages returns empty list.""" + result = await MessageRepository.get_all( + msg_type="CHAN", + conversation_key=CHAN_KEY, + limit=50, + ) + assert result == [] + + +@pytest.mark.asyncio +async def test_empty_page_after_oldest_message(test_db): + """Requesting a page before the oldest message returns empty list.""" + msg_id = await MessageRepository.create( + msg_type="CHAN", + text="only message", + conversation_key=CHAN_KEY, + sender_timestamp=100, + received_at=100, + ) + assert msg_id is not None + + # Use before cursor pointing at the only message — should get nothing + result = await MessageRepository.get_all( + msg_type="CHAN", + conversation_key=CHAN_KEY, + limit=50, + before=100, + before_id=msg_id, + ) + assert result == [] + + +@pytest.mark.asyncio +async def test_timestamp_tie_uses_id_tiebreaker(test_db): + """Multiple messages with the same received_at are ordered by id DESC.""" + ids = [] + for text in ["first", "second", "third"]: + msg_id = await MessageRepository.create( + msg_type="CHAN", + text=text, + conversation_key=CHAN_KEY, + sender_timestamp=500, + received_at=500, + ) + assert msg_id is not None + ids.append(msg_id) + + # All three at same timestamp; page of 2 should get the two highest IDs + page1 = await MessageRepository.get_all( + msg_type="CHAN", + conversation_key=CHAN_KEY, + limit=2, + ) + assert len(page1) == 2 + assert page1[0].id == ids[2] # "third" (highest id) + assert page1[1].id == ids[1] # "second" + + # Cursor from page1's last entry should get the remaining one + page2 = await MessageRepository.get_all( + msg_type="CHAN", + conversation_key=CHAN_KEY, + limit=2, + before=page1[-1].received_at, + before_id=page1[-1].id, + ) + assert len(page2) == 1 + assert page2[0].id == ids[0] # "first" (lowest id) + + +@pytest.mark.asyncio +async def test_conversation_key_isolates_messages(test_db): + """Messages from different conversations don't leak into each other's pages.""" + other_key = "FF" * 16 + + await MessageRepository.create( + msg_type="CHAN", + text="chan1", + conversation_key=CHAN_KEY, + sender_timestamp=100, + received_at=100, + ) + await MessageRepository.create( + msg_type="CHAN", + text="chan2", + conversation_key=other_key, + sender_timestamp=100, + received_at=100, + ) + + result = await MessageRepository.get_all( + msg_type="CHAN", + conversation_key=CHAN_KEY, + limit=50, + ) + assert len(result) == 1 + assert result[0].text == "chan1" + + +@pytest.mark.asyncio +async def test_limit_respected(test_db): + """Returned page never exceeds the requested limit.""" + for i in range(10): + await MessageRepository.create( + msg_type="CHAN", + text=f"msg{i}", + conversation_key=CHAN_KEY, + sender_timestamp=100 + i, + received_at=100 + i, + ) + + result = await MessageRepository.get_all( + msg_type="CHAN", + conversation_key=CHAN_KEY, + limit=3, + ) + assert len(result) == 3 + + +@pytest.mark.asyncio +async def test_full_walk_collects_all_messages(test_db): + """Walking through all pages collects every message exactly once.""" + total = 7 + for i in range(total): + await MessageRepository.create( + msg_type="CHAN", + text=f"msg{i}", + conversation_key=CHAN_KEY, + sender_timestamp=100 + i, + received_at=100 + i, + ) + + collected_ids: list[int] = [] + before = None + before_id = None + + for _ in range(total): # safety bound + kwargs: dict = { + "msg_type": "CHAN", + "conversation_key": CHAN_KEY, + "limit": 3, + } + if before is not None: + kwargs["before"] = before + kwargs["before_id"] = before_id + else: + kwargs["offset"] = 0 + + page = await MessageRepository.get_all(**kwargs) + if not page: + break + collected_ids.extend(m.id for m in page) + before = page[-1].received_at + before_id = page[-1].id + + assert len(collected_ids) == total + assert len(set(collected_ids)) == total # no duplicates