Misc cruft -- filtering, pagination tests, etc.

This commit is contained in:
Jack Kingsman
2026-02-24 21:03:24 -08:00
parent 684724913f
commit 1c2fb148bc
6 changed files with 210 additions and 11 deletions

View File

@@ -38,6 +38,13 @@ async def lifespan(app: FastAPI):
await db.connect()
logger.info("Database connected")
# Ensure default channels exist in the database even before the radio
# connects. Without this, a fresh or disconnected instance would return
# zero channels from GET /channels until the first successful radio sync.
from app.radio_sync import ensure_default_channels
await ensure_default_channels()
try:
await radio_manager.connect()
logger.info("Connected to radio")

View File

@@ -724,7 +724,6 @@ class MessageRepository:
"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
MAX(m.received_at) as last_message_time,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0
@@ -749,7 +748,6 @@ class MessageRepository:
"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
MAX(m.received_at) as last_message_time,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0

View File

@@ -8,7 +8,6 @@ from pydantic import BaseModel, Field
from app.dependencies import require_connected
from app.models import Channel
from app.radio import radio_manager
from app.radio_sync import ensure_default_channels
from app.repository import ChannelRepository
logger = logging.getLogger(__name__)
@@ -26,8 +25,6 @@ class CreateChannelRequest(BaseModel):
@router.get("", response_model=list[Channel])
async def list_channels() -> list[Channel]:
"""List all channels from the database."""
# Ensure Public channel always exists (self-healing)
await ensure_default_channels()
return await ChannelRepository.get_all()

View File

@@ -1,5 +1,6 @@
import logging
from hashlib import sha256
from sqlite3 import OperationalError
import aiosqlite
from fastapi import APIRouter, BackgroundTasks
@@ -288,7 +289,9 @@ async def run_maintenance(request: MaintenanceRequest) -> MaintenanceResult:
await vacuum_conn.executescript("VACUUM;")
vacuumed = True
logger.info("Database vacuumed")
except Exception as e:
except OperationalError as e:
logger.warning("VACUUM skipped (database busy): %s", e)
except Exception as e:
logger.error("VACUUM failed unexpectedly: %s", e)
return MaintenanceResult(packets_deleted=deleted, vacuumed=vacuumed)

View File

@@ -463,6 +463,38 @@ class TestContactMessageCLIFiltering:
_, payload = mock_broadcast.call_args.args
assert payload["conversation_key"] == "abc123"
@pytest.mark.asyncio
async def test_repeater_message_skipped_not_stored(self, test_db):
"""Messages from repeater contacts (type=2) are dropped, not stored."""
from app.event_handlers import on_contact_message
repeater_key = "dd" * 32
await ContactRepository.upsert(
{
"public_key": repeater_key,
"name": "MyRepeater",
"type": 2, # CONTACT_TYPE_REPEATER
"flags": 0,
}
)
with patch("app.event_handlers.broadcast_event") as mock_broadcast:
class MockEvent:
payload = {
"public_key": repeater_key,
"text": "Some repeater noise",
"txt_type": 0,
"sender_timestamp": 1700000000,
}
await on_contact_message(MockEvent())
mock_broadcast.assert_not_called()
messages = await MessageRepository.get_all()
assert len(messages) == 0
class TestEventHandlerRegistration:
"""Test event handler registration and cleanup."""

View File

