Work on some more concurrency fixes re: locks and context managers. Poking at #179.

This commit is contained in:
Jack Kingsman
2026-04-16 18:04:56 -07:00
parent 4b69ec4519
commit 4783da8f3e
12 changed files with 1529 additions and 1152 deletions

View File

@@ -1,4 +1,7 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
import aiosqlite
@@ -165,9 +168,74 @@ CREATE INDEX IF NOT EXISTS idx_repeater_telemetry_pk_ts
class Database:
"""Single-connection aiosqlite wrapper with coroutine-level serialization.
Why the lock: aiosqlite runs one ``sqlite3.Connection`` on a background
worker thread and serializes statement execution there. But SQLite's
``COMMIT`` fails with ``OperationalError: cannot commit transaction -
SQL statements in progress`` whenever *any* cursor on the connection has
a live prepared statement (a ``SELECT`` that returned ``SQLITE_ROW`` but
hasn't been fully consumed or closed). Under concurrent coroutines, one
task's in-flight ``fetchone()`` can still be in ``SQLITE_ROW`` state when
another task's ``commit()`` runs on the worker — triggering the error.
Fix: all DB work goes through ``tx()`` (writes) or ``readonly()`` (reads),
both of which acquire ``self._lock``. The lock is non-reentrant (asyncio
default) by design — nested ``tx()`` calls are a bug. Repository methods
that compose multiple operations factor the raw SQL into private helpers
that take a ``conn`` and don't lock; the public method acquires the lock
once and calls those helpers.
Why reads are also locked: reads must also hold the lock, because a read
in ``SQLITE_ROW`` state is precisely the live statement that breaks a
concurrent writer's commit. Single-connection aiosqlite cannot safely
overlap reads and writes. If we ever split reader/writer connections in
the future, ``readonly()`` becomes the seam to point at the reader pool.
"""
def __init__(self, db_path: str):
self.db_path = db_path
self._connection: aiosqlite.Connection | None = None
self._lock = asyncio.Lock()
@asynccontextmanager
async def tx(self) -> AsyncIterator[aiosqlite.Connection]:
"""Acquire the connection for a write transaction.
Commits on clean exit, rolls back on exception. Callers MUST close
every cursor opened inside the block (use ``async with conn.execute(...)
as cursor:``) so no prepared statement is alive when commit runs.
The lock serializes concurrent writers AND ensures no reader's cursor
is alive during the commit. Nested calls will deadlock — factor shared
SQL into helpers that accept ``conn`` and do not re-enter ``tx()``.
"""
async with self._lock:
if self._connection is None:
raise RuntimeError("Database not connected")
conn = self._connection
try:
yield conn
except BaseException:
await conn.rollback()
raise
else:
await conn.commit()
@asynccontextmanager
async def readonly(self) -> AsyncIterator[aiosqlite.Connection]:
"""Acquire the connection for a read. No commit, no rollback.
Locked for the same reason writes are: on a single connection, an
active read statement blocks a concurrent writer's commit. Callers
MUST fully consume or close cursors before the block exits (use
``async with conn.execute(...) as cursor:`` + ``fetchall`` /
``fetchone``; avoid holding a cursor across ``await`` on other IO).
"""
async with self._lock:
if self._connection is None:
raise RuntimeError("Database not connected")
yield self._connection
async def connect(self) -> None:
logger.info("Connecting to database at %s", self.db_path)

View File

