diff --git a/app/database.py b/app/database.py index 6ec9d22..197664e 100644 --- a/app/database.py +++ b/app/database.py @@ -1,4 +1,7 @@ +import asyncio import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from pathlib import Path import aiosqlite @@ -165,9 +168,74 @@ CREATE INDEX IF NOT EXISTS idx_repeater_telemetry_pk_ts class Database: + """Single-connection aiosqlite wrapper with coroutine-level serialization. + + Why the lock: aiosqlite runs one ``sqlite3.Connection`` on a background + worker thread and serializes statement execution there. But SQLite's + ``COMMIT`` fails with ``OperationalError: cannot commit transaction - + SQL statements in progress`` whenever *any* cursor on the connection has + a live prepared statement (a ``SELECT`` that returned ``SQLITE_ROW`` but + hasn't been fully consumed or closed). Under concurrent coroutines, one + task's in-flight ``fetchone()`` can still be in ``SQLITE_ROW`` state when + another task's ``commit()`` runs on the worker — triggering the error. + + Fix: all DB work goes through ``tx()`` (writes) or ``readonly()`` (reads), + both of which acquire ``self._lock``. The lock is non-reentrant (asyncio + default) by design — nested ``tx()`` calls are a bug. Repository methods + that compose multiple operations factor the raw SQL into private helpers + that take a ``conn`` and don't lock; the public method acquires the lock + once and calls those helpers. + + Why reads are also locked: reads must also hold the lock, because a read + in ``SQLITE_ROW`` state is precisely the live statement that breaks a + concurrent writer's commit. Single-connection aiosqlite cannot safely + overlap reads and writes. If we ever split reader/writer connections in + the future, ``readonly()`` becomes the seam to point at the reader pool. + """ + def __init__(self, db_path: str): self.db_path = db_path self._connection: aiosqlite.Connection | None = None + self._lock = asyncio.Lock() + + @asynccontextmanager + async def tx(self) -> AsyncIterator[aiosqlite.Connection]: + """Acquire the connection for a write transaction. + + Commits on clean exit, rolls back on exception. Callers MUST close + every cursor opened inside the block (use ``async with conn.execute(...) + as cursor:``) so no prepared statement is alive when commit runs. + + The lock serializes concurrent writers AND ensures no reader's cursor + is alive during the commit. Nested calls will deadlock — factor shared + SQL into helpers that accept ``conn`` and do not re-enter ``tx()``. + """ + async with self._lock: + if self._connection is None: + raise RuntimeError("Database not connected") + conn = self._connection + try: + yield conn + except BaseException: + await conn.rollback() + raise + else: + await conn.commit() + + @asynccontextmanager + async def readonly(self) -> AsyncIterator[aiosqlite.Connection]: + """Acquire the connection for a read. No commit, no rollback. + + Locked for the same reason writes are: on a single connection, an + active read statement blocks a concurrent writer's commit. Callers + MUST fully consume or close cursors before the block exits (use + ``async with conn.execute(...) as cursor:`` + ``fetchall`` / + ``fetchone``; avoid holding a cursor across ``await`` on other IO). + """ + async with self._lock: + if self._connection is None: + raise RuntimeError("Database not connected") + yield self._connection async def connect(self) -> None: logger.info("Connecting to database at %s", self.db_path) diff --git a/app/repository/channels.py b/app/repository/channels.py index 47f232c..efe5a90 100644 --- a/app/repository/channels.py +++ b/app/repository/channels.py @@ -8,31 +8,33 @@ class ChannelRepository: @staticmethod async def upsert(key: str, name: str, is_hashtag: bool = False, on_radio: bool = False) -> None: """Upsert a channel. Key is 32-char hex string.""" - await db.conn.execute( - """ - INSERT INTO channels (key, name, is_hashtag, on_radio, flood_scope_override) - VALUES (?, ?, ?, ?, NULL) - ON CONFLICT(key) DO UPDATE SET - name = excluded.name, - is_hashtag = excluded.is_hashtag, - on_radio = excluded.on_radio - """, - (key.upper(), name, is_hashtag, on_radio), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """ + INSERT INTO channels (key, name, is_hashtag, on_radio, flood_scope_override) + VALUES (?, ?, ?, ?, NULL) + ON CONFLICT(key) DO UPDATE SET + name = excluded.name, + is_hashtag = excluded.is_hashtag, + on_radio = excluded.on_radio + """, + (key.upper(), name, is_hashtag, on_radio), + ): + pass @staticmethod async def get_by_key(key: str) -> Channel | None: """Get a channel by its key (32-char hex string).""" - cursor = await db.conn.execute( - """ - SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite - FROM channels - WHERE key = ? - """, - (key.upper(),), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite + FROM channels + WHERE key = ? + """, + (key.upper(),), + ) as cursor: + row = await cursor.fetchone() if row: return Channel( key=row["key"], @@ -48,14 +50,15 @@ class ChannelRepository: @staticmethod async def get_all() -> list[Channel]: - cursor = await db.conn.execute( - """ - SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite - FROM channels - ORDER BY name - """ - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite + FROM channels + ORDER BY name + """ + ) as cursor: + rows = await cursor.fetchall() return [ Channel( key=row["key"], @@ -73,21 +76,23 @@ class ChannelRepository: @staticmethod async def set_favorite(key: str, value: bool) -> bool: """Set or clear the favorite flag for a channel. Returns True if row was found.""" - cursor = await db.conn.execute( - "UPDATE channels SET favorite = ? WHERE key = ?", - (1 if value else 0, key.upper()), - ) - await db.conn.commit() - return cursor.rowcount > 0 + async with db.tx() as conn: + async with conn.execute( + "UPDATE channels SET favorite = ? WHERE key = ?", + (1 if value else 0, key.upper()), + ) as cursor: + rowcount = cursor.rowcount + return rowcount > 0 @staticmethod async def delete(key: str) -> None: """Delete a channel by key.""" - await db.conn.execute( - "DELETE FROM channels WHERE key = ?", - (key.upper(),), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + "DELETE FROM channels WHERE key = ?", + (key.upper(),), + ): + pass @staticmethod async def update_last_read_at(key: str, timestamp: int | None = None) -> bool: @@ -96,35 +101,39 @@ class ChannelRepository: Returns True if a row was updated, False if channel not found. """ ts = timestamp if timestamp is not None else int(time.time()) - cursor = await db.conn.execute( - "UPDATE channels SET last_read_at = ? WHERE key = ?", - (ts, key.upper()), - ) - await db.conn.commit() - return cursor.rowcount > 0 + async with db.tx() as conn: + async with conn.execute( + "UPDATE channels SET last_read_at = ? WHERE key = ?", + (ts, key.upper()), + ) as cursor: + rowcount = cursor.rowcount + return rowcount > 0 @staticmethod async def update_flood_scope_override(key: str, flood_scope_override: str | None) -> bool: """Set or clear a channel's flood-scope override.""" - cursor = await db.conn.execute( - "UPDATE channels SET flood_scope_override = ? WHERE key = ?", - (flood_scope_override, key.upper()), - ) - await db.conn.commit() - return cursor.rowcount > 0 + async with db.tx() as conn: + async with conn.execute( + "UPDATE channels SET flood_scope_override = ? WHERE key = ?", + (flood_scope_override, key.upper()), + ) as cursor: + rowcount = cursor.rowcount + return rowcount > 0 @staticmethod async def update_path_hash_mode_override(key: str, path_hash_mode_override: int | None) -> bool: """Set or clear a channel's path hash mode override.""" - cursor = await db.conn.execute( - "UPDATE channels SET path_hash_mode_override = ? WHERE key = ?", - (path_hash_mode_override, key.upper()), - ) - await db.conn.commit() - return cursor.rowcount > 0 + async with db.tx() as conn: + async with conn.execute( + "UPDATE channels SET path_hash_mode_override = ? WHERE key = ?", + (path_hash_mode_override, key.upper()), + ) as cursor: + rowcount = cursor.rowcount + return rowcount > 0 @staticmethod async def mark_all_read(timestamp: int) -> None: """Mark all channels as read at the given timestamp.""" - await db.conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,)) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,)): + pass diff --git a/app/repository/contacts.py b/app/repository/contacts.py index ff542fa..9a5a45e 100644 --- a/app/repository/contacts.py +++ b/app/repository/contacts.py @@ -61,71 +61,72 @@ class ContactRepository: ) ) - await db.conn.execute( - """ - INSERT INTO contacts (public_key, name, type, flags, direct_path, direct_path_len, - direct_path_hash_mode, direct_path_updated_at, - route_override_path, route_override_len, - route_override_hash_mode, - last_advert, lat, lon, last_seen, - on_radio, last_contacted, first_seen) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(public_key) DO UPDATE SET - name = COALESCE(excluded.name, contacts.name), - type = CASE WHEN excluded.type = 0 THEN contacts.type ELSE excluded.type END, - flags = excluded.flags, - direct_path = COALESCE(excluded.direct_path, contacts.direct_path), - direct_path_len = COALESCE(excluded.direct_path_len, contacts.direct_path_len), - direct_path_hash_mode = COALESCE( - excluded.direct_path_hash_mode, contacts.direct_path_hash_mode + async with db.tx() as conn: + async with conn.execute( + """ + INSERT INTO contacts (public_key, name, type, flags, direct_path, direct_path_len, + direct_path_hash_mode, direct_path_updated_at, + route_override_path, route_override_len, + route_override_hash_mode, + last_advert, lat, lon, last_seen, + on_radio, last_contacted, first_seen) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(public_key) DO UPDATE SET + name = COALESCE(excluded.name, contacts.name), + type = CASE WHEN excluded.type = 0 THEN contacts.type ELSE excluded.type END, + flags = excluded.flags, + direct_path = COALESCE(excluded.direct_path, contacts.direct_path), + direct_path_len = COALESCE(excluded.direct_path_len, contacts.direct_path_len), + direct_path_hash_mode = COALESCE( + excluded.direct_path_hash_mode, contacts.direct_path_hash_mode + ), + direct_path_updated_at = COALESCE( + excluded.direct_path_updated_at, contacts.direct_path_updated_at + ), + route_override_path = COALESCE( + excluded.route_override_path, contacts.route_override_path + ), + route_override_len = COALESCE( + excluded.route_override_len, contacts.route_override_len + ), + route_override_hash_mode = COALESCE( + excluded.route_override_hash_mode, contacts.route_override_hash_mode + ), + last_advert = COALESCE(excluded.last_advert, contacts.last_advert), + lat = COALESCE(excluded.lat, contacts.lat), + lon = COALESCE(excluded.lon, contacts.lon), + last_seen = CASE + WHEN excluded.last_seen IS NULL THEN contacts.last_seen + WHEN contacts.last_seen IS NULL THEN excluded.last_seen + WHEN excluded.last_seen > contacts.last_seen THEN excluded.last_seen + ELSE contacts.last_seen + END, + on_radio = COALESCE(excluded.on_radio, contacts.on_radio), + last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted), + first_seen = COALESCE(contacts.first_seen, excluded.first_seen) + """, + ( + contact_row.public_key.lower(), + contact_row.name, + contact_row.type, + contact_row.flags, + direct_path, + direct_path_len, + direct_path_hash_mode, + contact_row.direct_path_updated_at, + route_override_path, + route_override_len, + route_override_hash_mode, + contact_row.last_advert, + contact_row.lat, + contact_row.lon, + contact_row.last_seen, + contact_row.on_radio, + contact_row.last_contacted, + contact_row.first_seen, ), - direct_path_updated_at = COALESCE( - excluded.direct_path_updated_at, contacts.direct_path_updated_at - ), - route_override_path = COALESCE( - excluded.route_override_path, contacts.route_override_path - ), - route_override_len = COALESCE( - excluded.route_override_len, contacts.route_override_len - ), - route_override_hash_mode = COALESCE( - excluded.route_override_hash_mode, contacts.route_override_hash_mode - ), - last_advert = COALESCE(excluded.last_advert, contacts.last_advert), - lat = COALESCE(excluded.lat, contacts.lat), - lon = COALESCE(excluded.lon, contacts.lon), - last_seen = CASE - WHEN excluded.last_seen IS NULL THEN contacts.last_seen - WHEN contacts.last_seen IS NULL THEN excluded.last_seen - WHEN excluded.last_seen > contacts.last_seen THEN excluded.last_seen - ELSE contacts.last_seen - END, - on_radio = COALESCE(excluded.on_radio, contacts.on_radio), - last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted), - first_seen = COALESCE(contacts.first_seen, excluded.first_seen) - """, - ( - contact_row.public_key.lower(), - contact_row.name, - contact_row.type, - contact_row.flags, - direct_path, - direct_path_len, - direct_path_hash_mode, - contact_row.direct_path_updated_at, - route_override_path, - route_override_len, - route_override_hash_mode, - contact_row.last_advert, - contact_row.lat, - contact_row.lon, - contact_row.last_seen, - contact_row.on_radio, - contact_row.last_contacted, - contact_row.first_seen, - ), - ) - await db.conn.commit() + ): + pass @staticmethod def _row_to_contact(row) -> Contact: @@ -183,10 +184,11 @@ class ContactRepository: @staticmethod async def get_by_key(public_key: str) -> Contact | None: - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),) - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),) + ) as cursor: + row = await cursor.fetchone() return ContactRepository._row_to_contact(row) if row else None @staticmethod @@ -200,11 +202,12 @@ class ContactRepository: exact = await ContactRepository.get_by_key(normalized_prefix) if exact: return exact - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2", - (f"{normalized_prefix}%",), - ) - rows = list(await cursor.fetchall()) + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2", + (f"{normalized_prefix}%",), + ) as cursor: + rows = list(await cursor.fetchall()) if len(rows) != 1: return None return ContactRepository._row_to_contact(rows[0]) @@ -212,11 +215,12 @@ class ContactRepository: @staticmethod async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]: """Get contacts matching a key prefix, up to limit.""" - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?", - (f"{prefix.lower()}%", limit), - ) - rows = list(await cursor.fetchall()) + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?", + (f"{prefix.lower()}%", limit), + ) as cursor: + rows = list(await cursor.fetchall()) return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod @@ -242,8 +246,9 @@ class ContactRepository: @staticmethod async def get_by_name(name: str) -> list[Contact]: """Get all contacts with the given exact name.""" - cursor = await db.conn.execute("SELECT * FROM contacts WHERE name = ?", (name,)) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute("SELECT * FROM contacts WHERE name = ?", (name,)) as cursor: + rows = await cursor.fetchall() return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod @@ -259,8 +264,9 @@ class ContactRepository: normalized = [p.lower() for p in prefixes] conditions = " OR ".join(["public_key LIKE ?"] * len(normalized)) params = [f"{p}%" for p in normalized] - cursor = await db.conn.execute(f"SELECT * FROM contacts WHERE {conditions}", params) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute(f"SELECT * FROM contacts WHERE {conditions}", params) as cursor: + rows = await cursor.fetchall() # Group by which prefix each row matches prefix_to_rows: dict[str, list] = {p: [] for p in normalized} for row in rows: @@ -277,63 +283,67 @@ class ContactRepository: @staticmethod async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]: - cursor = await db.conn.execute( - "SELECT * FROM contacts ORDER BY COALESCE(name, public_key) LIMIT ? OFFSET ?", - (limit, offset), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM contacts ORDER BY COALESCE(name, public_key) LIMIT ? OFFSET ?", + (limit, offset), + ) as cursor: + rows = await cursor.fetchall() return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod async def get_recently_contacted_non_repeaters(limit: int = 200) -> list[Contact]: """Get recently interacted-with non-repeater contacts.""" - cursor = await db.conn.execute( - """ - SELECT * FROM contacts - WHERE type != 2 AND last_contacted IS NOT NULL AND length(public_key) = 64 - ORDER BY last_contacted DESC - LIMIT ? - """, - (limit,), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT * FROM contacts + WHERE type != 2 AND last_contacted IS NOT NULL AND length(public_key) = 64 + ORDER BY last_contacted DESC + LIMIT ? + """, + (limit,), + ) as cursor: + rows = await cursor.fetchall() return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod async def get_recently_dm_active_non_repeaters(limit: int = 200) -> list[Contact]: """Get non-repeater contacts with the most recent DM activity (sent or received).""" - cursor = await db.conn.execute( - """ - SELECT c.* - FROM contacts c - INNER JOIN ( - SELECT conversation_key, MAX(received_at) AS last_dm - FROM messages - WHERE type = 'PRIV' - GROUP BY conversation_key - ) m ON c.public_key = m.conversation_key - WHERE c.type != 2 AND length(c.public_key) = 64 - ORDER BY m.last_dm DESC - LIMIT ? - """, - (limit,), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT c.* + FROM contacts c + INNER JOIN ( + SELECT conversation_key, MAX(received_at) AS last_dm + FROM messages + WHERE type = 'PRIV' + GROUP BY conversation_key + ) m ON c.public_key = m.conversation_key + WHERE c.type != 2 AND length(c.public_key) = 64 + ORDER BY m.last_dm DESC + LIMIT ? + """, + (limit,), + ) as cursor: + rows = await cursor.fetchall() return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod async def get_recently_advertised_non_repeaters(limit: int = 200) -> list[Contact]: """Get recently advert-heard non-repeater contacts.""" - cursor = await db.conn.execute( - """ - SELECT * FROM contacts - WHERE type != 2 AND last_advert IS NOT NULL AND length(public_key) = 64 - ORDER BY last_advert DESC - LIMIT ? - """, - (limit,), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT * FROM contacts + WHERE type != 2 AND last_advert IS NOT NULL AND length(public_key) = 64 + ORDER BY last_advert DESC + LIMIT ? + """, + (limit,), + ) as cursor: + rows = await cursor.fetchall() return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod @@ -359,28 +369,29 @@ class ContactRepository: path_hash_mode, ) ts = updated_at if updated_at is not None else int(time.time()) - await db.conn.execute( - """UPDATE contacts SET direct_path = ?, direct_path_len = ?, - direct_path_hash_mode = COALESCE(?, direct_path_hash_mode), - direct_path_updated_at = ?, - last_seen = CASE - WHEN last_seen IS NULL THEN ? - WHEN ? > last_seen THEN ? - ELSE last_seen - END - WHERE public_key = ?""", - ( - normalized_path, - normalized_path_len, - normalized_hash_mode, - ts, - ts, - ts, - ts, - public_key.lower(), - ), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """UPDATE contacts SET direct_path = ?, direct_path_len = ?, + direct_path_hash_mode = COALESCE(?, direct_path_hash_mode), + direct_path_updated_at = ?, + last_seen = CASE + WHEN last_seen IS NULL THEN ? + WHEN ? > last_seen THEN ? + ELSE last_seen + END + WHERE public_key = ?""", + ( + normalized_path, + normalized_path_len, + normalized_hash_mode, + ts, + ts, + ts, + ts, + public_key.lower(), + ), + ): + pass @staticmethod async def set_routing_override( @@ -394,65 +405,71 @@ class ContactRepository: path_len, path_hash_mode, ) - await db.conn.execute( - """ - UPDATE contacts - SET route_override_path = ?, route_override_len = ?, route_override_hash_mode = ? - WHERE public_key = ? - """, - ( - normalized_path, - normalized_len, - normalized_hash_mode, - public_key.lower(), - ), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """ + UPDATE contacts + SET route_override_path = ?, route_override_len = ?, route_override_hash_mode = ? + WHERE public_key = ? + """, + ( + normalized_path, + normalized_len, + normalized_hash_mode, + public_key.lower(), + ), + ): + pass @staticmethod async def clear_routing_override(public_key: str) -> None: - await db.conn.execute( - """ - UPDATE contacts - SET route_override_path = NULL, - route_override_len = NULL, - route_override_hash_mode = NULL - WHERE public_key = ? - """, - (public_key.lower(),), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """ + UPDATE contacts + SET route_override_path = NULL, + route_override_len = NULL, + route_override_hash_mode = NULL + WHERE public_key = ? + """, + (public_key.lower(),), + ): + pass @staticmethod async def clear_on_radio_except(keep_keys: list[str]) -> None: """Set on_radio=False for all contacts NOT in keep_keys.""" - if not keep_keys: - await db.conn.execute("UPDATE contacts SET on_radio = 0 WHERE on_radio = 1") - else: - placeholders = ",".join("?" * len(keep_keys)) - await db.conn.execute( - f"UPDATE contacts SET on_radio = 0 WHERE on_radio = 1 AND public_key NOT IN ({placeholders})", - keep_keys, - ) - await db.conn.commit() + async with db.tx() as conn: + if not keep_keys: + async with conn.execute("UPDATE contacts SET on_radio = 0 WHERE on_radio = 1"): + pass + else: + placeholders = ",".join("?" * len(keep_keys)) + async with conn.execute( + f"UPDATE contacts SET on_radio = 0 WHERE on_radio = 1 AND public_key NOT IN ({placeholders})", + keep_keys, + ): + pass @staticmethod async def get_favorites() -> list[Contact]: """Return all contacts marked as favorite.""" - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE favorite = 1 AND LENGTH(public_key) = 64" - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM contacts WHERE favorite = 1 AND LENGTH(public_key) = 64" + ) as cursor: + rows = await cursor.fetchall() return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod async def set_favorite(public_key: str, value: bool) -> None: """Set or clear the favorite flag for a contact.""" - await db.conn.execute( - "UPDATE contacts SET favorite = ? WHERE public_key = ?", - (1 if value else 0, public_key.lower()), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + "UPDATE contacts SET favorite = ? WHERE public_key = ?", + (1 if value else 0, public_key.lower()), + ): + pass @staticmethod async def delete(public_key: str) -> None: @@ -460,8 +477,9 @@ class ContactRepository: # contact_name_history and contact_advert_paths cascade via FK. # Messages are intentionally preserved so history re-surfaces # if the contact is re-added later. - await db.conn.execute("DELETE FROM contacts WHERE public_key = ?", (normalized,)) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute("DELETE FROM contacts WHERE public_key = ?", (normalized,)): + pass @staticmethod async def update_last_contacted(public_key: str, timestamp: int | None = None) -> None: @@ -477,11 +495,12 @@ class ContactRepository: ``last_seen`` via :meth:`touch_last_seen` on incoming DMs only. """ ts = timestamp if timestamp is not None else int(time.time()) - await db.conn.execute( - "UPDATE contacts SET last_contacted = ? WHERE public_key = ?", - (ts, public_key.lower()), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + "UPDATE contacts SET last_contacted = ? WHERE public_key = ?", + (ts, public_key.lower()), + ): + pass @staticmethod async def touch_last_seen(public_key: str, timestamp: int) -> None: @@ -491,19 +510,20 @@ class ContactRepository: exist. Use this from packet-ingest paths that have attributed a packet to a specific contact pubkey (advert, incoming DM, decrypted PATH, etc.). """ - await db.conn.execute( - """ - UPDATE contacts - SET last_seen = CASE - WHEN last_seen IS NULL THEN ? - WHEN ? > last_seen THEN ? - ELSE last_seen - END - WHERE public_key = ? - """, - (timestamp, timestamp, timestamp, public_key.lower()), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """ + UPDATE contacts + SET last_seen = CASE + WHEN last_seen IS NULL THEN ? + WHEN ? > last_seen THEN ? + ELSE last_seen + END + WHERE public_key = ? + """, + (timestamp, timestamp, timestamp, public_key.lower()), + ): + pass @staticmethod async def update_last_read_at(public_key: str, timestamp: int | None = None) -> bool: @@ -512,22 +532,25 @@ class ContactRepository: Returns True if a row was updated, False if contact not found. """ ts = timestamp if timestamp is not None else int(time.time()) - cursor = await db.conn.execute( - "UPDATE contacts SET last_read_at = ? WHERE public_key = ?", - (ts, public_key.lower()), - ) - await db.conn.commit() - return cursor.rowcount > 0 + async with db.tx() as conn: + async with conn.execute( + "UPDATE contacts SET last_read_at = ? WHERE public_key = ?", + (ts, public_key.lower()), + ) as cursor: + rowcount = cursor.rowcount + return rowcount > 0 @staticmethod async def promote_prefix_placeholders(full_key: str) -> list[str]: """Promote prefix-only placeholder contacts to a resolved full key. Returns the placeholder public keys that were merged into the full key. + All operations for the promotion happen inside one ``db.tx()`` so + partial promotions never leak to readers between steps. """ - async def migrate_child_rows(old_key: str, new_key: str) -> None: - await db.conn.execute( + async def migrate_child_rows(conn, old_key: str, new_key: str) -> None: + async with conn.execute( """ INSERT INTO contact_name_history (public_key, name, first_seen, last_seen) SELECT ?, name, first_seen, last_seen @@ -538,8 +561,9 @@ class ContactRepository: last_seen = MAX(contact_name_history.last_seen, excluded.last_seen) """, (new_key, old_key), - ) - await db.conn.execute( + ): + pass + async with conn.execute( """ INSERT INTO contact_advert_paths (public_key, path_hex, path_len, first_seen, last_seen, heard_count) @@ -552,132 +576,138 @@ class ContactRepository: heard_count = contact_advert_paths.heard_count + excluded.heard_count """, (new_key, old_key), - ) - await db.conn.execute( + ): + pass + async with conn.execute( "DELETE FROM contact_name_history WHERE public_key = ?", (old_key,), - ) - await db.conn.execute( + ): + pass + async with conn.execute( "DELETE FROM contact_advert_paths WHERE public_key = ?", (old_key,), - ) + ): + pass normalized_full_key = full_key.lower() - cursor = await db.conn.execute( - """ - SELECT public_key, last_seen, last_contacted, first_seen, last_read_at - FROM contacts - WHERE length(public_key) < 64 - AND ? LIKE public_key || '%' - ORDER BY length(public_key) DESC, public_key - """, - (normalized_full_key,), - ) - rows = list(await cursor.fetchall()) - if not rows: - return [] - promoted_keys: list[str] = [] - - for row in rows: - old_key = row["public_key"] - if old_key == normalized_full_key: - continue - - match_cursor = await db.conn.execute( + async with db.tx() as conn: + async with conn.execute( """ - SELECT COUNT(*) AS match_count + SELECT public_key, last_seen, last_contacted, first_seen, last_read_at FROM contacts - WHERE length(public_key) = 64 - AND public_key LIKE ? || '%' + WHERE length(public_key) < 64 + AND ? LIKE public_key || '%' + ORDER BY length(public_key) DESC, public_key """, - (old_key,), - ) - match_row = await match_cursor.fetchone() - match_count = match_row["match_count"] if match_row is not None else 0 - if match_count != 1: - logger.warning( - "Skipping prefix promotion for %s: %d full-key contacts match (expected 1)", - old_key, - match_count, - ) - continue + (normalized_full_key,), + ) as cursor: + rows = list(await cursor.fetchall()) + if not rows: + return [] - await migrate_child_rows(old_key, normalized_full_key) + for row in rows: + old_key = row["public_key"] + if old_key == normalized_full_key: + continue - # Merge timestamp metadata from the old prefix contact into the - # full-key contact (which all callers guarantee already exists), - # then delete the prefix placeholder. - await db.conn.execute( - """ - UPDATE contacts - SET last_seen = CASE - WHEN contacts.last_seen IS NULL THEN ? - WHEN ? IS NULL THEN contacts.last_seen - WHEN ? > contacts.last_seen THEN ? - ELSE contacts.last_seen - END, - last_contacted = CASE - WHEN contacts.last_contacted IS NULL THEN ? - WHEN ? IS NULL THEN contacts.last_contacted - WHEN ? > contacts.last_contacted THEN ? - ELSE contacts.last_contacted - END, - first_seen = CASE - WHEN contacts.first_seen IS NULL THEN ? - WHEN ? IS NULL THEN contacts.first_seen - WHEN ? < contacts.first_seen THEN ? - ELSE contacts.first_seen - END, - last_read_at = CASE - WHEN contacts.last_read_at IS NULL THEN ? - WHEN ? IS NULL THEN contacts.last_read_at - WHEN ? > contacts.last_read_at THEN ? - ELSE contacts.last_read_at - END - WHERE public_key = ? - """, - ( - row["last_seen"], - row["last_seen"], - row["last_seen"], - row["last_seen"], - row["last_contacted"], - row["last_contacted"], - row["last_contacted"], - row["last_contacted"], - row["first_seen"], - row["first_seen"], - row["first_seen"], - row["first_seen"], - row["last_read_at"], - row["last_read_at"], - row["last_read_at"], - row["last_read_at"], - normalized_full_key, - ), - ) - await db.conn.execute("DELETE FROM contacts WHERE public_key = ?", (old_key,)) + async with conn.execute( + """ + SELECT COUNT(*) AS match_count + FROM contacts + WHERE length(public_key) = 64 + AND public_key LIKE ? || '%' + """, + (old_key,), + ) as match_cursor: + match_row = await match_cursor.fetchone() + match_count = match_row["match_count"] if match_row is not None else 0 + if match_count != 1: + logger.warning( + "Skipping prefix promotion for %s: %d full-key contacts match (expected 1)", + old_key, + match_count, + ) + continue - promoted_keys.append(old_key) + await migrate_child_rows(conn, old_key, normalized_full_key) + + # Merge timestamp metadata from the old prefix contact into the + # full-key contact (which all callers guarantee already exists), + # then delete the prefix placeholder. + async with conn.execute( + """ + UPDATE contacts + SET last_seen = CASE + WHEN contacts.last_seen IS NULL THEN ? + WHEN ? IS NULL THEN contacts.last_seen + WHEN ? > contacts.last_seen THEN ? + ELSE contacts.last_seen + END, + last_contacted = CASE + WHEN contacts.last_contacted IS NULL THEN ? + WHEN ? IS NULL THEN contacts.last_contacted + WHEN ? > contacts.last_contacted THEN ? + ELSE contacts.last_contacted + END, + first_seen = CASE + WHEN contacts.first_seen IS NULL THEN ? + WHEN ? IS NULL THEN contacts.first_seen + WHEN ? < contacts.first_seen THEN ? + ELSE contacts.first_seen + END, + last_read_at = CASE + WHEN contacts.last_read_at IS NULL THEN ? + WHEN ? IS NULL THEN contacts.last_read_at + WHEN ? > contacts.last_read_at THEN ? + ELSE contacts.last_read_at + END + WHERE public_key = ? + """, + ( + row["last_seen"], + row["last_seen"], + row["last_seen"], + row["last_seen"], + row["last_contacted"], + row["last_contacted"], + row["last_contacted"], + row["last_contacted"], + row["first_seen"], + row["first_seen"], + row["first_seen"], + row["first_seen"], + row["last_read_at"], + row["last_read_at"], + row["last_read_at"], + row["last_read_at"], + normalized_full_key, + ), + ): + pass + async with conn.execute("DELETE FROM contacts WHERE public_key = ?", (old_key,)): + pass + + promoted_keys.append(old_key) - await db.conn.commit() return promoted_keys @staticmethod async def mark_all_read(timestamp: int) -> None: """Mark all contacts as read at the given timestamp.""" - await db.conn.execute("UPDATE contacts SET last_read_at = ?", (timestamp,)) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute("UPDATE contacts SET last_read_at = ?", (timestamp,)): + pass @staticmethod async def get_by_pubkey_first_byte(hex_byte: str) -> list[Contact]: """Get contacts whose public key starts with the given hex byte (2 chars).""" - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?", - (hex_byte.lower(),), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?", + (hex_byte.lower(),), + ) as cursor: + rows = await cursor.fetchall() return [ContactRepository._row_to_contact(row) for row in rows] @@ -716,71 +746,75 @@ class ContactAdvertPathRepository: normalized_path = path_hex.lower() path_len = hop_count if hop_count is not None else len(normalized_path) // 2 - await db.conn.execute( - """ - INSERT INTO contact_advert_paths - (public_key, path_hex, path_len, first_seen, last_seen, heard_count) - VALUES (?, ?, ?, ?, ?, 1) - ON CONFLICT(public_key, path_hex, path_len) DO UPDATE SET - last_seen = MAX(contact_advert_paths.last_seen, excluded.last_seen), - heard_count = contact_advert_paths.heard_count + 1 - """, - (normalized_key, normalized_path, path_len, timestamp, timestamp), - ) + async with db.tx() as conn: + async with conn.execute( + """ + INSERT INTO contact_advert_paths + (public_key, path_hex, path_len, first_seen, last_seen, heard_count) + VALUES (?, ?, ?, ?, ?, 1) + ON CONFLICT(public_key, path_hex, path_len) DO UPDATE SET + last_seen = MAX(contact_advert_paths.last_seen, excluded.last_seen), + heard_count = contact_advert_paths.heard_count + 1 + """, + (normalized_key, normalized_path, path_len, timestamp, timestamp), + ): + pass - # Keep only the N most recent unique paths per contact. - await db.conn.execute( - """ - DELETE FROM contact_advert_paths - WHERE public_key = ? - AND id NOT IN ( - SELECT id - FROM contact_advert_paths - WHERE public_key = ? - ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC - LIMIT ? - ) - """, - (normalized_key, normalized_key, max_paths), - ) - await db.conn.commit() + # Keep only the N most recent unique paths per contact. + async with conn.execute( + """ + DELETE FROM contact_advert_paths + WHERE public_key = ? + AND id NOT IN ( + SELECT id + FROM contact_advert_paths + WHERE public_key = ? + ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC + LIMIT ? + ) + """, + (normalized_key, normalized_key, max_paths), + ): + pass @staticmethod async def get_recent_for_contact(public_key: str, limit: int = 10) -> list[ContactAdvertPath]: - cursor = await db.conn.execute( - """ - SELECT path_hex, path_len, first_seen, last_seen, heard_count - FROM contact_advert_paths - WHERE public_key = ? - ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC - LIMIT ? - """, - (public_key.lower(), limit), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT path_hex, path_len, first_seen, last_seen, heard_count + FROM contact_advert_paths + WHERE public_key = ? + ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC + LIMIT ? + """, + (public_key.lower(), limit), + ) as cursor: + rows = await cursor.fetchall() return [ContactAdvertPathRepository._row_to_path(row) for row in rows] @staticmethod async def get_recent_for_all_contacts( limit_per_contact: int = 10, ) -> list[ContactAdvertPathSummary]: - cursor = await db.conn.execute( - """ - SELECT public_key, path_hex, path_len, first_seen, last_seen, heard_count - FROM ( - SELECT *, - ROW_NUMBER() OVER ( - PARTITION BY public_key - ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC - ) AS rn - FROM contact_advert_paths - ) - WHERE rn <= ? - ORDER BY public_key ASC, last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC - """, - (limit_per_contact,), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT public_key, path_hex, path_len, first_seen, last_seen, heard_count + FROM ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY public_key + ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC + ) AS rn + FROM contact_advert_paths + ) + WHERE rn <= ? + ORDER BY public_key ASC, last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC + """, + (limit_per_contact,), + ) as cursor: + rows = await cursor.fetchall() grouped: dict[str, list[ContactAdvertPath]] = {} for row in rows: @@ -802,29 +836,31 @@ class ContactNameHistoryRepository: @staticmethod async def record_name(public_key: str, name: str, timestamp: int) -> None: """Record a name observation. Upserts: updates last_seen if name already known.""" - await db.conn.execute( - """ - INSERT INTO contact_name_history (public_key, name, first_seen, last_seen) - VALUES (?, ?, ?, ?) - ON CONFLICT(public_key, name) DO UPDATE SET - last_seen = MAX(contact_name_history.last_seen, excluded.last_seen) - """, - (public_key.lower(), name, timestamp, timestamp), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """ + INSERT INTO contact_name_history (public_key, name, first_seen, last_seen) + VALUES (?, ?, ?, ?) + ON CONFLICT(public_key, name) DO UPDATE SET + last_seen = MAX(contact_name_history.last_seen, excluded.last_seen) + """, + (public_key.lower(), name, timestamp, timestamp), + ): + pass @staticmethod async def get_history(public_key: str) -> list[ContactNameHistory]: - cursor = await db.conn.execute( - """ - SELECT name, first_seen, last_seen - FROM contact_name_history - WHERE public_key = ? - ORDER BY last_seen DESC - """, - (public_key.lower(),), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT name, first_seen, last_seen + FROM contact_name_history + WHERE public_key = ? + ORDER BY last_seen DESC + """, + (public_key.lower(),), + ) as cursor: + rows = await cursor.fetchall() return [ ContactNameHistory( name=row["name"], first_seen=row["first_seen"], last_seen=row["last_seen"] diff --git a/app/repository/fanout.py b/app/repository/fanout.py index 76fb31d..91372a6 100644 --- a/app/repository/fanout.py +++ b/app/repository/fanout.py @@ -6,6 +6,8 @@ import time import uuid from typing import Any +import aiosqlite + from app.database import db logger = logging.getLogger(__name__) @@ -31,26 +33,37 @@ def _row_to_dict(row: Any) -> dict[str, Any]: return result +async def _get_in_conn(conn: aiosqlite.Connection, config_id: str) -> dict[str, Any] | None: + """Fetch a config using an already-acquired connection. + + Used by ``create`` and ``update`` to return the freshly-written row + without re-entering the non-reentrant DB lock. + """ + async with conn.execute("SELECT * FROM fanout_configs WHERE id = ?", (config_id,)) as cursor: + row = await cursor.fetchone() + if row is None: + return None + return _row_to_dict(row) + + class FanoutConfigRepository: """CRUD operations for fanout_configs table.""" @staticmethod async def get_all() -> list[dict[str, Any]]: """Get all fanout configs ordered by sort_order.""" - cursor = await db.conn.execute( - "SELECT * FROM fanout_configs ORDER BY sort_order, created_at" - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM fanout_configs ORDER BY sort_order, created_at" + ) as cursor: + rows = await cursor.fetchall() return [_row_to_dict(row) for row in rows] @staticmethod async def get(config_id: str) -> dict[str, Any] | None: """Get a single fanout config by ID.""" - cursor = await db.conn.execute("SELECT * FROM fanout_configs WHERE id = ?", (config_id,)) - row = await cursor.fetchone() - if row is None: - return None - return _row_to_dict(row) + async with db.readonly() as conn: + return await _get_in_conn(conn, config_id) @staticmethod async def create( @@ -65,39 +78,41 @@ class FanoutConfigRepository: new_id = config_id or str(uuid.uuid4()) now = int(time.time()) - # Get next sort_order - cursor = await db.conn.execute( - "SELECT COALESCE(MAX(sort_order), -1) + 1 FROM fanout_configs" - ) - row = await cursor.fetchone() - sort_order = row[0] if row else 0 + async with db.tx() as conn: + # Determine next sort_order under the same lock as the insert, + # so two concurrent ``create()`` calls cannot collide. + async with conn.execute( + "SELECT COALESCE(MAX(sort_order), -1) + 1 FROM fanout_configs" + ) as cursor: + row = await cursor.fetchone() + sort_order = row[0] if row else 0 - await db.conn.execute( - """ - INSERT INTO fanout_configs (id, type, name, enabled, config, scope, sort_order, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - new_id, - config_type, - name, - 1 if enabled else 0, - json.dumps(config), - json.dumps(scope), - sort_order, - now, - ), - ) - await db.conn.commit() + async with conn.execute( + """ + INSERT INTO fanout_configs (id, type, name, enabled, config, scope, sort_order, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + new_id, + config_type, + name, + 1 if enabled else 0, + json.dumps(config), + json.dumps(scope), + sort_order, + now, + ), + ): + pass - result = await FanoutConfigRepository.get(new_id) + result = await _get_in_conn(conn, new_id) assert result is not None return result @staticmethod async def update(config_id: str, **fields: Any) -> dict[str, Any] | None: """Update a fanout config. Only provided fields are updated.""" - updates = [] + updates: list[str] = [] params: list[Any] = [] for field in ("name", "enabled", "config", "scope", "sort_order"): @@ -115,23 +130,25 @@ class FanoutConfigRepository: params.append(config_id) query = f"UPDATE fanout_configs SET {', '.join(updates)} WHERE id = ?" - await db.conn.execute(query, params) - await db.conn.commit() - - return await FanoutConfigRepository.get(config_id) + async with db.tx() as conn: + async with conn.execute(query, params): + pass + return await _get_in_conn(conn, config_id) @staticmethod async def delete(config_id: str) -> None: """Delete a fanout config.""" - await db.conn.execute("DELETE FROM fanout_configs WHERE id = ?", (config_id,)) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute("DELETE FROM fanout_configs WHERE id = ?", (config_id,)): + pass _configs_cache.pop(config_id, None) @staticmethod async def get_enabled() -> list[dict[str, Any]]: """Get all enabled fanout configs.""" - cursor = await db.conn.execute( - "SELECT * FROM fanout_configs WHERE enabled = 1 ORDER BY sort_order, created_at" - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM fanout_configs WHERE enabled = 1 ORDER BY sort_order, created_at" + ) as cursor: + rows = await cursor.fetchall() return [_row_to_dict(row) for row in rows] diff --git a/app/repository/messages.py b/app/repository/messages.py index 2562034..186effb 100644 --- a/app/repository/messages.py +++ b/app/repository/messages.py @@ -89,32 +89,34 @@ class MessageRepository: # Normalize sender_key to lowercase so queries can match without LOWER(). normalized_sender_key = sender_key.lower() if sender_key else sender_key - cursor = await db.conn.execute( - """ - INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp, - received_at, paths, txt_type, signature, outgoing, - sender_name, sender_key) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - msg_type, - conversation_key, - text, - sender_timestamp, - received_at, - paths_json, - txt_type, - signature, - outgoing, - sender_name, - normalized_sender_key, - ), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """ + INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp, + received_at, paths, txt_type, signature, outgoing, + sender_name, sender_key) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + msg_type, + conversation_key, + text, + sender_timestamp, + received_at, + paths_json, + txt_type, + signature, + outgoing, + sender_name, + normalized_sender_key, + ), + ) as cursor: + rowcount = cursor.rowcount + lastrowid = cursor.lastrowid # rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation - if cursor.rowcount == 0: + if rowcount == 0: return None - return cursor.lastrowid + return lastrowid @staticmethod async def add_path( @@ -142,17 +144,20 @@ class MessageRepository: if snr is not None: entry["snr"] = snr new_entry = json.dumps(entry) - await db.conn.execute( - """UPDATE messages SET paths = json_insert( - COALESCE(paths, '[]'), '$[#]', json(?) - ) WHERE id = ?""", - (new_entry, message_id), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + """UPDATE messages SET paths = json_insert( + COALESCE(paths, '[]'), '$[#]', json(?) + ) WHERE id = ?""", + (new_entry, message_id), + ): + pass - # Read back the full list for the return value - cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,)) - row = await cursor.fetchone() + # Read back the full list for the return value, same transaction. + async with conn.execute( + "SELECT paths FROM messages WHERE id = ?", (message_id,) + ) as cursor: + row = await cursor.fetchone() if not row or not row["paths"]: return [] @@ -171,23 +176,24 @@ class MessageRepository: only a prefix as conversation_key are updated to use the full key. """ lower_key = full_key.lower() - cursor = await db.conn.execute( - """UPDATE messages SET conversation_key = ?, - sender_key = CASE - WHEN sender_key IS NOT NULL AND length(sender_key) < 64 - AND ? LIKE sender_key || '%' - THEN ? ELSE sender_key END - WHERE type = 'PRIV' AND length(conversation_key) < 64 - AND ? LIKE conversation_key || '%' - AND ( - SELECT COUNT(*) FROM contacts - WHERE length(public_key) = 64 - AND public_key LIKE messages.conversation_key || '%' - ) = 1""", - (lower_key, lower_key, lower_key, lower_key), - ) - await db.conn.commit() - return cursor.rowcount + async with db.tx() as conn: + async with conn.execute( + """UPDATE messages SET conversation_key = ?, + sender_key = CASE + WHEN sender_key IS NOT NULL AND length(sender_key) < 64 + AND ? LIKE sender_key || '%' + THEN ? ELSE sender_key END + WHERE type = 'PRIV' AND length(conversation_key) < 64 + AND ? LIKE conversation_key || '%' + AND ( + SELECT COUNT(*) FROM contacts + WHERE length(public_key) = 64 + AND public_key LIKE messages.conversation_key || '%' + ) = 1""", + (lower_key, lower_key, lower_key, lower_key), + ) as cursor: + rowcount = cursor.rowcount + return rowcount @staticmethod async def backfill_channel_sender_key(public_key: str, name: str) -> int: @@ -197,21 +203,22 @@ class MessageRepository: any channel messages with a matching sender_name but no sender_key are updated to associate them with this contact's public key. """ - cursor = await db.conn.execute( - """UPDATE messages SET sender_key = ? - WHERE type = 'CHAN' AND sender_name = ? AND sender_key IS NULL - AND ( - SELECT COUNT(*) FROM contacts - WHERE name = ? - ) = 1 - AND EXISTS ( - SELECT 1 FROM contacts - WHERE public_key = ? AND name = ? - )""", - (public_key.lower(), name, name, public_key.lower(), name), - ) - await db.conn.commit() - return cursor.rowcount + async with db.tx() as conn: + async with conn.execute( + """UPDATE messages SET sender_key = ? + WHERE type = 'CHAN' AND sender_name = ? AND sender_key IS NULL + AND ( + SELECT COUNT(*) FROM contacts + WHERE name = ? + ) = 1 + AND EXISTS ( + SELECT 1 FROM contacts + WHERE public_key = ? AND name = ? + )""", + (public_key.lower(), name, name, public_key.lower(), name), + ) as cursor: + rowcount = cursor.rowcount + return rowcount @staticmethod def _normalize_conversation_key(conversation_key: str) -> tuple[str, str]: @@ -462,8 +469,9 @@ class MessageRepository: query += " OFFSET ?" params.append(offset) - cursor = await db.conn.execute(query, params) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() return [MessageRepository._row_to_message(row) for row in rows] @staticmethod @@ -501,51 +509,54 @@ class MessageRepository: where_sql = " AND ".join(["1=1", *where_parts]) # 1. Get the target message (must satisfy filters if provided) - target_cursor = await db.conn.execute( - f"SELECT {MessageRepository._message_select('messages')} " - f"FROM messages WHERE id = ? AND {where_sql}", - (message_id, *base_params), - ) - target_row = await target_cursor.fetchone() - if not target_row: - return [], False, False + async with db.readonly() as conn: + async with conn.execute( + f"SELECT {MessageRepository._message_select('messages')} " + f"FROM messages WHERE id = ? AND {where_sql}", + (message_id, *base_params), + ) as target_cursor: + target_row = await target_cursor.fetchone() + if not target_row: + return [], False, False - target = MessageRepository._row_to_message(target_row) + target = MessageRepository._row_to_message(target_row) - # 2. Get context_size+1 messages before target (DESC) - before_query = f""" - SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql} - AND (received_at < ? OR (received_at = ? AND id < ?)) - ORDER BY received_at DESC, id DESC LIMIT ? - """ - before_params = [ - *base_params, - target.received_at, - target.received_at, - target.id, - context_size + 1, - ] - before_cursor = await db.conn.execute(before_query, before_params) - before_rows = list(await before_cursor.fetchall()) + # 2. Get context_size+1 messages before target (DESC) + before_query = f""" + SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql} + AND (received_at < ? OR (received_at = ? AND id < ?)) + ORDER BY received_at DESC, id DESC LIMIT ? + """ + before_params = [ + *base_params, + target.received_at, + target.received_at, + target.id, + context_size + 1, + ] + async with conn.execute(before_query, before_params) as before_cursor: + before_rows = list(await before_cursor.fetchall()) - has_older = len(before_rows) > context_size - before_messages = [MessageRepository._row_to_message(r) for r in before_rows[:context_size]] + has_older = len(before_rows) > context_size + before_messages = [ + MessageRepository._row_to_message(r) for r in before_rows[:context_size] + ] - # 3. Get context_size+1 messages after target (ASC) - after_query = f""" - SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql} - AND (received_at > ? OR (received_at = ? AND id > ?)) - ORDER BY received_at ASC, id ASC LIMIT ? - """ - after_params = [ - *base_params, - target.received_at, - target.received_at, - target.id, - context_size + 1, - ] - after_cursor = await db.conn.execute(after_query, after_params) - after_rows = list(await after_cursor.fetchall()) + # 3. Get context_size+1 messages after target (ASC) + after_query = f""" + SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql} + AND (received_at > ? OR (received_at = ? AND id > ?)) + ORDER BY received_at ASC, id ASC LIMIT ? + """ + after_params = [ + *base_params, + target.received_at, + target.received_at, + target.id, + context_size + 1, + ] + async with conn.execute(after_query, after_params) as after_cursor: + after_rows = list(await after_cursor.fetchall()) has_newer = len(after_rows) > context_size after_messages = [MessageRepository._row_to_message(r) for r in after_rows[:context_size]] @@ -556,21 +567,29 @@ class MessageRepository: @staticmethod async def increment_ack_count(message_id: int) -> int: - """Increment ack count and return the new value.""" - cursor = await db.conn.execute( - "UPDATE messages SET acked = acked + 1 WHERE id = ? RETURNING acked", (message_id,) - ) - row = await cursor.fetchone() - await db.conn.commit() + """Increment ack count and return the new value. + + NOTE: ``RETURNING`` leaves the prepared statement active until the + row is fetched, so we MUST consume it inside the ``async with`` + block. Without that, the commit at the end of ``db.tx()`` fails + with ``cannot commit transaction - SQL statements in progress``. + """ + async with db.tx() as conn: + async with conn.execute( + "UPDATE messages SET acked = acked + 1 WHERE id = ? RETURNING acked", + (message_id,), + ) as cursor: + row = await cursor.fetchone() return row["acked"] if row else 1 @staticmethod async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]: """Get the current ack count and paths for a message.""" - cursor = await db.conn.execute( - "SELECT acked, paths FROM messages WHERE id = ?", (message_id,) - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT acked, paths FROM messages WHERE id = ?", (message_id,) + ) as cursor: + row = await cursor.fetchone() if not row: return 0, None return row["acked"], MessageRepository._parse_paths(row["paths"]) @@ -578,11 +597,12 @@ class MessageRepository: @staticmethod async def get_by_id(message_id: int) -> "Message | None": """Look up a message by its ID.""" - cursor = await db.conn.execute( - f"SELECT {MessageRepository._message_select('messages')} FROM messages WHERE id = ?", - (message_id,), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + f"SELECT {MessageRepository._message_select('messages')} FROM messages WHERE id = ?", + (message_id,), + ) as cursor: + row = await cursor.fetchone() if not row: return None @@ -591,11 +611,14 @@ class MessageRepository: @staticmethod async def delete_by_id(message_id: int) -> None: """Delete a message row by ID.""" - await db.conn.execute( - "UPDATE raw_packets SET message_id = NULL WHERE message_id = ?", (message_id,) - ) - await db.conn.execute("DELETE FROM messages WHERE id = ?", (message_id,)) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + "UPDATE raw_packets SET message_id = NULL WHERE message_id = ?", + (message_id,), + ): + pass + async with conn.execute("DELETE FROM messages WHERE id = ?", (message_id,)): + pass @staticmethod async def get_by_content( @@ -618,8 +641,9 @@ class MessageRepository: query += " AND outgoing = ?" params.append(1 if outgoing else 0) query += " ORDER BY id ASC" - cursor = await db.conn.execute(query, params) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute(query, params) as cursor: + row = await cursor.fetchone() if not row: return None @@ -653,76 +677,6 @@ class MessageRepository: ) blocked_sql = f" AND {blocked_clause}" if blocked_clause else "" - # Channel unreads - cursor = await db.conn.execute( - f""" - SELECT m.conversation_key, - COUNT(*) as unread_count, - SUM(CASE - WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 - ELSE 0 - END) > 0 as has_mention - FROM messages m - JOIN channels c ON m.conversation_key = c.key - WHERE m.type = 'CHAN' AND m.outgoing = 0 - AND m.received_at > COALESCE(c.last_read_at, 0) - {blocked_sql} - GROUP BY m.conversation_key - """, - (mention_token or "", mention_token or "", *blocked_params), - ) - rows = await cursor.fetchall() - for row in rows: - state_key = f"channel-{row['conversation_key']}" - counts[state_key] = row["unread_count"] - if mention_token and row["has_mention"]: - mention_flags[state_key] = True - - # Contact unreads - cursor = await db.conn.execute( - f""" - SELECT m.conversation_key, - COUNT(*) as unread_count, - SUM(CASE - WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 - ELSE 0 - END) > 0 as has_mention - FROM messages m - LEFT JOIN contacts ct ON m.conversation_key = ct.public_key - WHERE m.type = 'PRIV' AND m.outgoing = 0 - AND m.received_at > COALESCE(ct.last_read_at, 0) - {blocked_sql} - GROUP BY m.conversation_key - """, - (mention_token or "", mention_token or "", *blocked_params), - ) - rows = await cursor.fetchall() - for row in rows: - state_key = f"contact-{row['conversation_key']}" - counts[state_key] = row["unread_count"] - if mention_token and row["has_mention"]: - mention_flags[state_key] = True - - cursor = await db.conn.execute( - """ - SELECT key, last_read_at - FROM channels - """ - ) - rows = await cursor.fetchall() - for row in rows: - last_read_ats[f"channel-{row['key']}"] = row["last_read_at"] - - cursor = await db.conn.execute( - """ - SELECT public_key, last_read_at - FROM contacts - """ - ) - rows = await cursor.fetchall() - for row in rows: - last_read_ats[f"contact-{row['public_key']}"] = row["last_read_at"] - # Last message times for all conversations (including read ones), # excluding blocked incoming traffic so refresh matches live WS behavior. last_time_clause, last_time_params = MessageRepository._build_blocked_incoming_clause( @@ -730,20 +684,94 @@ class MessageRepository: ) last_time_where_sql = f"WHERE {last_time_clause}" if last_time_clause else "" - cursor = await db.conn.execute( - f""" - SELECT type, conversation_key, MAX(received_at) as last_message_time - FROM messages - {last_time_where_sql} - GROUP BY type, conversation_key - """, - last_time_params, - ) - rows = await cursor.fetchall() - for row in rows: - prefix = "channel" if row["type"] == "CHAN" else "contact" - state_key = f"{prefix}-{row['conversation_key']}" - last_message_times[state_key] = row["last_message_time"] + # Single readonly acquisition for all 5 queries — they form one logical + # snapshot, and holding the lock for the batch is cheaper than acquiring + # it 5 times. + async with db.readonly() as conn: + # Channel unreads + async with conn.execute( + f""" + SELECT m.conversation_key, + COUNT(*) as unread_count, + SUM(CASE + WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 + ELSE 0 + END) > 0 as has_mention + FROM messages m + JOIN channels c ON m.conversation_key = c.key + WHERE m.type = 'CHAN' AND m.outgoing = 0 + AND m.received_at > COALESCE(c.last_read_at, 0) + {blocked_sql} + GROUP BY m.conversation_key + """, + (mention_token or "", mention_token or "", *blocked_params), + ) as cursor: + rows = await cursor.fetchall() + for row in rows: + state_key = f"channel-{row['conversation_key']}" + counts[state_key] = row["unread_count"] + if mention_token and row["has_mention"]: + mention_flags[state_key] = True + + # Contact unreads + async with conn.execute( + f""" + SELECT m.conversation_key, + COUNT(*) as unread_count, + SUM(CASE + WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 + ELSE 0 + END) > 0 as has_mention + FROM messages m + LEFT JOIN contacts ct ON m.conversation_key = ct.public_key + WHERE m.type = 'PRIV' AND m.outgoing = 0 + AND m.received_at > COALESCE(ct.last_read_at, 0) + {blocked_sql} + GROUP BY m.conversation_key + """, + (mention_token or "", mention_token or "", *blocked_params), + ) as cursor: + rows = await cursor.fetchall() + for row in rows: + state_key = f"contact-{row['conversation_key']}" + counts[state_key] = row["unread_count"] + if mention_token and row["has_mention"]: + mention_flags[state_key] = True + + async with conn.execute( + """ + SELECT key, last_read_at + FROM channels + """ + ) as cursor: + rows = await cursor.fetchall() + for row in rows: + last_read_ats[f"channel-{row['key']}"] = row["last_read_at"] + + async with conn.execute( + """ + SELECT public_key, last_read_at + FROM contacts + """ + ) as cursor: + rows = await cursor.fetchall() + for row in rows: + last_read_ats[f"contact-{row['public_key']}"] = row["last_read_at"] + + async with conn.execute( + f""" + SELECT type, conversation_key, MAX(received_at) as last_message_time + FROM messages + {last_time_where_sql} + GROUP BY type, conversation_key + """, + last_time_params, + ) as cursor: + rows = await cursor.fetchall() + for row in rows: + prefix = "channel" if row["type"] == "CHAN" else "contact" + state_key = f"{prefix}-{row['conversation_key']}" + last_message_times[state_key] = row["last_message_time"] # Only include last_read_ats for conversations that actually have messages. # Without this filter, every contact heard via advertisement (even without @@ -760,41 +788,45 @@ class MessageRepository: @staticmethod async def count_dm_messages(contact_key: str) -> int: """Count total DM messages for a contact.""" - cursor = await db.conn.execute( - "SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?", - (contact_key.lower(),), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?", + (contact_key.lower(),), + ) as cursor: + row = await cursor.fetchone() return row["cnt"] if row else 0 @staticmethod async def count_channel_messages_by_sender(sender_key: str) -> int: """Count channel messages sent by a specific contact.""" - cursor = await db.conn.execute( - "SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?", - (sender_key.lower(),), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?", + (sender_key.lower(),), + ) as cursor: + row = await cursor.fetchone() return row["cnt"] if row else 0 @staticmethod async def count_channel_messages_by_sender_name(sender_name: str) -> int: """Count channel messages attributed to a display name.""" - cursor = await db.conn.execute( - "SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_name = ?", - (sender_name,), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_name = ?", + (sender_name,), + ) as cursor: + row = await cursor.fetchone() return row["cnt"] if row else 0 @staticmethod async def get_first_channel_message_by_sender_name(sender_name: str) -> int | None: """Get the earliest stored channel message timestamp for a display name.""" - cursor = await db.conn.execute( - "SELECT MIN(received_at) AS first_seen FROM messages WHERE type = 'CHAN' AND sender_name = ?", - (sender_name,), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT MIN(received_at) AS first_seen FROM messages WHERE type = 'CHAN' AND sender_name = ?", + (sender_name,), + ) as cursor: + row = await cursor.fetchone() return row["first_seen"] if row and row["first_seen"] is not None else None @staticmethod @@ -813,68 +845,76 @@ class MessageRepository: t_48h = now - 172800 t_7d = now - 604800 - cursor = await db.conn.execute( - """ - SELECT COUNT(*) AS all_time, - SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_1h, - SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_24h, - SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_48h, - SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_7d, - MIN(received_at) AS first_message_at, - COUNT(DISTINCT sender_key) AS unique_sender_count - FROM messages WHERE type = 'CHAN' AND conversation_key = ? - """, - (t_1h, t_24h, t_48h, t_7d, conversation_key), - ) - row = await cursor.fetchone() - assert row is not None # Aggregate query always returns a row + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT COUNT(*) AS all_time, + SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_1h, + SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_24h, + SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_48h, + SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_7d, + MIN(received_at) AS first_message_at, + COUNT(DISTINCT sender_key) AS unique_sender_count + FROM messages WHERE type = 'CHAN' AND conversation_key = ? + """, + (t_1h, t_24h, t_48h, t_7d, conversation_key), + ) as cursor: + row = await cursor.fetchone() + assert row is not None # Aggregate query always returns a row - message_counts = { - "last_1h": row["last_1h"] or 0, - "last_24h": row["last_24h"] or 0, - "last_48h": row["last_48h"] or 0, - "last_7d": row["last_7d"] or 0, - "all_time": row["all_time"] or 0, - } - - cursor2 = await db.conn.execute( - """ - SELECT COALESCE(sender_name, sender_key, 'Unknown') AS display_name, - sender_key, COUNT(*) AS cnt - FROM messages - WHERE type = 'CHAN' AND conversation_key = ? - AND received_at >= ? AND sender_key IS NOT NULL - GROUP BY sender_key ORDER BY cnt DESC LIMIT 5 - """, - (conversation_key, t_24h), - ) - top_rows = await cursor2.fetchall() - top_senders = [ - { - "sender_name": r["display_name"], - "sender_key": r["sender_key"], - "message_count": r["cnt"], + message_counts = { + "last_1h": row["last_1h"] or 0, + "last_24h": row["last_24h"] or 0, + "last_48h": row["last_48h"] or 0, + "last_7d": row["last_7d"] or 0, + "all_time": row["all_time"] or 0, } - for r in top_rows - ] - # Path hash width distribution for last 24h (in-Python parse of raw packet envelopes) - cursor3 = await db.conn.execute( - """ - SELECT rp.data FROM raw_packets rp - JOIN messages m ON rp.message_id = m.id - WHERE m.type = 'CHAN' AND m.conversation_key = ? - AND rp.timestamp >= ? - """, - (conversation_key, t_24h), - ) - rows3 = await cursor3.fetchall() + async with conn.execute( + """ + SELECT COALESCE(sender_name, sender_key, 'Unknown') AS display_name, + sender_key, COUNT(*) AS cnt + FROM messages + WHERE type = 'CHAN' AND conversation_key = ? + AND received_at >= ? AND sender_key IS NOT NULL + GROUP BY sender_key ORDER BY cnt DESC LIMIT 5 + """, + (conversation_key, t_24h), + ) as cursor: + top_rows = await cursor.fetchall() + top_senders = [ + { + "sender_name": r["display_name"], + "sender_key": r["sender_key"], + "message_count": r["cnt"], + } + for r in top_rows + ] + + # Path hash width distribution for last 24h: fetch raw rows under + # the lock, then release BEFORE the CPU-bound in-Python envelope + # parse. Parsing can iterate thousands of rows and previously held + # the DB lock for the whole traversal — blocking every other repo + # caller on a Pi. Keep the lock only for the fetch. + async with conn.execute( + """ + SELECT rp.data FROM raw_packets rp + JOIN messages m ON rp.message_id = m.id + WHERE m.type = 'CHAN' AND m.conversation_key = ? + AND rp.timestamp >= ? + """, + (conversation_key, t_24h), + ) as cursor: + rows3 = await cursor.fetchall() + first_message_at = row["first_message_at"] + unique_sender_count = row["unique_sender_count"] or 0 + path_hash_width_24h = bucket_path_hash_widths(rows3) return { "message_counts": message_counts, - "first_message_at": row["first_message_at"], - "unique_sender_count": row["unique_sender_count"] or 0, + "first_message_at": first_message_at, + "unique_sender_count": unique_sender_count, "top_senders_24h": top_senders, "path_hash_width_24h": path_hash_width_24h, } @@ -882,14 +922,15 @@ class MessageRepository: @staticmethod async def count_channels_with_incoming_messages() -> int: """Count distinct channel conversations with at least one incoming message.""" - cursor = await db.conn.execute( - """ - SELECT COUNT(DISTINCT conversation_key) AS cnt - FROM messages - WHERE type = 'CHAN' AND outgoing = 0 - """ - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT COUNT(DISTINCT conversation_key) AS cnt + FROM messages + WHERE type = 'CHAN' AND outgoing = 0 + """ + ) as cursor: + row = await cursor.fetchone() return int(row["cnt"]) if row and row["cnt"] is not None else 0 @staticmethod @@ -898,20 +939,21 @@ class MessageRepository: Returns list of (channel_key, channel_name, message_count) tuples. """ - cursor = await db.conn.execute( - """ - SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, - COUNT(*) AS cnt - FROM messages m - LEFT JOIN channels c ON m.conversation_key = c.key - WHERE m.type = 'CHAN' AND m.sender_key = ? - GROUP BY m.conversation_key - ORDER BY cnt DESC - LIMIT ? - """, - (sender_key.lower(), limit), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, + COUNT(*) AS cnt + FROM messages m + LEFT JOIN channels c ON m.conversation_key = c.key + WHERE m.type = 'CHAN' AND m.sender_key = ? + GROUP BY m.conversation_key + ORDER BY cnt DESC + LIMIT ? + """, + (sender_key.lower(), limit), + ) as cursor: + rows = await cursor.fetchall() return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows] @staticmethod @@ -919,34 +961,36 @@ class MessageRepository: sender_name: str, limit: int = 5 ) -> list[tuple[str, str, int]]: """Get channels where a display name has sent the most messages.""" - cursor = await db.conn.execute( - """ - SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, - COUNT(*) AS cnt - FROM messages m - LEFT JOIN channels c ON m.conversation_key = c.key - WHERE m.type = 'CHAN' AND m.sender_name = ? - GROUP BY m.conversation_key - ORDER BY cnt DESC - LIMIT ? - """, - (sender_name, limit), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, + COUNT(*) AS cnt + FROM messages m + LEFT JOIN channels c ON m.conversation_key = c.key + WHERE m.type = 'CHAN' AND m.sender_name = ? + GROUP BY m.conversation_key + ORDER BY cnt DESC + LIMIT ? + """, + (sender_name, limit), + ) as cursor: + rows = await cursor.fetchall() return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows] @staticmethod async def _get_activity_hour_buckets(where_sql: str, params: list[Any]) -> dict[int, int]: - cursor = await db.conn.execute( - f""" - SELECT received_at / 3600 AS hour_bucket, COUNT(*) AS cnt - FROM messages - WHERE {where_sql} - GROUP BY hour_bucket - """, - params, - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + f""" + SELECT received_at / 3600 AS hour_bucket, COUNT(*) AS cnt + FROM messages + WHERE {where_sql} + GROUP BY hour_bucket + """, + params, + ) as cursor: + rows = await cursor.fetchall() return {int(row["hour_bucket"]): row["cnt"] for row in rows} @staticmethod @@ -1000,16 +1044,17 @@ class MessageRepository: current_day_start = (now // 86400) * 86400 start = current_day_start - (weeks - 1) * bucket_seconds - cursor = await db.conn.execute( - f""" - SELECT (received_at - ?) / ? AS bucket_idx, COUNT(*) AS cnt - FROM messages - WHERE {where_sql} AND received_at >= ? - GROUP BY bucket_idx - """, - [start, bucket_seconds, *params, start], - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + f""" + SELECT (received_at - ?) / ? AS bucket_idx, COUNT(*) AS cnt + FROM messages + WHERE {where_sql} AND received_at >= ? + GROUP BY bucket_idx + """, + [start, bucket_seconds, *params, start], + ) as cursor: + rows = await cursor.fetchall() counts = {int(row["bucket_idx"]): row["cnt"] for row in rows} return [ diff --git a/app/repository/raw_packets.py b/app/repository/raw_packets.py index 3aded63..16ab864 100644 --- a/app/repository/raw_packets.py +++ b/app/repository/raw_packets.py @@ -34,65 +34,85 @@ class RawPacketRepository: # For malformed packets, hash the full data payload_hash = sha256(data).digest() - cursor = await db.conn.execute( - "INSERT OR IGNORE INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)", - (ts, data, payload_hash), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + "INSERT OR IGNORE INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)", + (ts, data, payload_hash), + ) as cursor: + rowcount = cursor.rowcount + lastrowid = cursor.lastrowid - if cursor.rowcount > 0: - assert cursor.lastrowid is not None - return (cursor.lastrowid, True) + if rowcount > 0: + assert lastrowid is not None + return (lastrowid, True) - # Duplicate payload — look up the existing row. - cursor = await db.conn.execute( - "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) - ) - existing = await cursor.fetchone() + # Duplicate payload — look up the existing row (same transaction). + async with conn.execute( + "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) + ) as cursor: + existing = await cursor.fetchone() assert existing is not None return (existing["id"], False) @staticmethod async def get_undecrypted_count() -> int: """Get count of undecrypted packets (those without a linked message).""" - cursor = await db.conn.execute( - "SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL" - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL" + ) as cursor: + row = await cursor.fetchone() return row["count"] if row else 0 @staticmethod async def get_oldest_undecrypted() -> int | None: """Get timestamp of oldest undecrypted packet, or None if none exist.""" - cursor = await db.conn.execute( - "SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL" - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL" + ) as cursor: + row = await cursor.fetchone() return row["oldest"] if row and row["oldest"] is not None else None + @staticmethod + async def _stream_undecrypted_rows( + batch_size: int, + ) -> AsyncIterator[tuple[int, bytes, int]]: + """Internal: keyset-paginated scan of every undecrypted raw packet. + + Yields ``(id, data, timestamp)`` for each row across all batches. + Lock is acquired per batch only — concurrent writes can interleave + at batch boundaries rather than being blocked for the full scan. + Each batch opens a fresh cursor and consumes it fully with + ``fetchall()`` before releasing, so no prepared statement is alive + at a yield boundary. + + ``last_id`` advances per row, not per yield, so external filters + (see ``stream_undecrypted_text_messages``) that drop rows do not + cause a re-scan of skipped IDs. + """ + last_id = -1 + while True: + async with db.readonly() as conn: + async with conn.execute( + "SELECT id, data, timestamp FROM raw_packets " + "WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?", + (last_id, batch_size), + ) as cursor: + rows = await cursor.fetchall() + if not rows: + return + for row in rows: + last_id = row["id"] + yield (row["id"], bytes(row["data"]), row["timestamp"]) + @staticmethod async def stream_all_undecrypted( batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE, ) -> AsyncIterator[tuple[int, bytes, int]]: - """Yield all undecrypted packets as (id, data, timestamp) in bounded batches. - - Uses keyset pagination so each batch is a fresh query with a fully - consumed cursor — no open statement held across yield boundaries. - """ - last_id = -1 - while True: - cursor = await db.conn.execute( - "SELECT id, data, timestamp FROM raw_packets " - "WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?", - (last_id, batch_size), - ) - rows = await cursor.fetchall() - await cursor.close() - if not rows: - break - for row in rows: - last_id = row["id"] - yield (row["id"], bytes(row["data"]), row["timestamp"]) + """Yield all undecrypted packets as (id, data, timestamp) in bounded batches.""" + async for row in RawPacketRepository._stream_undecrypted_rows(batch_size): + yield row @staticmethod async def stream_undecrypted_text_messages( @@ -100,26 +120,15 @@ class RawPacketRepository: ) -> AsyncIterator[tuple[int, bytes, int]]: """Yield undecrypted TEXT_MESSAGE packets in bounded-size batches. - Uses keyset pagination so each batch is a fresh query with a fully - consumed cursor — no open statement held across yield boundaries. + Filters the shared scan to rows whose payload parses as a text + message. Non-matching rows still advance the keyset cursor so they + aren't re-fetched on subsequent batches. """ - last_id = -1 - while True: - cursor = await db.conn.execute( - "SELECT id, data, timestamp FROM raw_packets " - "WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?", - (last_id, batch_size), - ) - rows = await cursor.fetchall() - await cursor.close() - if not rows: - break - for row in rows: - last_id = row["id"] - data = bytes(row["data"]) - payload_type = get_packet_payload_type(data) - if payload_type == PayloadType.TEXT_MESSAGE: - yield (row["id"], data, row["timestamp"]) + async for packet_id, data, timestamp in RawPacketRepository._stream_undecrypted_rows( + batch_size + ): + if get_packet_payload_type(data) == PayloadType.TEXT_MESSAGE: + yield (packet_id, data, timestamp) @staticmethod async def count_undecrypted_text_messages( @@ -136,20 +145,22 @@ class RawPacketRepository: @staticmethod async def mark_decrypted(packet_id: int, message_id: int) -> None: """Link a raw packet to its decrypted message.""" - await db.conn.execute( - "UPDATE raw_packets SET message_id = ? WHERE id = ?", - (message_id, packet_id), - ) - await db.conn.commit() + async with db.tx() as conn: + async with conn.execute( + "UPDATE raw_packets SET message_id = ? WHERE id = ?", + (message_id, packet_id), + ): + pass @staticmethod async def get_linked_message_id(packet_id: int) -> int | None: """Return the linked message ID for a raw packet, if any.""" - cursor = await db.conn.execute( - "SELECT message_id FROM raw_packets WHERE id = ?", - (packet_id,), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT message_id FROM raw_packets WHERE id = ?", + (packet_id,), + ) as cursor: + row = await cursor.fetchone() if not row: return None return row["message_id"] @@ -157,11 +168,12 @@ class RawPacketRepository: @staticmethod async def get_by_id(packet_id: int) -> tuple[int, bytes, int, int | None] | None: """Return a raw packet row as (id, data, timestamp, message_id).""" - cursor = await db.conn.execute( - "SELECT id, data, timestamp, message_id FROM raw_packets WHERE id = ?", - (packet_id,), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT id, data, timestamp, message_id FROM raw_packets WHERE id = ?", + (packet_id,), + ) as cursor: + row = await cursor.fetchone() if not row: return None return (row["id"], bytes(row["data"]), row["timestamp"], row["message_id"]) @@ -170,16 +182,20 @@ class RawPacketRepository: async def prune_old_undecrypted(max_age_days: int) -> int: """Delete undecrypted packets older than max_age_days. Returns count deleted.""" cutoff = int(time.time()) - (max_age_days * 86400) - cursor = await db.conn.execute( - "DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?", - (cutoff,), - ) - await db.conn.commit() - return cursor.rowcount + async with db.tx() as conn: + async with conn.execute( + "DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?", + (cutoff,), + ) as cursor: + rowcount = cursor.rowcount + return rowcount @staticmethod async def purge_linked_to_messages() -> int: """Delete raw packets that are already linked to a stored message.""" - cursor = await db.conn.execute("DELETE FROM raw_packets WHERE message_id IS NOT NULL") - await db.conn.commit() - return cursor.rowcount + async with db.tx() as conn: + async with conn.execute( + "DELETE FROM raw_packets WHERE message_id IS NOT NULL" + ) as cursor: + rowcount = cursor.rowcount + return rowcount diff --git a/app/repository/repeater_telemetry.py b/app/repository/repeater_telemetry.py index 068a812..976d39d 100644 --- a/app/repository/repeater_telemetry.py +++ b/app/repository/repeater_telemetry.py @@ -21,51 +21,54 @@ class RepeaterTelemetryRepository: data: dict, ) -> None: """Insert a telemetry history row and prune stale entries.""" - await db.conn.execute( - """ - INSERT INTO repeater_telemetry_history - (public_key, timestamp, data) - VALUES (?, ?, ?) - """, - (public_key, timestamp, json.dumps(data)), - ) - - # Prune entries older than 30 days cutoff = int(time.time()) - _MAX_AGE_SECONDS - await db.conn.execute( - "DELETE FROM repeater_telemetry_history WHERE public_key = ? AND timestamp < ?", - (public_key, cutoff), - ) + async with db.tx() as conn: + async with conn.execute( + """ + INSERT INTO repeater_telemetry_history + (public_key, timestamp, data) + VALUES (?, ?, ?) + """, + (public_key, timestamp, json.dumps(data)), + ): + pass - # Cap at _MAX_ENTRIES_PER_REPEATER (keep newest) - await db.conn.execute( - """ - DELETE FROM repeater_telemetry_history - WHERE public_key = ? AND id NOT IN ( - SELECT id FROM repeater_telemetry_history - WHERE public_key = ? - ORDER BY timestamp DESC - LIMIT ? - ) - """, - (public_key, public_key, _MAX_ENTRIES_PER_REPEATER), - ) + # Prune entries older than 30 days + async with conn.execute( + "DELETE FROM repeater_telemetry_history WHERE public_key = ? AND timestamp < ?", + (public_key, cutoff), + ): + pass - await db.conn.commit() + # Cap at _MAX_ENTRIES_PER_REPEATER (keep newest) + async with conn.execute( + """ + DELETE FROM repeater_telemetry_history + WHERE public_key = ? AND id NOT IN ( + SELECT id FROM repeater_telemetry_history + WHERE public_key = ? + ORDER BY timestamp DESC + LIMIT ? + ) + """, + (public_key, public_key, _MAX_ENTRIES_PER_REPEATER), + ): + pass @staticmethod async def get_history(public_key: str, since_timestamp: int) -> list[dict]: """Return telemetry rows for a repeater since a given timestamp, ordered ASC.""" - cursor = await db.conn.execute( - """ - SELECT timestamp, data - FROM repeater_telemetry_history - WHERE public_key = ? AND timestamp >= ? - ORDER BY timestamp ASC - """, - (public_key, since_timestamp), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT timestamp, data + FROM repeater_telemetry_history + WHERE public_key = ? AND timestamp >= ? + ORDER BY timestamp ASC + """, + (public_key, since_timestamp), + ) as cursor: + rows = await cursor.fetchall() return [ { "timestamp": row["timestamp"], @@ -77,17 +80,18 @@ class RepeaterTelemetryRepository: @staticmethod async def get_latest(public_key: str) -> dict | None: """Return the most recent telemetry row for a repeater, or None.""" - cursor = await db.conn.execute( - """ - SELECT timestamp, data - FROM repeater_telemetry_history - WHERE public_key = ? - ORDER BY timestamp DESC - LIMIT 1 - """, - (public_key,), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT timestamp, data + FROM repeater_telemetry_history + WHERE public_key = ? + ORDER BY timestamp DESC + LIMIT 1 + """, + (public_key,), + ) as cursor: + row = await cursor.fetchone() if row is None: return None return { diff --git a/app/repository/settings.py b/app/repository/settings.py index ccdb3f9..38bd087 100644 --- a/app/repository/settings.py +++ b/app/repository/settings.py @@ -3,6 +3,8 @@ import logging import time from typing import Any +import aiosqlite + from app.database import db from app.models import AppSettings from app.path_utils import bucket_path_hash_widths @@ -17,15 +19,23 @@ SECONDS_7D = 604800 class AppSettingsRepository: - """Repository for app_settings table (single-row pattern).""" + """Repository for app_settings table (single-row pattern). + + Public methods acquire the DB lock exactly once. ``toggle_*`` helpers that + need a read-modify-write do so inside a single ``db.tx()`` — the internal + ``_get_in_conn`` / ``_apply_updates`` helpers run under the caller's + already-held lock and must NEVER call ``db.tx()`` or ``db.readonly()``. + """ @staticmethod - async def get() -> AppSettings: - """Get the current app settings. + async def _get_in_conn(conn: aiosqlite.Connection) -> AppSettings: + """Load settings using an already-acquired connection. - Always returns settings - creates default row if needed (migration handles initial row). + Used by the public ``get()`` and by multi-step operations + (``toggle_blocked_key``, ``toggle_blocked_name``) to avoid re-entering + the non-reentrant DB lock. """ - cursor = await db.conn.execute( + async with conn.execute( """ SELECT max_radio_contacts, auto_decrypt_dm_on_advert, last_message_times, @@ -35,8 +45,8 @@ class AppSettingsRepository: telemetry_interval_hours FROM app_settings WHERE id = 1 """ - ) - row = await cursor.fetchone() + ) as cursor: + row = await cursor.fetchone() if not row: # Should not happen after migration, but handle gracefully @@ -119,7 +129,9 @@ class AppSettingsRepository: ) @staticmethod - async def update( + async def _apply_updates( + conn: aiosqlite.Connection, + *, max_radio_contacts: int | None = None, auto_decrypt_dm_on_advert: bool | None = None, last_message_times: dict[str, int] | None = None, @@ -132,9 +144,13 @@ class AppSettingsRepository: tracked_telemetry_repeaters: list[str] | None = None, auto_resend_channel: bool | None = None, telemetry_interval_hours: int | None = None, - ) -> AppSettings: - """Update app settings. Only provided fields are updated.""" - updates = [] + ) -> None: + """Apply field updates using an already-acquired connection. + + Emits a single UPDATE statement inside the caller's transaction. Does + NOT commit — the caller's ``db.tx()`` handles that. + """ + updates: list[str] = [] params: list[Any] = [] if max_radio_contacts is not None: @@ -187,47 +203,101 @@ class AppSettingsRepository: if updates: query = f"UPDATE app_settings SET {', '.join(updates)} WHERE id = 1" - await db.conn.execute(query, params) - await db.conn.commit() + async with conn.execute(query, params): + pass - return await AppSettingsRepository.get() + @staticmethod + async def get() -> AppSettings: + """Get the current app settings. + + Always returns settings - creates default row if needed (migration handles initial row). + """ + async with db.readonly() as conn: + return await AppSettingsRepository._get_in_conn(conn) + + @staticmethod + async def update( + max_radio_contacts: int | None = None, + auto_decrypt_dm_on_advert: bool | None = None, + last_message_times: dict[str, int] | None = None, + advert_interval: int | None = None, + last_advert_time: int | None = None, + flood_scope: str | None = None, + blocked_keys: list[str] | None = None, + blocked_names: list[str] | None = None, + discovery_blocked_types: list[int] | None = None, + tracked_telemetry_repeaters: list[str] | None = None, + auto_resend_channel: bool | None = None, + telemetry_interval_hours: int | None = None, + ) -> AppSettings: + """Update app settings. Only provided fields are updated.""" + async with db.tx() as conn: + await AppSettingsRepository._apply_updates( + conn, + max_radio_contacts=max_radio_contacts, + auto_decrypt_dm_on_advert=auto_decrypt_dm_on_advert, + last_message_times=last_message_times, + advert_interval=advert_interval, + last_advert_time=last_advert_time, + flood_scope=flood_scope, + blocked_keys=blocked_keys, + blocked_names=blocked_names, + discovery_blocked_types=discovery_blocked_types, + tracked_telemetry_repeaters=tracked_telemetry_repeaters, + auto_resend_channel=auto_resend_channel, + telemetry_interval_hours=telemetry_interval_hours, + ) + return await AppSettingsRepository._get_in_conn(conn) @staticmethod async def toggle_blocked_key(key: str) -> AppSettings: - """Toggle a public key in the blocked list. Keys are normalized to lowercase.""" + """Toggle a public key in the blocked list. Keys are normalized to lowercase. + + Read-modify-write is atomic under a single ``db.tx()`` lock — two + concurrent toggles for the same key cannot produce an inconsistent + intermediate state. + """ normalized = key.lower() - settings = await AppSettingsRepository.get() - if normalized in settings.blocked_keys: - new_keys = [k for k in settings.blocked_keys if k != normalized] - else: - new_keys = settings.blocked_keys + [normalized] - return await AppSettingsRepository.update(blocked_keys=new_keys) + async with db.tx() as conn: + settings = await AppSettingsRepository._get_in_conn(conn) + if normalized in settings.blocked_keys: + new_keys = [k for k in settings.blocked_keys if k != normalized] + else: + new_keys = settings.blocked_keys + [normalized] + await AppSettingsRepository._apply_updates(conn, blocked_keys=new_keys) + return await AppSettingsRepository._get_in_conn(conn) @staticmethod async def toggle_blocked_name(name: str) -> AppSettings: - """Toggle a display name in the blocked list.""" - settings = await AppSettingsRepository.get() - if name in settings.blocked_names: - new_names = [n for n in settings.blocked_names if n != name] - else: - new_names = settings.blocked_names + [name] - return await AppSettingsRepository.update(blocked_names=new_names) + """Toggle a display name in the blocked list. + + Same atomicity guarantee as ``toggle_blocked_key``. + """ + async with db.tx() as conn: + settings = await AppSettingsRepository._get_in_conn(conn) + if name in settings.blocked_names: + new_names = [n for n in settings.blocked_names if n != name] + else: + new_names = settings.blocked_names + [name] + await AppSettingsRepository._apply_updates(conn, blocked_names=new_names) + return await AppSettingsRepository._get_in_conn(conn) class StatisticsRepository: @staticmethod async def get_database_message_totals() -> dict[str, int]: """Return message totals needed by lightweight debug surfaces.""" - cursor = await db.conn.execute( - """ - SELECT - SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms, - SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages, - SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing - FROM messages - """ - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT + SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms, + SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages, + SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing + FROM messages + """ + ) as cursor: + row = await cursor.fetchone() assert row is not None return { "total_dms": row["total_dms"] or 0, @@ -240,18 +310,19 @@ class StatisticsRepository: """Get time-windowed counts for contacts/repeaters heard.""" now = int(time.time()) op = "!=" if exclude else "=" - cursor = await db.conn.execute( - f""" - SELECT - SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour, - SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours, - SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week - FROM contacts - WHERE type {op} ? AND last_seen IS NOT NULL - """, - (now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + f""" + SELECT + SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour, + SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours, + SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week + FROM contacts + WHERE type {op} ? AND last_seen IS NOT NULL + """, + (now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type), + ) as cursor: + row = await cursor.fetchone() assert row is not None # Aggregate query always returns a row return { "last_hour": row["last_hour"] or 0, @@ -267,24 +338,25 @@ class StatisticsRepository: the old UPPER(...) join and aggregate per known channel directly. """ now = int(time.time()) - cursor = await db.conn.execute( - """ - WITH known AS ( - SELECT conversation_key, MAX(received_at) AS last_received_at - FROM messages - WHERE type = 'CHAN' - AND conversation_key IN (SELECT key FROM channels) - GROUP BY conversation_key - ) - SELECT - SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_hour, - SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_24_hours, - SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_week - FROM known - """, - (now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D), - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + """ + WITH known AS ( + SELECT conversation_key, MAX(received_at) AS last_received_at + FROM messages + WHERE type = 'CHAN' + AND conversation_key IN (SELECT key FROM channels) + GROUP BY conversation_key + ) + SELECT + SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_hour, + SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_24_hours, + SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_week + FROM known + """, + (now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D), + ) as cursor: + row = await cursor.fetchone() assert row is not None return { "last_hour": row["last_hour"] or 0, @@ -298,92 +370,105 @@ class StatisticsRepository: now = int(time.time()) cutoff = now - SECONDS_72H # Bucket timestamps to the start of each hour - cursor = await db.conn.execute( - """ - SELECT (timestamp / 3600) * 3600 AS hour_ts, COUNT(*) AS count - FROM raw_packets - WHERE timestamp >= ? - GROUP BY hour_ts - ORDER BY hour_ts - """, - (cutoff,), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + """ + SELECT (timestamp / 3600) * 3600 AS hour_ts, COUNT(*) AS count + FROM raw_packets + WHERE timestamp >= ? + GROUP BY hour_ts + ORDER BY hour_ts + """, + (cutoff,), + ) as cursor: + rows = await cursor.fetchall() return [{"timestamp": row["hour_ts"], "count": row["count"]} for row in rows] @staticmethod async def _path_hash_width_24h() -> dict[str, int | float]: """Count parsed raw packets from the last 24h by hop hash width.""" now = int(time.time()) - cursor = await db.conn.execute( - "SELECT data FROM raw_packets WHERE timestamp >= ?", - (now - SECONDS_24H,), - ) - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + "SELECT data FROM raw_packets WHERE timestamp >= ?", + (now - SECONDS_24H,), + ) as cursor: + rows = await cursor.fetchall() return bucket_path_hash_widths(rows) @staticmethod async def get_all() -> dict: - """Aggregate all statistics from existing tables.""" + """Aggregate all statistics from existing tables. + + Each helper acquires its own lock; there's no requirement that the + whole snapshot be atomic. If we ever wanted a consistent snapshot + we'd batch all queries into a single ``db.readonly()`` and use + ``_in_conn`` helpers, but statistics are intentionally approximate. + """ now = int(time.time()) - # Top 5 busiest channels in last 24h - cursor = await db.conn.execute( - """ - SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, - COUNT(*) AS message_count - FROM messages m - LEFT JOIN channels c ON m.conversation_key = c.key - WHERE m.type = 'CHAN' AND m.received_at >= ? - GROUP BY m.conversation_key - ORDER BY COUNT(*) DESC - LIMIT 5 - """, - (now - SECONDS_24H,), - ) - rows = await cursor.fetchall() - busiest_channels_24h = [ - { - "channel_key": row["conversation_key"], - "channel_name": row["channel_name"], - "message_count": row["message_count"], - } - for row in rows - ] + async with db.readonly() as conn: + # Top 5 busiest channels in last 24h + async with conn.execute( + """ + SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, + COUNT(*) AS message_count + FROM messages m + LEFT JOIN channels c ON m.conversation_key = c.key + WHERE m.type = 'CHAN' AND m.received_at >= ? + GROUP BY m.conversation_key + ORDER BY COUNT(*) DESC + LIMIT 5 + """, + (now - SECONDS_24H,), + ) as cursor: + rows = await cursor.fetchall() + busiest_channels_24h = [ + { + "channel_key": row["conversation_key"], + "channel_name": row["channel_name"], + "message_count": row["message_count"], + } + for row in rows + ] - # Entity counts - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2") - row = await cursor.fetchone() - assert row is not None - contact_count: int = row["cnt"] + # Entity counts + async with conn.execute( + "SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2" + ) as cursor: + row = await cursor.fetchone() + assert row is not None + contact_count: int = row["cnt"] - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2") - row = await cursor.fetchone() - assert row is not None - repeater_count: int = row["cnt"] + async with conn.execute( + "SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2" + ) as cursor: + row = await cursor.fetchone() + assert row is not None + repeater_count: int = row["cnt"] - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM channels") - row = await cursor.fetchone() - assert row is not None - channel_count: int = row["cnt"] + async with conn.execute("SELECT COUNT(*) AS cnt FROM channels") as cursor: + row = await cursor.fetchone() + assert row is not None + channel_count: int = row["cnt"] - # Packet split - cursor = await db.conn.execute( - """ - SELECT COUNT(*) AS total, - SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted - FROM raw_packets - """ - ) - pkt_row = await cursor.fetchone() - assert pkt_row is not None - total_packets = pkt_row["total"] or 0 - decrypted_packets = pkt_row["decrypted"] or 0 - undecrypted_packets = total_packets - decrypted_packets + # Packet split + async with conn.execute( + """ + SELECT COUNT(*) AS total, + SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted + FROM raw_packets + """ + ) as cursor: + pkt_row = await cursor.fetchone() + assert pkt_row is not None + total_packets = pkt_row["total"] or 0 + decrypted_packets = pkt_row["decrypted"] or 0 + undecrypted_packets = total_packets - decrypted_packets + # These each acquire their own lock. The snapshot isn't atomic across + # them — fine for stats, which are approximate by nature. message_totals = await StatisticsRepository.get_database_message_totals() - - # Activity windows contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True) repeaters_heard = await StatisticsRepository._activity_counts(contact_type=2) known_channels_active = await StatisticsRepository._known_channels_active() diff --git a/tests/conftest.py b/tests/conftest.py index e27cb73..794b893 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,13 +28,28 @@ def cleanup_test_db_dir(): @pytest.fixture async def test_db(): """Create an in-memory test database with schema + migrations.""" - from app.repository import channels, contacts, messages, raw_packets, settings + from app.repository import ( + channels, + contacts, + messages, + raw_packets, + repeater_telemetry, + settings, + ) from app.repository import fanout as fanout_repo db = Database(":memory:") await db.connect() - submodules = [contacts, channels, messages, raw_packets, settings, fanout_repo] + submodules = [ + contacts, + channels, + messages, + raw_packets, + settings, + fanout_repo, + repeater_telemetry, + ] originals = [(mod, mod.db) for mod in submodules] for mod in submodules: diff --git a/tests/test_packets_router.py b/tests/test_packets_router.py index 7876017..4d14418 100644 --- a/tests/test_packets_router.py +++ b/tests/test_packets_router.py @@ -322,7 +322,7 @@ class TestUndecryptedTextPacketStreaming: [], ] - async def fake_execute(*_args, **_kwargs): + def fake_execute(*_args, **_kwargs): batch = batches.pop(0) class FakeCursor: @@ -332,6 +332,16 @@ class TestUndecryptedTextPacketStreaming: async def close(self): pass + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + # aiosqlite's execute() returns a `contextmanager`-decorated + # coroutine that is both awaitable and usable as an async-with. + # Our repo code now uses `async with conn.execute(...) as cursor:`, + # so the mock just needs to return something with __aenter__/__aexit__. return FakeCursor() with patch.object(test_db.conn, "execute", side_effect=fake_execute): diff --git a/tests/test_repository.py b/tests/test_repository.py index 733337b..bc2bbbe 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -1,11 +1,12 @@ """Tests for repository layer.""" -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import patch import pytest from app.models import Contact, ContactUpsert from app.repository import ( + AppSettingsRepository, ContactAdvertPathRepository, ContactNameHistoryRepository, ContactRepository, @@ -613,37 +614,103 @@ class TestAppSettingsRepository: """Test AppSettingsRepository parsing and migration edge cases.""" @pytest.mark.asyncio - async def test_get_handles_corrupted_json_and_invalid_sort_order(self): - """Corrupted JSON fields are recovered with safe defaults.""" - mock_conn = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchone = AsyncMock( - return_value={ - "max_radio_contacts": 250, - "auto_decrypt_dm_on_advert": 1, - "last_message_times": "{also-not-json", - "advert_interval": None, - "last_advert_time": None, - "flood_scope": "", - "blocked_keys": "[]", - "blocked_names": "[]", - "discovery_blocked_types": "[]", - } + async def test_get_handles_corrupted_json_and_invalid_sort_order(self, test_db): + """Corrupted JSON fields are recovered with safe defaults. + + Uses the real DB so it exercises the lock-aware path. We stuff + malformed JSON directly into the row, then verify ``get()`` recovers + with defaults rather than propagating a parse error. + """ + await test_db.conn.execute( + """ + UPDATE app_settings + SET max_radio_contacts = 250, + auto_decrypt_dm_on_advert = 1, + last_message_times = '{also-not-json', + advert_interval = NULL, + last_advert_time = NULL, + flood_scope = '', + blocked_keys = '[]', + blocked_names = '[]', + discovery_blocked_types = '[]' + WHERE id = 1 + """ ) - mock_conn.execute = AsyncMock(return_value=mock_cursor) - mock_db = MagicMock() - mock_db.conn = mock_conn + await test_db.conn.commit() - with patch("app.repository.settings.db", mock_db): - from app.repository import AppSettingsRepository - - settings = await AppSettingsRepository.get() + settings = await AppSettingsRepository.get() assert settings.max_radio_contacts == 250 assert settings.last_message_times == {} assert settings.advert_interval == 0 assert settings.last_advert_time == 0 + @pytest.mark.asyncio + async def test_get_in_conn_tolerates_missing_columns(self): + """Defend against partial migrations where columns added by later + migrations are absent from the row. + + Real DBs can't produce this state (schema init + migrations always + run to the latest version on startup), but hand-rolled snapshots, + external DB tools, or interrupted migrations might. The + ``KeyError``-catching branches in ``_get_in_conn`` exist specifically + to guarantee graceful degradation. + + We test these directly by mocking the connection boundary with a + dict-backed row that mimics a pre-migration snapshot missing: + - ``tracked_telemetry_repeaters`` (migration 53) + - ``auto_resend_channel`` (migration 54) + - ``telemetry_interval_hours`` (migration 57) + """ + from unittest.mock import MagicMock + + from app.telemetry_interval import DEFAULT_TELEMETRY_INTERVAL_HOURS + + # sqlite3.Row raises KeyError for missing columns when accessed by + # name, which is what we want to simulate. We mimic that here with a + # dict-backed object whose __getitem__ raises KeyError for absent + # keys (dict.__getitem__ already does this). + class PartialRow(dict): + def keys(self): # pragma: no cover - aiosqlite.Row compat + return super().keys() + + partial_row = PartialRow( + { + "max_radio_contacts": 123, + "auto_decrypt_dm_on_advert": 1, + "last_message_times": "{}", + "advert_interval": 0, + "last_advert_time": 0, + "flood_scope": "", + "blocked_keys": "[]", + "blocked_names": "[]", + "discovery_blocked_types": "[]", + # intentionally missing: tracked_telemetry_repeaters, + # auto_resend_channel, telemetry_interval_hours + } + ) + + class FakeCursor: + async def fetchone(self): + return partial_row + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + mock_conn = MagicMock() + mock_conn.execute = MagicMock(return_value=FakeCursor()) + + settings = await AppSettingsRepository._get_in_conn(mock_conn) + + assert settings.max_radio_contacts == 123 + # Missing-column defaults kick in: + assert settings.tracked_telemetry_repeaters == [] + assert settings.auto_resend_channel is False + assert settings.telemetry_interval_hours == DEFAULT_TELEMETRY_INTERVAL_HOURS + class TestMessageRepositoryGetById: """Test MessageRepository.get_by_id method.""" diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 31910e3..6db352e 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -2,7 +2,7 @@ import time from types import SimpleNamespace -from unittest.mock import AsyncMock, patch +from unittest.mock import patch import pytest @@ -353,13 +353,21 @@ class TestPathHashWidthStats: @pytest.mark.asyncio async def test_path_hash_width_scan_fetches_all_then_buckets(self, test_db): - """Hash-width stats should fetchall() then bucket synchronously.""" + """Hash-width stats should fetchall() then bucket synchronously. - fake_rows = [{"data": b"a"}, {"data": b"b"}, {"data": b"c"}] + Uses real DB rows + a patched parser so it exercises the lock-aware + readonly path. Mocking ``conn.execute`` on the pre-refactor code no + longer reflects the actual call pattern (we use ``async with``). + """ - class FakeCursor: - async def fetchall(self): - return fake_rows + now = int(time.time()) + # Seed three raw packets in the last 24h with arbitrary distinguishing bytes. + for i, data in enumerate((b"a", b"b", b"c")): + await test_db.conn.execute( + "INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)", + (now - (i + 1), data), + ) + await test_db.conn.commit() def fake_parse(raw_packet: bytes): hash_sizes = { @@ -372,10 +380,7 @@ class TestPathHashWidthStats: return None return SimpleNamespace(hash_size=hash_size) - with ( - patch.object(test_db.conn, "execute", new=AsyncMock(return_value=FakeCursor())), - patch("app.path_utils.parse_packet_envelope", side_effect=fake_parse), - ): + with patch("app.path_utils.parse_packet_envelope", side_effect=fake_parse): breakdown = await StatisticsRepository._path_hash_width_24h() assert breakdown["total_packets"] == 3