@@ -24,16 +24,18 @@ async def test_db():
await db.disconnect()
CHAN_KEY = "ABC123DEF456ABC123DEF456ABC12345"
DM_KEY = "aa" * 32
@pytest.mark.asyncio
async def test_cursor_pagination_avoids_overlap(test_db):
key = "ABC123DEF456ABC123DEF456ABC12345"
ids = []
for received_at, text in [(200, "m1"), (200, "m2"), (150, "m3"), (100, "m4")]:
msg_id = await MessageRepository.create(
msg_type="CHAN",
text=text,
conversation_key=key,
conversation_key=CHAN_KEY,
sender_timestamp=received_at,
received_at=received_at,
)
@@ -42,7 +44,7 @@ async def test_cursor_pagination_avoids_overlap(test_db):
page1 = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=key,
conversation_key=CHAN_KEY,
limit=2,
offset=0,
)
@@ -51,7 +53,7 @@ async def test_cursor_pagination_avoids_overlap(test_db):
oldest = page1[-1]
page2 = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=key,
conversation_key=CHAN_KEY,
limit=2,
offset=0,
before=oldest.received_at,
@@ -62,3 +64,163 @@ async def test_cursor_pagination_avoids_overlap(test_db):
ids_page1 = {m.id for m in page1}
ids_page2 = {m.id for m in page2}
assert ids_page1.isdisjoint(ids_page2)
@pytest.mark.asyncio
async def test_empty_page_when_no_messages(test_db):
"""Pagination on a conversation with no messages returns empty list."""
result = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=CHAN_KEY,
limit=50,
)
assert result == []
@pytest.mark.asyncio
async def test_empty_page_after_oldest_message(test_db):
"""Requesting a page before the oldest message returns empty list."""
msg_id = await MessageRepository.create(
msg_type="CHAN",
text="only message",
conversation_key=CHAN_KEY,
sender_timestamp=100,
received_at=100,
)
assert msg_id is not None
# Use before cursor pointing at the only message — should get nothing
result = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=CHAN_KEY,
limit=50,
before=100,
before_id=msg_id,
)
assert result == []
@pytest.mark.asyncio
async def test_timestamp_tie_uses_id_tiebreaker(test_db):
"""Multiple messages with the same received_at are ordered by id DESC."""
ids = []
for text in ["first", "second", "third"]:
msg_id = await MessageRepository.create(
msg_type="CHAN",
text=text,
conversation_key=CHAN_KEY,
sender_timestamp=500,
received_at=500,
)
assert msg_id is not None
ids.append(msg_id)
# All three at same timestamp; page of 2 should get the two highest IDs
page1 = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=CHAN_KEY,
limit=2,
)
assert len(page1) == 2
assert page1[0].id == ids[2] # "third" (highest id)
assert page1[1].id == ids[1] # "second"
# Cursor from page1's last entry should get the remaining one
page2 = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=CHAN_KEY,
limit=2,
before=page1[-1].received_at,
before_id=page1[-1].id,
)
assert len(page2) == 1
assert page2[0].id == ids[0] # "first" (lowest id)
@pytest.mark.asyncio
async def test_conversation_key_isolates_messages(test_db):
"""Messages from different conversations don't leak into each other's pages."""
other_key = "FF" * 16
await MessageRepository.create(
msg_type="CHAN",
text="chan1",
conversation_key=CHAN_KEY,
sender_timestamp=100,
received_at=100,
)
await MessageRepository.create(
msg_type="CHAN",
text="chan2",
conversation_key=other_key,
sender_timestamp=100,
received_at=100,
)
result = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=CHAN_KEY,
limit=50,
)
assert len(result) == 1
assert result[0].text == "chan1"
@pytest.mark.asyncio
async def test_limit_respected(test_db):
"""Returned page never exceeds the requested limit."""
for i in range(10):
await MessageRepository.create(
msg_type="CHAN",
text=f"msg{i}",
conversation_key=CHAN_KEY,
sender_timestamp=100 + i,
received_at=100 + i,
)
result = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=CHAN_KEY,
limit=3,
)
assert len(result) == 3
@pytest.mark.asyncio
async def test_full_walk_collects_all_messages(test_db):
"""Walking through all pages collects every message exactly once."""
total = 7
for i in range(total):
await MessageRepository.create(
msg_type="CHAN",
text=f"msg{i}",
conversation_key=CHAN_KEY,
sender_timestamp=100 + i,
received_at=100 + i,
)
collected_ids: list[int] = []
before = None
before_id = None
for _ in range(total): # safety bound
kwargs: dict = {
"msg_type": "CHAN",
"conversation_key": CHAN_KEY,
"limit": 3,
}
if before is not None:
kwargs["before"] = before
kwargs["before_id"] = before_id
else:
kwargs["offset"] = 0
page = await MessageRepository.get_all(**kwargs)
if not page:
break
collected_ids.extend(m.id for m in page)
before = page[-1].received_at
before_id = page[-1].id
assert len(collected_ids) == total
assert len(set(collected_ids)) == total # no duplicates