@@ -8,31 +8,33 @@ class ChannelRepository:
@staticmethod
async def upsert(key: str, name: str, is_hashtag: bool = False, on_radio: bool = False) -> None:
"""Upsert a channel. Key is 32-char hex string."""
await db.conn.execute(
"""
INSERT INTO channels (key, name, is_hashtag, on_radio, flood_scope_override)
VALUES (?, ?, ?, ?, NULL)
ON CONFLICT(key) DO UPDATE SET
name = excluded.name,
is_hashtag = excluded.is_hashtag,
on_radio = excluded.on_radio
""",
(key.upper(), name, is_hashtag, on_radio),
)
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute(
"""
INSERT INTO channels (key, name, is_hashtag, on_radio, flood_scope_override)
VALUES (?, ?, ?, ?, NULL)
ON CONFLICT(key) DO UPDATE SET
name = excluded.name,
is_hashtag = excluded.is_hashtag,
on_radio = excluded.on_radio
""",
(key.upper(), name, is_hashtag, on_radio),
):
pass
@staticmethod
async def get_by_key(key: str) -> Channel | None:
"""Get a channel by its key (32-char hex string)."""
cursor = await db.conn.execute(
"""
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
FROM channels
WHERE key = ?
""",
(key.upper(),),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
FROM channels
WHERE key = ?
""",
(key.upper(),),
) as cursor:
row = await cursor.fetchone()
if row:
return Channel(
key=row["key"],
@@ -48,14 +50,15 @@ class ChannelRepository:
@staticmethod
async def get_all() -> list[Channel]:
cursor = await db.conn.execute(
"""
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
FROM channels
ORDER BY name
"""
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
FROM channels
ORDER BY name
"""
) as cursor:
rows = await cursor.fetchall()
return [
Channel(
key=row["key"],
@@ -73,21 +76,23 @@ class ChannelRepository:
@staticmethod
async def set_favorite(key: str, value: bool) -> bool:
"""Set or clear the favorite flag for a channel. Returns True if row was found."""
cursor = await db.conn.execute(
"UPDATE channels SET favorite = ? WHERE key = ?",
(1 if value else 0, key.upper()),
)
await db.conn.commit()
return cursor.rowcount > 0
async with db.tx() as conn:
async with conn.execute(
"UPDATE channels SET favorite = ? WHERE key = ?",
(1 if value else 0, key.upper()),
) as cursor:
rowcount = cursor.rowcount
return rowcount > 0
@staticmethod
async def delete(key: str) -> None:
"""Delete a channel by key."""
await db.conn.execute(
"DELETE FROM channels WHERE key = ?",
(key.upper(),),
)
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute(
"DELETE FROM channels WHERE key = ?",
(key.upper(),),
):
pass
@staticmethod
async def update_last_read_at(key: str, timestamp: int | None = None) -> bool:
@@ -96,35 +101,39 @@ class ChannelRepository:
Returns True if a row was updated, False if channel not found.
"""
ts = timestamp if timestamp is not None else int(time.time())
cursor = await db.conn.execute(
"UPDATE channels SET last_read_at = ? WHERE key = ?",
(ts, key.upper()),
)
await db.conn.commit()
return cursor.rowcount > 0
async with db.tx() as conn:
async with conn.execute(
"UPDATE channels SET last_read_at = ? WHERE key = ?",
(ts, key.upper()),
) as cursor:
rowcount = cursor.rowcount
return rowcount > 0
@staticmethod
async def update_flood_scope_override(key: str, flood_scope_override: str | None) -> bool:
"""Set or clear a channel's flood-scope override."""
cursor = await db.conn.execute(
"UPDATE channels SET flood_scope_override = ? WHERE key = ?",
(flood_scope_override, key.upper()),
)
await db.conn.commit()
return cursor.rowcount > 0
async with db.tx() as conn:
async with conn.execute(
"UPDATE channels SET flood_scope_override = ? WHERE key = ?",
(flood_scope_override, key.upper()),
) as cursor:
rowcount = cursor.rowcount
return rowcount > 0
@staticmethod
async def update_path_hash_mode_override(key: str, path_hash_mode_override: int | None) -> bool:
"""Set or clear a channel's path hash mode override."""
cursor = await db.conn.execute(
"UPDATE channels SET path_hash_mode_override = ? WHERE key = ?",
(path_hash_mode_override, key.upper()),
)
await db.conn.commit()
return cursor.rowcount > 0
async with db.tx() as conn:
async with conn.execute(
"UPDATE channels SET path_hash_mode_override = ? WHERE key = ?",
(path_hash_mode_override, key.upper()),
) as cursor:
rowcount = cursor.rowcount
return rowcount > 0
@staticmethod
async def mark_all_read(timestamp: int) -> None:
"""Mark all channels as read at the given timestamp."""
await db.conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,))
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,)):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,8 @@ import time
import uuid
from typing import Any
import aiosqlite
from app.database import db
logger = logging.getLogger(__name__)
@@ -31,26 +33,37 @@ def _row_to_dict(row: Any) -> dict[str, Any]:
return result
async def _get_in_conn(conn: aiosqlite.Connection, config_id: str) -> dict[str, Any] | None:
"""Fetch a config using an already-acquired connection.
Used by ``create`` and ``update`` to return the freshly-written row
without re-entering the non-reentrant DB lock.
"""
async with conn.execute("SELECT * FROM fanout_configs WHERE id = ?", (config_id,)) as cursor:
row = await cursor.fetchone()
if row is None:
return None
return _row_to_dict(row)
class FanoutConfigRepository:
"""CRUD operations for fanout_configs table."""
@staticmethod
async def get_all() -> list[dict[str, Any]]:
"""Get all fanout configs ordered by sort_order."""
cursor = await db.conn.execute(
"SELECT * FROM fanout_configs ORDER BY sort_order, created_at"
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"SELECT * FROM fanout_configs ORDER BY sort_order, created_at"
) as cursor:
rows = await cursor.fetchall()
return [_row_to_dict(row) for row in rows]
@staticmethod
async def get(config_id: str) -> dict[str, Any] | None:
"""Get a single fanout config by ID."""
cursor = await db.conn.execute("SELECT * FROM fanout_configs WHERE id = ?", (config_id,))
row = await cursor.fetchone()
if row is None:
return None
return _row_to_dict(row)
async with db.readonly() as conn:
return await _get_in_conn(conn, config_id)
@staticmethod
async def create(
@@ -65,39 +78,41 @@ class FanoutConfigRepository:
new_id = config_id or str(uuid.uuid4())
now = int(time.time())
# Get next sort_order
cursor = await db.conn.execute(
"SELECT COALESCE(MAX(sort_order), -1) + 1 FROM fanout_configs"
)
row = await cursor.fetchone()
sort_order = row[0] if row else 0
async with db.tx() as conn:
# Determine next sort_order under the same lock as the insert,
# so two concurrent ``create()`` calls cannot collide.
async with conn.execute(
"SELECT COALESCE(MAX(sort_order), -1) + 1 FROM fanout_configs"
) as cursor:
row = await cursor.fetchone()
sort_order = row[0] if row else 0
await db.conn.execute(
"""
INSERT INTO fanout_configs (id, type, name, enabled, config, scope, sort_order, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
new_id,
config_type,
name,
1 if enabled else 0,
json.dumps(config),
json.dumps(scope),
sort_order,
now,
),
)
await db.conn.commit()
async with conn.execute(
"""
INSERT INTO fanout_configs (id, type, name, enabled, config, scope, sort_order, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
new_id,
config_type,
name,
1 if enabled else 0,
json.dumps(config),
json.dumps(scope),
sort_order,
now,
),
):
pass
result = await FanoutConfigRepository.get(new_id)
result = await _get_in_conn(conn, new_id)
assert result is not None
return result
@staticmethod
async def update(config_id: str, **fields: Any) -> dict[str, Any] | None:
"""Update a fanout config. Only provided fields are updated."""
updates = []
updates: list[str] = []
params: list[Any] = []
for field in ("name", "enabled", "config", "scope", "sort_order"):
@@ -115,23 +130,25 @@ class FanoutConfigRepository:
params.append(config_id)
query = f"UPDATE fanout_configs SET {', '.join(updates)} WHERE id = ?"
await db.conn.execute(query, params)
await db.conn.commit()
return await FanoutConfigRepository.get(config_id)
async with db.tx() as conn:
async with conn.execute(query, params):
pass
return await _get_in_conn(conn, config_id)
@staticmethod
async def delete(config_id: str) -> None:
"""Delete a fanout config."""
await db.conn.execute("DELETE FROM fanout_configs WHERE id = ?", (config_id,))
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute("DELETE FROM fanout_configs WHERE id = ?", (config_id,)):
pass
_configs_cache.pop(config_id, None)
@staticmethod
async def get_enabled() -> list[dict[str, Any]]:
"""Get all enabled fanout configs."""
cursor = await db.conn.execute(
"SELECT * FROM fanout_configs WHERE enabled = 1 ORDER BY sort_order, created_at"
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"SELECT * FROM fanout_configs WHERE enabled = 1 ORDER BY sort_order, created_at"
) as cursor:
rows = await cursor.fetchall()
return [_row_to_dict(row) for row in rows]

View File

@@ -89,32 +89,34 @@ class MessageRepository:
# Normalize sender_key to lowercase so queries can match without LOWER().
normalized_sender_key = sender_key.lower() if sender_key else sender_key
cursor = await db.conn.execute(
"""
INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp,
received_at, paths, txt_type, signature, outgoing,
sender_name, sender_key)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
msg_type,
conversation_key,
text,
sender_timestamp,
received_at,
paths_json,
txt_type,
signature,
outgoing,
sender_name,
normalized_sender_key,
),
)
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute(
"""
INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp,
received_at, paths, txt_type, signature, outgoing,
sender_name, sender_key)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
msg_type,
conversation_key,
text,
sender_timestamp,
received_at,
paths_json,
txt_type,
signature,
outgoing,
sender_name,
normalized_sender_key,
),
) as cursor:
rowcount = cursor.rowcount
lastrowid = cursor.lastrowid
# rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation
if cursor.rowcount == 0:
if rowcount == 0:
return None
return cursor.lastrowid
return lastrowid
@staticmethod
async def add_path(
@@ -142,17 +144,20 @@ class MessageRepository:
if snr is not None:
entry["snr"] = snr
new_entry = json.dumps(entry)
await db.conn.execute(
"""UPDATE messages SET paths = json_insert(
COALESCE(paths, '[]'), '$[#]', json(?)
) WHERE id = ?""",
(new_entry, message_id),
)
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute(
"""UPDATE messages SET paths = json_insert(
COALESCE(paths, '[]'), '$[#]', json(?)
) WHERE id = ?""",
(new_entry, message_id),
):
pass
# Read back the full list for the return value
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
row = await cursor.fetchone()
# Read back the full list for the return value, same transaction.
async with conn.execute(
"SELECT paths FROM messages WHERE id = ?", (message_id,)
) as cursor:
row = await cursor.fetchone()
if not row or not row["paths"]:
return []
@@ -171,23 +176,24 @@ class MessageRepository:
only a prefix as conversation_key are updated to use the full key.
"""
lower_key = full_key.lower()
cursor = await db.conn.execute(
"""UPDATE messages SET conversation_key = ?,
sender_key = CASE
WHEN sender_key IS NOT NULL AND length(sender_key) < 64
AND ? LIKE sender_key || '%'
THEN ? ELSE sender_key END
WHERE type = 'PRIV' AND length(conversation_key) < 64
AND ? LIKE conversation_key || '%'
AND (
SELECT COUNT(*) FROM contacts
WHERE length(public_key) = 64
AND public_key LIKE messages.conversation_key || '%'
) = 1""",
(lower_key, lower_key, lower_key, lower_key),
)
await db.conn.commit()
return cursor.rowcount
async with db.tx() as conn:
async with conn.execute(
"""UPDATE messages SET conversation_key = ?,
sender_key = CASE
WHEN sender_key IS NOT NULL AND length(sender_key) < 64
AND ? LIKE sender_key || '%'
THEN ? ELSE sender_key END
WHERE type = 'PRIV' AND length(conversation_key) < 64
AND ? LIKE conversation_key || '%'
AND (
SELECT COUNT(*) FROM contacts
WHERE length(public_key) = 64
AND public_key LIKE messages.conversation_key || '%'
) = 1""",
(lower_key, lower_key, lower_key, lower_key),
) as cursor:
rowcount = cursor.rowcount
return rowcount
@staticmethod
async def backfill_channel_sender_key(public_key: str, name: str) -> int:
@@ -197,21 +203,22 @@ class MessageRepository:
any channel messages with a matching sender_name but no sender_key
are updated to associate them with this contact's public key.
"""
cursor = await db.conn.execute(
"""UPDATE messages SET sender_key = ?
WHERE type = 'CHAN' AND sender_name = ? AND sender_key IS NULL
AND (
SELECT COUNT(*) FROM contacts
WHERE name = ?
) = 1
AND EXISTS (
SELECT 1 FROM contacts
WHERE public_key = ? AND name = ?
)""",
(public_key.lower(), name, name, public_key.lower(), name),
)
await db.conn.commit()
return cursor.rowcount
async with db.tx() as conn:
async with conn.execute(
"""UPDATE messages SET sender_key = ?
WHERE type = 'CHAN' AND sender_name = ? AND sender_key IS NULL
AND (
SELECT COUNT(*) FROM contacts
WHERE name = ?
) = 1
AND EXISTS (
SELECT 1 FROM contacts
WHERE public_key = ? AND name = ?
)""",
(public_key.lower(), name, name, public_key.lower(), name),
) as cursor:
rowcount = cursor.rowcount
return rowcount
@staticmethod
def _normalize_conversation_key(conversation_key: str) -> tuple[str, str]:
@@ -462,8 +469,9 @@ class MessageRepository:
query += " OFFSET ?"
params.append(offset)
cursor = await db.conn.execute(query, params)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [MessageRepository._row_to_message(row) for row in rows]
@staticmethod
@@ -501,51 +509,54 @@ class MessageRepository:
where_sql = " AND ".join(["1=1", *where_parts])
# 1. Get the target message (must satisfy filters if provided)
target_cursor = await db.conn.execute(
f"SELECT {MessageRepository._message_select('messages')} "
f"FROM messages WHERE id = ? AND {where_sql}",
(message_id, *base_params),
)
target_row = await target_cursor.fetchone()
if not target_row:
return [], False, False
async with db.readonly() as conn:
async with conn.execute(
f"SELECT {MessageRepository._message_select('messages')} "
f"FROM messages WHERE id = ? AND {where_sql}",
(message_id, *base_params),
) as target_cursor:
target_row = await target_cursor.fetchone()
if not target_row:
return [], False, False
target = MessageRepository._row_to_message(target_row)
target = MessageRepository._row_to_message(target_row)
# 2. Get context_size+1 messages before target (DESC)
before_query = f"""
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
AND (received_at < ? OR (received_at = ? AND id < ?))
ORDER BY received_at DESC, id DESC LIMIT ?
"""
before_params = [
*base_params,
target.received_at,
target.received_at,
target.id,
context_size + 1,
]
before_cursor = await db.conn.execute(before_query, before_params)
before_rows = list(await before_cursor.fetchall())
# 2. Get context_size+1 messages before target (DESC)
before_query = f"""
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
AND (received_at < ? OR (received_at = ? AND id < ?))
ORDER BY received_at DESC, id DESC LIMIT ?
"""
before_params = [
*base_params,
target.received_at,
target.received_at,
target.id,
context_size + 1,
]
async with conn.execute(before_query, before_params) as before_cursor:
before_rows = list(await before_cursor.fetchall())
has_older = len(before_rows) > context_size
before_messages = [MessageRepository._row_to_message(r) for r in before_rows[:context_size]]
has_older = len(before_rows) > context_size
before_messages = [
MessageRepository._row_to_message(r) for r in before_rows[:context_size]
]
# 3. Get context_size+1 messages after target (ASC)
after_query = f"""
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
AND (received_at > ? OR (received_at = ? AND id > ?))
ORDER BY received_at ASC, id ASC LIMIT ?
"""
after_params = [
*base_params,
target.received_at,
target.received_at,
target.id,
context_size + 1,
]
after_cursor = await db.conn.execute(after_query, after_params)
after_rows = list(await after_cursor.fetchall())
# 3. Get context_size+1 messages after target (ASC)
after_query = f"""
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
AND (received_at > ? OR (received_at = ? AND id > ?))
ORDER BY received_at ASC, id ASC LIMIT ?
"""
after_params = [
*base_params,
target.received_at,
target.received_at,
target.id,
context_size + 1,
]
async with conn.execute(after_query, after_params) as after_cursor:
after_rows = list(await after_cursor.fetchall())
has_newer = len(after_rows) > context_size
after_messages = [MessageRepository._row_to_message(r) for r in after_rows[:context_size]]
@@ -556,21 +567,29 @@ class MessageRepository:
@staticmethod
async def increment_ack_count(message_id: int) -> int:
"""Increment ack count and return the new value."""
cursor = await db.conn.execute(
"UPDATE messages SET acked = acked + 1 WHERE id = ? RETURNING acked", (message_id,)
)
row = await cursor.fetchone()
await db.conn.commit()
"""Increment ack count and return the new value.
NOTE: ``RETURNING`` leaves the prepared statement active until the
row is fetched, so we MUST consume it inside the ``async with``
block. Without that, the commit at the end of ``db.tx()`` fails
with ``cannot commit transaction - SQL statements in progress``.
"""
async with db.tx() as conn:
async with conn.execute(
"UPDATE messages SET acked = acked + 1 WHERE id = ? RETURNING acked",
(message_id,),
) as cursor:
row = await cursor.fetchone()
return row["acked"] if row else 1
@staticmethod
async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]:
"""Get the current ack count and paths for a message."""
cursor = await db.conn.execute(
"SELECT acked, paths FROM messages WHERE id = ?", (message_id,)
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT acked, paths FROM messages WHERE id = ?", (message_id,)
) as cursor:
row = await cursor.fetchone()
if not row:
return 0, None
return row["acked"], MessageRepository._parse_paths(row["paths"])
@@ -578,11 +597,12 @@ class MessageRepository:
@staticmethod
async def get_by_id(message_id: int) -> "Message | None":
"""Look up a message by its ID."""
cursor = await db.conn.execute(
f"SELECT {MessageRepository._message_select('messages')} FROM messages WHERE id = ?",
(message_id,),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
f"SELECT {MessageRepository._message_select('messages')} FROM messages WHERE id = ?",
(message_id,),
) as cursor:
row = await cursor.fetchone()
if not row:
return None
@@ -591,11 +611,14 @@ class MessageRepository:
@staticmethod
async def delete_by_id(message_id: int) -> None:
"""Delete a message row by ID."""
await db.conn.execute(
"UPDATE raw_packets SET message_id = NULL WHERE message_id = ?", (message_id,)
)
await db.conn.execute("DELETE FROM messages WHERE id = ?", (message_id,))
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute(
"UPDATE raw_packets SET message_id = NULL WHERE message_id = ?",
(message_id,),
):
pass
async with conn.execute("DELETE FROM messages WHERE id = ?", (message_id,)):
pass
@staticmethod
async def get_by_content(
@@ -618,8 +641,9 @@ class MessageRepository:
query += " AND outgoing = ?"
params.append(1 if outgoing else 0)
query += " ORDER BY id ASC"
cursor = await db.conn.execute(query, params)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(query, params) as cursor:
row = await cursor.fetchone()
if not row:
return None
@@ -653,76 +677,6 @@ class MessageRepository:
)
blocked_sql = f" AND {blocked_clause}" if blocked_clause else ""
# Channel unreads
cursor = await db.conn.execute(
f"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0
END) > 0 as has_mention
FROM messages m
JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.outgoing = 0
AND m.received_at > COALESCE(c.last_read_at, 0)
{blocked_sql}
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or "", *blocked_params),
)
rows = await cursor.fetchall()
for row in rows:
state_key = f"channel-{row['conversation_key']}"
counts[state_key] = row["unread_count"]
if mention_token and row["has_mention"]:
mention_flags[state_key] = True
# Contact unreads
cursor = await db.conn.execute(
f"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0
END) > 0 as has_mention
FROM messages m
LEFT JOIN contacts ct ON m.conversation_key = ct.public_key
WHERE m.type = 'PRIV' AND m.outgoing = 0
AND m.received_at > COALESCE(ct.last_read_at, 0)
{blocked_sql}
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or "", *blocked_params),
)
rows = await cursor.fetchall()
for row in rows:
state_key = f"contact-{row['conversation_key']}"
counts[state_key] = row["unread_count"]
if mention_token and row["has_mention"]:
mention_flags[state_key] = True
cursor = await db.conn.execute(
"""
SELECT key, last_read_at
FROM channels
"""
)
rows = await cursor.fetchall()
for row in rows:
last_read_ats[f"channel-{row['key']}"] = row["last_read_at"]
cursor = await db.conn.execute(
"""
SELECT public_key, last_read_at
FROM contacts
"""
)
rows = await cursor.fetchall()
for row in rows:
last_read_ats[f"contact-{row['public_key']}"] = row["last_read_at"]
# Last message times for all conversations (including read ones),
# excluding blocked incoming traffic so refresh matches live WS behavior.
last_time_clause, last_time_params = MessageRepository._build_blocked_incoming_clause(
@@ -730,20 +684,94 @@ class MessageRepository:
)
last_time_where_sql = f"WHERE {last_time_clause}" if last_time_clause else ""
cursor = await db.conn.execute(
f"""
SELECT type, conversation_key, MAX(received_at) as last_message_time
FROM messages
{last_time_where_sql}
GROUP BY type, conversation_key
""",
last_time_params,
)
rows = await cursor.fetchall()
for row in rows:
prefix = "channel" if row["type"] == "CHAN" else "contact"
state_key = f"{prefix}-{row['conversation_key']}"
last_message_times[state_key] = row["last_message_time"]
# Single readonly acquisition for all 5 queries — they form one logical
# snapshot, and holding the lock for the batch is cheaper than acquiring
# it 5 times.
async with db.readonly() as conn:
# Channel unreads
async with conn.execute(
f"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0
END) > 0 as has_mention
FROM messages m
JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.outgoing = 0
AND m.received_at > COALESCE(c.last_read_at, 0)
{blocked_sql}
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or "", *blocked_params),
) as cursor:
rows = await cursor.fetchall()
for row in rows:
state_key = f"channel-{row['conversation_key']}"
counts[state_key] = row["unread_count"]
if mention_token and row["has_mention"]:
mention_flags[state_key] = True
# Contact unreads
async with conn.execute(
f"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0
END) > 0 as has_mention
FROM messages m
LEFT JOIN contacts ct ON m.conversation_key = ct.public_key
WHERE m.type = 'PRIV' AND m.outgoing = 0
AND m.received_at > COALESCE(ct.last_read_at, 0)
{blocked_sql}
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or "", *blocked_params),
) as cursor:
rows = await cursor.fetchall()
for row in rows:
state_key = f"contact-{row['conversation_key']}"
counts[state_key] = row["unread_count"]
if mention_token and row["has_mention"]:
mention_flags[state_key] = True
async with conn.execute(
"""
SELECT key, last_read_at
FROM channels
"""
) as cursor:
rows = await cursor.fetchall()
for row in rows:
last_read_ats[f"channel-{row['key']}"] = row["last_read_at"]
async with conn.execute(
"""
SELECT public_key, last_read_at
FROM contacts
"""
) as cursor:
rows = await cursor.fetchall()
for row in rows:
last_read_ats[f"contact-{row['public_key']}"] = row["last_read_at"]
async with conn.execute(
f"""
SELECT type, conversation_key, MAX(received_at) as last_message_time
FROM messages
{last_time_where_sql}
GROUP BY type, conversation_key
""",
last_time_params,
) as cursor:
rows = await cursor.fetchall()
for row in rows:
prefix = "channel" if row["type"] == "CHAN" else "contact"
state_key = f"{prefix}-{row['conversation_key']}"
last_message_times[state_key] = row["last_message_time"]
# Only include last_read_ats for conversations that actually have messages.
# Without this filter, every contact heard via advertisement (even without
@@ -760,41 +788,45 @@ class MessageRepository:
@staticmethod
async def count_dm_messages(contact_key: str) -> int:
"""Count total DM messages for a contact."""
cursor = await db.conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?",
(contact_key.lower(),),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?",
(contact_key.lower(),),
) as cursor:
row = await cursor.fetchone()
return row["cnt"] if row else 0
@staticmethod
async def count_channel_messages_by_sender(sender_key: str) -> int:
"""Count channel messages sent by a specific contact."""
cursor = await db.conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?",
(sender_key.lower(),),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?",
(sender_key.lower(),),
) as cursor:
row = await cursor.fetchone()
return row["cnt"] if row else 0
@staticmethod
async def count_channel_messages_by_sender_name(sender_name: str) -> int:
"""Count channel messages attributed to a display name."""
cursor = await db.conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_name = ?",
(sender_name,),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_name = ?",
(sender_name,),
) as cursor:
row = await cursor.fetchone()
return row["cnt"] if row else 0
@staticmethod
async def get_first_channel_message_by_sender_name(sender_name: str) -> int | None:
"""Get the earliest stored channel message timestamp for a display name."""
cursor = await db.conn.execute(
"SELECT MIN(received_at) AS first_seen FROM messages WHERE type = 'CHAN' AND sender_name = ?",
(sender_name,),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT MIN(received_at) AS first_seen FROM messages WHERE type = 'CHAN' AND sender_name = ?",
(sender_name,),
) as cursor:
row = await cursor.fetchone()
return row["first_seen"] if row and row["first_seen"] is not None else None
@staticmethod
@@ -813,68 +845,76 @@ class MessageRepository:
t_48h = now - 172800
t_7d = now - 604800
cursor = await db.conn.execute(
"""
SELECT COUNT(*) AS all_time,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_1h,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_24h,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_48h,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_7d,
MIN(received_at) AS first_message_at,
COUNT(DISTINCT sender_key) AS unique_sender_count
FROM messages WHERE type = 'CHAN' AND conversation_key = ?
""",
(t_1h, t_24h, t_48h, t_7d, conversation_key),
)
row = await cursor.fetchone()
assert row is not None # Aggregate query always returns a row
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT COUNT(*) AS all_time,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_1h,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_24h,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_48h,
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_7d,
MIN(received_at) AS first_message_at,
COUNT(DISTINCT sender_key) AS unique_sender_count
FROM messages WHERE type = 'CHAN' AND conversation_key = ?
""",
(t_1h, t_24h, t_48h, t_7d, conversation_key),
) as cursor:
row = await cursor.fetchone()
assert row is not None # Aggregate query always returns a row
message_counts = {
"last_1h": row["last_1h"] or 0,
"last_24h": row["last_24h"] or 0,
"last_48h": row["last_48h"] or 0,
"last_7d": row["last_7d"] or 0,
"all_time": row["all_time"] or 0,
}
cursor2 = await db.conn.execute(
"""
SELECT COALESCE(sender_name, sender_key, 'Unknown') AS display_name,
sender_key, COUNT(*) AS cnt
FROM messages
WHERE type = 'CHAN' AND conversation_key = ?
AND received_at >= ? AND sender_key IS NOT NULL
GROUP BY sender_key ORDER BY cnt DESC LIMIT 5
""",
(conversation_key, t_24h),
)
top_rows = await cursor2.fetchall()
top_senders = [
{
"sender_name": r["display_name"],
"sender_key": r["sender_key"],
"message_count": r["cnt"],
message_counts = {
"last_1h": row["last_1h"] or 0,
"last_24h": row["last_24h"] or 0,
"last_48h": row["last_48h"] or 0,
"last_7d": row["last_7d"] or 0,
"all_time": row["all_time"] or 0,
}
for r in top_rows
]
# Path hash width distribution for last 24h (in-Python parse of raw packet envelopes)
cursor3 = await db.conn.execute(
"""
SELECT rp.data FROM raw_packets rp
JOIN messages m ON rp.message_id = m.id
WHERE m.type = 'CHAN' AND m.conversation_key = ?
AND rp.timestamp >= ?
""",
(conversation_key, t_24h),
)
rows3 = await cursor3.fetchall()
async with conn.execute(
"""
SELECT COALESCE(sender_name, sender_key, 'Unknown') AS display_name,
sender_key, COUNT(*) AS cnt
FROM messages
WHERE type = 'CHAN' AND conversation_key = ?
AND received_at >= ? AND sender_key IS NOT NULL
GROUP BY sender_key ORDER BY cnt DESC LIMIT 5
""",
(conversation_key, t_24h),
) as cursor:
top_rows = await cursor.fetchall()
top_senders = [
{
"sender_name": r["display_name"],
"sender_key": r["sender_key"],
"message_count": r["cnt"],
}
for r in top_rows
]
# Path hash width distribution for last 24h: fetch raw rows under
# the lock, then release BEFORE the CPU-bound in-Python envelope
# parse. Parsing can iterate thousands of rows and previously held
# the DB lock for the whole traversal — blocking every other repo
# caller on a Pi. Keep the lock only for the fetch.
async with conn.execute(
"""
SELECT rp.data FROM raw_packets rp
JOIN messages m ON rp.message_id = m.id
WHERE m.type = 'CHAN' AND m.conversation_key = ?
AND rp.timestamp >= ?
""",
(conversation_key, t_24h),
) as cursor:
rows3 = await cursor.fetchall()
first_message_at = row["first_message_at"]
unique_sender_count = row["unique_sender_count"] or 0
path_hash_width_24h = bucket_path_hash_widths(rows3)
return {
"message_counts": message_counts,
"first_message_at": row["first_message_at"],
"unique_sender_count": row["unique_sender_count"] or 0,
"first_message_at": first_message_at,
"unique_sender_count": unique_sender_count,
"top_senders_24h": top_senders,
"path_hash_width_24h": path_hash_width_24h,
}
@@ -882,14 +922,15 @@ class MessageRepository:
@staticmethod
async def count_channels_with_incoming_messages() -> int:
"""Count distinct channel conversations with at least one incoming message."""
cursor = await db.conn.execute(
"""
SELECT COUNT(DISTINCT conversation_key) AS cnt
FROM messages
WHERE type = 'CHAN' AND outgoing = 0
"""
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT COUNT(DISTINCT conversation_key) AS cnt
FROM messages
WHERE type = 'CHAN' AND outgoing = 0
"""
) as cursor:
row = await cursor.fetchone()
return int(row["cnt"]) if row and row["cnt"] is not None else 0
@staticmethod
@@ -898,20 +939,21 @@ class MessageRepository:
Returns list of (channel_key, channel_name, message_count) tuples.
"""
cursor = await db.conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS cnt
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.sender_key = ?
GROUP BY m.conversation_key
ORDER BY cnt DESC
LIMIT ?
""",
(sender_key.lower(), limit),
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS cnt
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.sender_key = ?
GROUP BY m.conversation_key
ORDER BY cnt DESC
LIMIT ?
""",
(sender_key.lower(), limit),
) as cursor:
rows = await cursor.fetchall()
return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows]
@staticmethod
@@ -919,34 +961,36 @@ class MessageRepository:
sender_name: str, limit: int = 5
) -> list[tuple[str, str, int]]:
"""Get channels where a display name has sent the most messages."""
cursor = await db.conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS cnt
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.sender_name = ?
GROUP BY m.conversation_key
ORDER BY cnt DESC
LIMIT ?
""",
(sender_name, limit),
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS cnt
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.sender_name = ?
GROUP BY m.conversation_key
ORDER BY cnt DESC
LIMIT ?
""",
(sender_name, limit),
) as cursor:
rows = await cursor.fetchall()
return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows]
@staticmethod
async def _get_activity_hour_buckets(where_sql: str, params: list[Any]) -> dict[int, int]:
cursor = await db.conn.execute(
f"""
SELECT received_at / 3600 AS hour_bucket, COUNT(*) AS cnt
FROM messages
WHERE {where_sql}
GROUP BY hour_bucket
""",
params,
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
f"""
SELECT received_at / 3600 AS hour_bucket, COUNT(*) AS cnt
FROM messages
WHERE {where_sql}
GROUP BY hour_bucket
""",
params,
) as cursor:
rows = await cursor.fetchall()
return {int(row["hour_bucket"]): row["cnt"] for row in rows}
@staticmethod
@@ -1000,16 +1044,17 @@ class MessageRepository:
current_day_start = (now // 86400) * 86400
start = current_day_start - (weeks - 1) * bucket_seconds
cursor = await db.conn.execute(
f"""
SELECT (received_at - ?) / ? AS bucket_idx, COUNT(*) AS cnt
FROM messages
WHERE {where_sql} AND received_at >= ?
GROUP BY bucket_idx
""",
[start, bucket_seconds, *params, start],
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
f"""
SELECT (received_at - ?) / ? AS bucket_idx, COUNT(*) AS cnt
FROM messages
WHERE {where_sql} AND received_at >= ?
GROUP BY bucket_idx
""",
[start, bucket_seconds, *params, start],
) as cursor:
rows = await cursor.fetchall()
counts = {int(row["bucket_idx"]): row["cnt"] for row in rows}
return [

View File

@@ -34,65 +34,85 @@ class RawPacketRepository:
# For malformed packets, hash the full data
payload_hash = sha256(data).digest()
cursor = await db.conn.execute(
"INSERT OR IGNORE INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)",
(ts, data, payload_hash),
)
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute(
"INSERT OR IGNORE INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)",
(ts, data, payload_hash),
) as cursor:
rowcount = cursor.rowcount
lastrowid = cursor.lastrowid
if cursor.rowcount > 0:
assert cursor.lastrowid is not None
return (cursor.lastrowid, True)
if rowcount > 0:
assert lastrowid is not None
return (lastrowid, True)
# Duplicate payload — look up the existing row.
cursor = await db.conn.execute(
"SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,)
)
existing = await cursor.fetchone()
# Duplicate payload — look up the existing row (same transaction).
async with conn.execute(
"SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,)
) as cursor:
existing = await cursor.fetchone()
assert existing is not None
return (existing["id"], False)
@staticmethod
async def get_undecrypted_count() -> int:
"""Get count of undecrypted packets (those without a linked message)."""
cursor = await db.conn.execute(
"SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL"
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL"
) as cursor:
row = await cursor.fetchone()
return row["count"] if row else 0
@staticmethod
async def get_oldest_undecrypted() -> int | None:
"""Get timestamp of oldest undecrypted packet, or None if none exist."""
cursor = await db.conn.execute(
"SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL"
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL"
) as cursor:
row = await cursor.fetchone()
return row["oldest"] if row and row["oldest"] is not None else None
@staticmethod
async def _stream_undecrypted_rows(
batch_size: int,
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Internal: keyset-paginated scan of every undecrypted raw packet.
Yields ``(id, data, timestamp)`` for each row across all batches.
Lock is acquired per batch only — concurrent writes can interleave
at batch boundaries rather than being blocked for the full scan.
Each batch opens a fresh cursor and consumes it fully with
``fetchall()`` before releasing, so no prepared statement is alive
at a yield boundary.
``last_id`` advances per row, not per yield, so external filters
(see ``stream_undecrypted_text_messages``) that drop rows do not
cause a re-scan of skipped IDs.
"""
last_id = -1
while True:
async with db.readonly() as conn:
async with conn.execute(
"SELECT id, data, timestamp FROM raw_packets "
"WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?",
(last_id, batch_size),
) as cursor:
rows = await cursor.fetchall()
if not rows:
return
for row in rows:
last_id = row["id"]
yield (row["id"], bytes(row["data"]), row["timestamp"])
@staticmethod
async def stream_all_undecrypted(
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Yield all undecrypted packets as (id, data, timestamp) in bounded batches.
Uses keyset pagination so each batch is a fresh query with a fully
consumed cursor — no open statement held across yield boundaries.
"""
last_id = -1
while True:
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets "
"WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?",
(last_id, batch_size),
)
rows = await cursor.fetchall()
await cursor.close()
if not rows:
break
for row in rows:
last_id = row["id"]
yield (row["id"], bytes(row["data"]), row["timestamp"])
"""Yield all undecrypted packets as (id, data, timestamp) in bounded batches."""
async for row in RawPacketRepository._stream_undecrypted_rows(batch_size):
yield row
@staticmethod
async def stream_undecrypted_text_messages(
@@ -100,26 +120,15 @@ class RawPacketRepository:
) -> AsyncIterator[tuple[int, bytes, int]]:
"""Yield undecrypted TEXT_MESSAGE packets in bounded-size batches.
Uses keyset pagination so each batch is a fresh query with a fully
consumed cursor — no open statement held across yield boundaries.
Filters the shared scan to rows whose payload parses as a text
message. Non-matching rows still advance the keyset cursor so they
aren't re-fetched on subsequent batches.
"""
last_id = -1
while True:
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets "
"WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?",
(last_id, batch_size),
)
rows = await cursor.fetchall()
await cursor.close()
if not rows:
break
for row in rows:
last_id = row["id"]
data = bytes(row["data"])
payload_type = get_packet_payload_type(data)
if payload_type == PayloadType.TEXT_MESSAGE:
yield (row["id"], data, row["timestamp"])
async for packet_id, data, timestamp in RawPacketRepository._stream_undecrypted_rows(
batch_size
):
if get_packet_payload_type(data) == PayloadType.TEXT_MESSAGE:
yield (packet_id, data, timestamp)
@staticmethod
async def count_undecrypted_text_messages(
@@ -136,20 +145,22 @@ class RawPacketRepository:
@staticmethod
async def mark_decrypted(packet_id: int, message_id: int) -> None:
"""Link a raw packet to its decrypted message."""
await db.conn.execute(
"UPDATE raw_packets SET message_id = ? WHERE id = ?",
(message_id, packet_id),
)
await db.conn.commit()
async with db.tx() as conn:
async with conn.execute(
"UPDATE raw_packets SET message_id = ? WHERE id = ?",
(message_id, packet_id),
):
pass
@staticmethod
async def get_linked_message_id(packet_id: int) -> int | None:
"""Return the linked message ID for a raw packet, if any."""
cursor = await db.conn.execute(
"SELECT message_id FROM raw_packets WHERE id = ?",
(packet_id,),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT message_id FROM raw_packets WHERE id = ?",
(packet_id,),
) as cursor:
row = await cursor.fetchone()
if not row:
return None
return row["message_id"]
@@ -157,11 +168,12 @@ class RawPacketRepository:
@staticmethod
async def get_by_id(packet_id: int) -> tuple[int, bytes, int, int | None] | None:
"""Return a raw packet row as (id, data, timestamp, message_id)."""
cursor = await db.conn.execute(
"SELECT id, data, timestamp, message_id FROM raw_packets WHERE id = ?",
(packet_id,),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"SELECT id, data, timestamp, message_id FROM raw_packets WHERE id = ?",
(packet_id,),
) as cursor:
row = await cursor.fetchone()
if not row:
return None
return (row["id"], bytes(row["data"]), row["timestamp"], row["message_id"])
@@ -170,16 +182,20 @@ class RawPacketRepository:
async def prune_old_undecrypted(max_age_days: int) -> int:
"""Delete undecrypted packets older than max_age_days. Returns count deleted."""
cutoff = int(time.time()) - (max_age_days * 86400)
cursor = await db.conn.execute(
"DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?",
(cutoff,),
)
await db.conn.commit()
return cursor.rowcount
async with db.tx() as conn:
async with conn.execute(
"DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?",
(cutoff,),
) as cursor:
rowcount = cursor.rowcount
return rowcount
@staticmethod
async def purge_linked_to_messages() -> int:
"""Delete raw packets that are already linked to a stored message."""
cursor = await db.conn.execute("DELETE FROM raw_packets WHERE message_id IS NOT NULL")
await db.conn.commit()
return cursor.rowcount
async with db.tx() as conn:
async with conn.execute(
"DELETE FROM raw_packets WHERE message_id IS NOT NULL"
) as cursor:
rowcount = cursor.rowcount
return rowcount

View File

@@ -21,51 +21,54 @@ class RepeaterTelemetryRepository:
data: dict,
) -> None:
"""Insert a telemetry history row and prune stale entries."""
await db.conn.execute(
"""
INSERT INTO repeater_telemetry_history
(public_key, timestamp, data)
VALUES (?, ?, ?)
""",
(public_key, timestamp, json.dumps(data)),
)
# Prune entries older than 30 days
cutoff = int(time.time()) - _MAX_AGE_SECONDS
await db.conn.execute(
"DELETE FROM repeater_telemetry_history WHERE public_key = ? AND timestamp < ?",
(public_key, cutoff),
)
async with db.tx() as conn:
async with conn.execute(
"""
INSERT INTO repeater_telemetry_history
(public_key, timestamp, data)
VALUES (?, ?, ?)
""",
(public_key, timestamp, json.dumps(data)),
):
pass
# Cap at _MAX_ENTRIES_PER_REPEATER (keep newest)
await db.conn.execute(
"""
DELETE FROM repeater_telemetry_history
WHERE public_key = ? AND id NOT IN (
SELECT id FROM repeater_telemetry_history
WHERE public_key = ?
ORDER BY timestamp DESC
LIMIT ?
)
""",
(public_key, public_key, _MAX_ENTRIES_PER_REPEATER),
)
# Prune entries older than 30 days
async with conn.execute(
"DELETE FROM repeater_telemetry_history WHERE public_key = ? AND timestamp < ?",
(public_key, cutoff),
):
pass
await db.conn.commit()
# Cap at _MAX_ENTRIES_PER_REPEATER (keep newest)
async with conn.execute(
"""
DELETE FROM repeater_telemetry_history
WHERE public_key = ? AND id NOT IN (
SELECT id FROM repeater_telemetry_history
WHERE public_key = ?
ORDER BY timestamp DESC
LIMIT ?
)
""",
(public_key, public_key, _MAX_ENTRIES_PER_REPEATER),
):
pass
@staticmethod
async def get_history(public_key: str, since_timestamp: int) -> list[dict]:
"""Return telemetry rows for a repeater since a given timestamp, ordered ASC."""
cursor = await db.conn.execute(
"""
SELECT timestamp, data
FROM repeater_telemetry_history
WHERE public_key = ? AND timestamp >= ?
ORDER BY timestamp ASC
""",
(public_key, since_timestamp),
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT timestamp, data
FROM repeater_telemetry_history
WHERE public_key = ? AND timestamp >= ?
ORDER BY timestamp ASC
""",
(public_key, since_timestamp),
) as cursor:
rows = await cursor.fetchall()
return [
{
"timestamp": row["timestamp"],
@@ -77,17 +80,18 @@ class RepeaterTelemetryRepository:
@staticmethod
async def get_latest(public_key: str) -> dict | None:
"""Return the most recent telemetry row for a repeater, or None."""
cursor = await db.conn.execute(
"""
SELECT timestamp, data
FROM repeater_telemetry_history
WHERE public_key = ?
ORDER BY timestamp DESC
LIMIT 1
""",
(public_key,),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT timestamp, data
FROM repeater_telemetry_history
WHERE public_key = ?
ORDER BY timestamp DESC
LIMIT 1
""",
(public_key,),
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
return {

View File

@@ -3,6 +3,8 @@ import logging
import time
from typing import Any
import aiosqlite
from app.database import db
from app.models import AppSettings
from app.path_utils import bucket_path_hash_widths
@@ -17,15 +19,23 @@ SECONDS_7D = 604800
class AppSettingsRepository:
"""Repository for app_settings table (single-row pattern)."""
"""Repository for app_settings table (single-row pattern).
Public methods acquire the DB lock exactly once. ``toggle_*`` helpers that
need a read-modify-write do so inside a single ``db.tx()`` — the internal
``_get_in_conn`` / ``_apply_updates`` helpers run under the caller's
already-held lock and must NEVER call ``db.tx()`` or ``db.readonly()``.
"""
@staticmethod
async def get() -> AppSettings:
"""Get the current app settings.
async def _get_in_conn(conn: aiosqlite.Connection) -> AppSettings:
"""Load settings using an already-acquired connection.
Always returns settings - creates default row if needed (migration handles initial row).
Used by the public ``get()`` and by multi-step operations
(``toggle_blocked_key``, ``toggle_blocked_name``) to avoid re-entering
the non-reentrant DB lock.
"""
cursor = await db.conn.execute(
async with conn.execute(
"""
SELECT max_radio_contacts, auto_decrypt_dm_on_advert,
last_message_times,
@@ -35,8 +45,8 @@ class AppSettingsRepository:
telemetry_interval_hours
FROM app_settings WHERE id = 1
"""
)
row = await cursor.fetchone()
) as cursor:
row = await cursor.fetchone()
if not row:
# Should not happen after migration, but handle gracefully
@@ -119,7 +129,9 @@ class AppSettingsRepository:
)
@staticmethod
async def update(
async def _apply_updates(
conn: aiosqlite.Connection,
*,
max_radio_contacts: int | None = None,
auto_decrypt_dm_on_advert: bool | None = None,
last_message_times: dict[str, int] | None = None,
@@ -132,9 +144,13 @@ class AppSettingsRepository:
tracked_telemetry_repeaters: list[str] | None = None,
auto_resend_channel: bool | None = None,
telemetry_interval_hours: int | None = None,
) -> AppSettings:
"""Update app settings. Only provided fields are updated."""
updates = []
) -> None:
"""Apply field updates using an already-acquired connection.
Emits a single UPDATE statement inside the caller's transaction. Does
NOT commit — the caller's ``db.tx()`` handles that.
"""
updates: list[str] = []
params: list[Any] = []
if max_radio_contacts is not None:
@@ -187,47 +203,101 @@ class AppSettingsRepository:
if updates:
query = f"UPDATE app_settings SET {', '.join(updates)} WHERE id = 1"
await db.conn.execute(query, params)
await db.conn.commit()
async with conn.execute(query, params):
pass
return await AppSettingsRepository.get()
@staticmethod
async def get() -> AppSettings:
"""Get the current app settings.
Always returns settings - creates default row if needed (migration handles initial row).
"""
async with db.readonly() as conn:
return await AppSettingsRepository._get_in_conn(conn)
@staticmethod
async def update(
max_radio_contacts: int | None = None,
auto_decrypt_dm_on_advert: bool | None = None,
last_message_times: dict[str, int] | None = None,
advert_interval: int | None = None,
last_advert_time: int | None = None,
flood_scope: str | None = None,
blocked_keys: list[str] | None = None,
blocked_names: list[str] | None = None,
discovery_blocked_types: list[int] | None = None,
tracked_telemetry_repeaters: list[str] | None = None,
auto_resend_channel: bool | None = None,
telemetry_interval_hours: int | None = None,
) -> AppSettings:
"""Update app settings. Only provided fields are updated."""
async with db.tx() as conn:
await AppSettingsRepository._apply_updates(
conn,
max_radio_contacts=max_radio_contacts,
auto_decrypt_dm_on_advert=auto_decrypt_dm_on_advert,
last_message_times=last_message_times,
advert_interval=advert_interval,
last_advert_time=last_advert_time,
flood_scope=flood_scope,
blocked_keys=blocked_keys,
blocked_names=blocked_names,
discovery_blocked_types=discovery_blocked_types,
tracked_telemetry_repeaters=tracked_telemetry_repeaters,
auto_resend_channel=auto_resend_channel,
telemetry_interval_hours=telemetry_interval_hours,
)
return await AppSettingsRepository._get_in_conn(conn)
@staticmethod
async def toggle_blocked_key(key: str) -> AppSettings:
"""Toggle a public key in the blocked list. Keys are normalized to lowercase."""
"""Toggle a public key in the blocked list. Keys are normalized to lowercase.
Read-modify-write is atomic under a single ``db.tx()`` lock — two
concurrent toggles for the same key cannot produce an inconsistent
intermediate state.
"""
normalized = key.lower()
settings = await AppSettingsRepository.get()
if normalized in settings.blocked_keys:
new_keys = [k for k in settings.blocked_keys if k != normalized]
else:
new_keys = settings.blocked_keys + [normalized]
return await AppSettingsRepository.update(blocked_keys=new_keys)
async with db.tx() as conn:
settings = await AppSettingsRepository._get_in_conn(conn)
if normalized in settings.blocked_keys:
new_keys = [k for k in settings.blocked_keys if k != normalized]
else:
new_keys = settings.blocked_keys + [normalized]
await AppSettingsRepository._apply_updates(conn, blocked_keys=new_keys)
return await AppSettingsRepository._get_in_conn(conn)
@staticmethod
async def toggle_blocked_name(name: str) -> AppSettings:
"""Toggle a display name in the blocked list."""
settings = await AppSettingsRepository.get()
if name in settings.blocked_names:
new_names = [n for n in settings.blocked_names if n != name]
else:
new_names = settings.blocked_names + [name]
return await AppSettingsRepository.update(blocked_names=new_names)
"""Toggle a display name in the blocked list.
Same atomicity guarantee as ``toggle_blocked_key``.
"""
async with db.tx() as conn:
settings = await AppSettingsRepository._get_in_conn(conn)
if name in settings.blocked_names:
new_names = [n for n in settings.blocked_names if n != name]
else:
new_names = settings.blocked_names + [name]
await AppSettingsRepository._apply_updates(conn, blocked_names=new_names)
return await AppSettingsRepository._get_in_conn(conn)
class StatisticsRepository:
@staticmethod
async def get_database_message_totals() -> dict[str, int]:
"""Return message totals needed by lightweight debug surfaces."""
cursor = await db.conn.execute(
"""
SELECT
SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms,
SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages,
SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing
FROM messages
"""
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT
SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms,
SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages,
SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing
FROM messages
"""
) as cursor:
row = await cursor.fetchone()
assert row is not None
return {
"total_dms": row["total_dms"] or 0,
@@ -240,18 +310,19 @@ class StatisticsRepository:
"""Get time-windowed counts for contacts/repeaters heard."""
now = int(time.time())
op = "!=" if exclude else "="
cursor = await db.conn.execute(
f"""
SELECT
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour,
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours,
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week
FROM contacts
WHERE type {op} ? AND last_seen IS NOT NULL
""",
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
f"""
SELECT
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour,
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours,
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week
FROM contacts
WHERE type {op} ? AND last_seen IS NOT NULL
""",
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type),
) as cursor:
row = await cursor.fetchone()
assert row is not None # Aggregate query always returns a row
return {
"last_hour": row["last_hour"] or 0,
@@ -267,24 +338,25 @@ class StatisticsRepository:
the old UPPER(...) join and aggregate per known channel directly.
"""
now = int(time.time())
cursor = await db.conn.execute(
"""
WITH known AS (
SELECT conversation_key, MAX(received_at) AS last_received_at
FROM messages
WHERE type = 'CHAN'
AND conversation_key IN (SELECT key FROM channels)
GROUP BY conversation_key
)
SELECT
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_hour,
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_24_hours,
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_week
FROM known
""",
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D),
)
row = await cursor.fetchone()
async with db.readonly() as conn:
async with conn.execute(
"""
WITH known AS (
SELECT conversation_key, MAX(received_at) AS last_received_at
FROM messages
WHERE type = 'CHAN'
AND conversation_key IN (SELECT key FROM channels)
GROUP BY conversation_key
)
SELECT
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_hour,
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_24_hours,
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_week
FROM known
""",
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D),
) as cursor:
row = await cursor.fetchone()
assert row is not None
return {
"last_hour": row["last_hour"] or 0,
@@ -298,92 +370,105 @@ class StatisticsRepository:
now = int(time.time())
cutoff = now - SECONDS_72H
# Bucket timestamps to the start of each hour
cursor = await db.conn.execute(
"""
SELECT (timestamp / 3600) * 3600 AS hour_ts, COUNT(*) AS count
FROM raw_packets
WHERE timestamp >= ?
GROUP BY hour_ts
ORDER BY hour_ts
""",
(cutoff,),
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"""
SELECT (timestamp / 3600) * 3600 AS hour_ts, COUNT(*) AS count
FROM raw_packets
WHERE timestamp >= ?
GROUP BY hour_ts
ORDER BY hour_ts
""",
(cutoff,),
) as cursor:
rows = await cursor.fetchall()
return [{"timestamp": row["hour_ts"], "count": row["count"]} for row in rows]
@staticmethod
async def _path_hash_width_24h() -> dict[str, int | float]:
"""Count parsed raw packets from the last 24h by hop hash width."""
now = int(time.time())
cursor = await db.conn.execute(
"SELECT data FROM raw_packets WHERE timestamp >= ?",
(now - SECONDS_24H,),
)
rows = await cursor.fetchall()
async with db.readonly() as conn:
async with conn.execute(
"SELECT data FROM raw_packets WHERE timestamp >= ?",
(now - SECONDS_24H,),
) as cursor:
rows = await cursor.fetchall()
return bucket_path_hash_widths(rows)
@staticmethod
async def get_all() -> dict:
"""Aggregate all statistics from existing tables."""
"""Aggregate all statistics from existing tables.
Each helper acquires its own lock; there's no requirement that the
whole snapshot be atomic. If we ever wanted a consistent snapshot
we'd batch all queries into a single ``db.readonly()`` and use
``_in_conn`` helpers, but statistics are intentionally approximate.
"""
now = int(time.time())
# Top 5 busiest channels in last 24h
cursor = await db.conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS message_count
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.received_at >= ?
GROUP BY m.conversation_key
ORDER BY COUNT(*) DESC
LIMIT 5
""",
(now - SECONDS_24H,),
)
rows = await cursor.fetchall()
busiest_channels_24h = [
{
"channel_key": row["conversation_key"],
"channel_name": row["channel_name"],
"message_count": row["message_count"],
}
for row in rows
]
async with db.readonly() as conn:
# Top 5 busiest channels in last 24h
async with conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS message_count
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.received_at >= ?
GROUP BY m.conversation_key
ORDER BY COUNT(*) DESC
LIMIT 5
""",
(now - SECONDS_24H,),
) as cursor:
rows = await cursor.fetchall()
busiest_channels_24h = [
{
"channel_key": row["conversation_key"],
"channel_name": row["channel_name"],
"message_count": row["message_count"],
}
for row in rows
]
# Entity counts
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2")
row = await cursor.fetchone()
assert row is not None
contact_count: int = row["cnt"]
# Entity counts
async with conn.execute(
"SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2"
) as cursor:
row = await cursor.fetchone()
assert row is not None
contact_count: int = row["cnt"]
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2")
row = await cursor.fetchone()
assert row is not None
repeater_count: int = row["cnt"]
async with conn.execute(
"SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2"
) as cursor:
row = await cursor.fetchone()
assert row is not None
repeater_count: int = row["cnt"]
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM channels")
row = await cursor.fetchone()
assert row is not None
channel_count: int = row["cnt"]
async with conn.execute("SELECT COUNT(*) AS cnt FROM channels") as cursor:
row = await cursor.fetchone()
assert row is not None
channel_count: int = row["cnt"]
# Packet split
cursor = await db.conn.execute(
"""
SELECT COUNT(*) AS total,
SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted
FROM raw_packets
"""
)
pkt_row = await cursor.fetchone()
assert pkt_row is not None
total_packets = pkt_row["total"] or 0
decrypted_packets = pkt_row["decrypted"] or 0
undecrypted_packets = total_packets - decrypted_packets
# Packet split
async with conn.execute(
"""
SELECT COUNT(*) AS total,
SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted
FROM raw_packets
"""
) as cursor:
pkt_row = await cursor.fetchone()
assert pkt_row is not None
total_packets = pkt_row["total"] or 0
decrypted_packets = pkt_row["decrypted"] or 0
undecrypted_packets = total_packets - decrypted_packets
# These each acquire their own lock. The snapshot isn't atomic across
# them — fine for stats, which are approximate by nature.
message_totals = await StatisticsRepository.get_database_message_totals()
# Activity windows
contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True)
repeaters_heard = await StatisticsRepository._activity_counts(contact_type=2)
known_channels_active = await StatisticsRepository._known_channels_active()

View File

@@ -28,13 +28,28 @@ def cleanup_test_db_dir():
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
from app.repository import channels, contacts, messages, raw_packets, settings
from app.repository import (
channels,
contacts,
messages,
raw_packets,
repeater_telemetry,
settings,
)
from app.repository import fanout as fanout_repo
db = Database(":memory:")
await db.connect()
submodules = [contacts, channels, messages, raw_packets, settings, fanout_repo]
submodules = [
contacts,
channels,
messages,
raw_packets,
settings,
fanout_repo,
repeater_telemetry,
]
originals = [(mod, mod.db) for mod in submodules]
for mod in submodules:

View File

@@ -322,7 +322,7 @@ class TestUndecryptedTextPacketStreaming:
[],
]
async def fake_execute(*_args, **_kwargs):
def fake_execute(*_args, **_kwargs):
batch = batches.pop(0)
class FakeCursor:
@@ -332,6 +332,16 @@ class TestUndecryptedTextPacketStreaming:
async def close(self):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
# aiosqlite's execute() returns a `contextmanager`-decorated
# coroutine that is both awaitable and usable as an async-with.
# Our repo code now uses `async with conn.execute(...) as cursor:`,
# so the mock just needs to return something with __aenter__/__aexit__.
return FakeCursor()
with patch.object(test_db.conn, "execute", side_effect=fake_execute):

View File

@@ -1,11 +1,12 @@
"""Tests for repository layer."""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import patch
import pytest
from app.models import Contact, ContactUpsert
from app.repository import (
AppSettingsRepository,
ContactAdvertPathRepository,
ContactNameHistoryRepository,
ContactRepository,
@@ -613,37 +614,103 @@ class TestAppSettingsRepository:
"""Test AppSettingsRepository parsing and migration edge cases."""
@pytest.mark.asyncio
async def test_get_handles_corrupted_json_and_invalid_sort_order(self):
"""Corrupted JSON fields are recovered with safe defaults."""
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(
return_value={
"max_radio_contacts": 250,
"auto_decrypt_dm_on_advert": 1,
"last_message_times": "{also-not-json",
"advert_interval": None,
"last_advert_time": None,
"flood_scope": "",
"blocked_keys": "[]",
"blocked_names": "[]",
"discovery_blocked_types": "[]",
}
async def test_get_handles_corrupted_json_and_invalid_sort_order(self, test_db):
"""Corrupted JSON fields are recovered with safe defaults.
Uses the real DB so it exercises the lock-aware path. We stuff
malformed JSON directly into the row, then verify ``get()`` recovers
with defaults rather than propagating a parse error.
"""
await test_db.conn.execute(
"""
UPDATE app_settings
SET max_radio_contacts = 250,
auto_decrypt_dm_on_advert = 1,
last_message_times = '{also-not-json',
advert_interval = NULL,
last_advert_time = NULL,
flood_scope = '',
blocked_keys = '[]',
blocked_names = '[]',
discovery_blocked_types = '[]'
WHERE id = 1
"""
)
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_db = MagicMock()
mock_db.conn = mock_conn
await test_db.conn.commit()
with patch("app.repository.settings.db", mock_db):
from app.repository import AppSettingsRepository
settings = await AppSettingsRepository.get()
settings = await AppSettingsRepository.get()
assert settings.max_radio_contacts == 250
assert settings.last_message_times == {}
assert settings.advert_interval == 0
assert settings.last_advert_time == 0
@pytest.mark.asyncio
async def test_get_in_conn_tolerates_missing_columns(self):
"""Defend against partial migrations where columns added by later
migrations are absent from the row.
Real DBs can't produce this state (schema init + migrations always
run to the latest version on startup), but hand-rolled snapshots,
external DB tools, or interrupted migrations might. The
``KeyError``-catching branches in ``_get_in_conn`` exist specifically
to guarantee graceful degradation.
We test these directly by mocking the connection boundary with a
dict-backed row that mimics a pre-migration snapshot missing:
- ``tracked_telemetry_repeaters`` (migration 53)
- ``auto_resend_channel`` (migration 54)
- ``telemetry_interval_hours`` (migration 57)
"""
from unittest.mock import MagicMock
from app.telemetry_interval import DEFAULT_TELEMETRY_INTERVAL_HOURS
# sqlite3.Row raises KeyError for missing columns when accessed by
# name, which is what we want to simulate. We mimic that here with a
# dict-backed object whose __getitem__ raises KeyError for absent
# keys (dict.__getitem__ already does this).
class PartialRow(dict):
def keys(self): # pragma: no cover - aiosqlite.Row compat
return super().keys()
partial_row = PartialRow(
{
"max_radio_contacts": 123,
"auto_decrypt_dm_on_advert": 1,
"last_message_times": "{}",
"advert_interval": 0,
"last_advert_time": 0,
"flood_scope": "",
"blocked_keys": "[]",
"blocked_names": "[]",
"discovery_blocked_types": "[]",
# intentionally missing: tracked_telemetry_repeaters,
# auto_resend_channel, telemetry_interval_hours
}
)
class FakeCursor:
async def fetchone(self):
return partial_row
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
mock_conn = MagicMock()
mock_conn.execute = MagicMock(return_value=FakeCursor())
settings = await AppSettingsRepository._get_in_conn(mock_conn)
assert settings.max_radio_contacts == 123
# Missing-column defaults kick in:
assert settings.tracked_telemetry_repeaters == []
assert settings.auto_resend_channel is False
assert settings.telemetry_interval_hours == DEFAULT_TELEMETRY_INTERVAL_HOURS
class TestMessageRepositoryGetById:
"""Test MessageRepository.get_by_id method."""

View File

@@ -2,7 +2,7 @@
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import patch
import pytest
@@ -353,13 +353,21 @@ class TestPathHashWidthStats:
@pytest.mark.asyncio
async def test_path_hash_width_scan_fetches_all_then_buckets(self, test_db):
"""Hash-width stats should fetchall() then bucket synchronously."""
"""Hash-width stats should fetchall() then bucket synchronously.
fake_rows = [{"data": b"a"}, {"data": b"b"}, {"data": b"c"}]
Uses real DB rows + a patched parser so it exercises the lock-aware
readonly path. Mocking ``conn.execute`` on the pre-refactor code no
longer reflects the actual call pattern (we use ``async with``).
"""
class FakeCursor:
async def fetchall(self):
return fake_rows
now = int(time.time())
# Seed three raw packets in the last 24h with arbitrary distinguishing bytes.
for i, data in enumerate((b"a", b"b", b"c")):
await test_db.conn.execute(
"INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)",
(now - (i + 1), data),
)
await test_db.conn.commit()
def fake_parse(raw_packet: bytes):
hash_sizes = {
@@ -372,10 +380,7 @@ class TestPathHashWidthStats:
return None
return SimpleNamespace(hash_size=hash_size)
with (
patch.object(test_db.conn, "execute", new=AsyncMock(return_value=FakeCursor())),
patch("app.path_utils.parse_packet_envelope", side_effect=fake_parse),
):
with patch("app.path_utils.parse_packet_envelope", side_effect=fake_parse):
breakdown = await StatisticsRepository._path_hash_width_24h()
assert breakdown["total_packets"] == 3