mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-05-04 20:43:03 +02:00
Work on some more concurrency fixes re: locks and context managers. Poking at #179.
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
@@ -165,9 +168,74 @@ CREATE INDEX IF NOT EXISTS idx_repeater_telemetry_pk_ts
|
||||
|
||||
|
||||
class Database:
|
||||
"""Single-connection aiosqlite wrapper with coroutine-level serialization.
|
||||
|
||||
Why the lock: aiosqlite runs one ``sqlite3.Connection`` on a background
|
||||
worker thread and serializes statement execution there. But SQLite's
|
||||
``COMMIT`` fails with ``OperationalError: cannot commit transaction -
|
||||
SQL statements in progress`` whenever *any* cursor on the connection has
|
||||
a live prepared statement (a ``SELECT`` that returned ``SQLITE_ROW`` but
|
||||
hasn't been fully consumed or closed). Under concurrent coroutines, one
|
||||
task's in-flight ``fetchone()`` can still be in ``SQLITE_ROW`` state when
|
||||
another task's ``commit()`` runs on the worker — triggering the error.
|
||||
|
||||
Fix: all DB work goes through ``tx()`` (writes) or ``readonly()`` (reads),
|
||||
both of which acquire ``self._lock``. The lock is non-reentrant (asyncio
|
||||
default) by design — nested ``tx()`` calls are a bug. Repository methods
|
||||
that compose multiple operations factor the raw SQL into private helpers
|
||||
that take a ``conn`` and don't lock; the public method acquires the lock
|
||||
once and calls those helpers.
|
||||
|
||||
Why reads are also locked: reads must also hold the lock, because a read
|
||||
in ``SQLITE_ROW`` state is precisely the live statement that breaks a
|
||||
concurrent writer's commit. Single-connection aiosqlite cannot safely
|
||||
overlap reads and writes. If we ever split reader/writer connections in
|
||||
the future, ``readonly()`` becomes the seam to point at the reader pool.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self._connection: aiosqlite.Connection | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def tx(self) -> AsyncIterator[aiosqlite.Connection]:
|
||||
"""Acquire the connection for a write transaction.
|
||||
|
||||
Commits on clean exit, rolls back on exception. Callers MUST close
|
||||
every cursor opened inside the block (use ``async with conn.execute(...)
|
||||
as cursor:``) so no prepared statement is alive when commit runs.
|
||||
|
||||
The lock serializes concurrent writers AND ensures no reader's cursor
|
||||
is alive during the commit. Nested calls will deadlock — factor shared
|
||||
SQL into helpers that accept ``conn`` and do not re-enter ``tx()``.
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Database not connected")
|
||||
conn = self._connection
|
||||
try:
|
||||
yield conn
|
||||
except BaseException:
|
||||
await conn.rollback()
|
||||
raise
|
||||
else:
|
||||
await conn.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def readonly(self) -> AsyncIterator[aiosqlite.Connection]:
|
||||
"""Acquire the connection for a read. No commit, no rollback.
|
||||
|
||||
Locked for the same reason writes are: on a single connection, an
|
||||
active read statement blocks a concurrent writer's commit. Callers
|
||||
MUST fully consume or close cursors before the block exits (use
|
||||
``async with conn.execute(...) as cursor:`` + ``fetchall`` /
|
||||
``fetchone``; avoid holding a cursor across ``await`` on other IO).
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Database not connected")
|
||||
yield self._connection
|
||||
|
||||
async def connect(self) -> None:
|
||||
logger.info("Connecting to database at %s", self.db_path)
|
||||
|
||||
@@ -8,31 +8,33 @@ class ChannelRepository:
|
||||
@staticmethod
|
||||
async def upsert(key: str, name: str, is_hashtag: bool = False, on_radio: bool = False) -> None:
|
||||
"""Upsert a channel. Key is 32-char hex string."""
|
||||
await db.conn.execute(
|
||||
"""
|
||||
INSERT INTO channels (key, name, is_hashtag, on_radio, flood_scope_override)
|
||||
VALUES (?, ?, ?, ?, NULL)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
is_hashtag = excluded.is_hashtag,
|
||||
on_radio = excluded.on_radio
|
||||
""",
|
||||
(key.upper(), name, is_hashtag, on_radio),
|
||||
)
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
INSERT INTO channels (key, name, is_hashtag, on_radio, flood_scope_override)
|
||||
VALUES (?, ?, ?, ?, NULL)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
is_hashtag = excluded.is_hashtag,
|
||||
on_radio = excluded.on_radio
|
||||
""",
|
||||
(key.upper(), name, is_hashtag, on_radio),
|
||||
):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def get_by_key(key: str) -> Channel | None:
|
||||
"""Get a channel by its key (32-char hex string)."""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
|
||||
FROM channels
|
||||
WHERE key = ?
|
||||
""",
|
||||
(key.upper(),),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
|
||||
FROM channels
|
||||
WHERE key = ?
|
||||
""",
|
||||
(key.upper(),),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Channel(
|
||||
key=row["key"],
|
||||
@@ -48,14 +50,15 @@ class ChannelRepository:
|
||||
|
||||
@staticmethod
|
||||
async def get_all() -> list[Channel]:
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
|
||||
FROM channels
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT key, name, is_hashtag, on_radio, flood_scope_override, path_hash_mode_override, last_read_at, favorite
|
||||
FROM channels
|
||||
ORDER BY name
|
||||
"""
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Channel(
|
||||
key=row["key"],
|
||||
@@ -73,21 +76,23 @@ class ChannelRepository:
|
||||
@staticmethod
|
||||
async def set_favorite(key: str, value: bool) -> bool:
|
||||
"""Set or clear the favorite flag for a channel. Returns True if row was found."""
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE channels SET favorite = ? WHERE key = ?",
|
||||
(1 if value else 0, key.upper()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"UPDATE channels SET favorite = ? WHERE key = ?",
|
||||
(1 if value else 0, key.upper()),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount > 0
|
||||
|
||||
@staticmethod
|
||||
async def delete(key: str) -> None:
|
||||
"""Delete a channel by key."""
|
||||
await db.conn.execute(
|
||||
"DELETE FROM channels WHERE key = ?",
|
||||
(key.upper(),),
|
||||
)
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"DELETE FROM channels WHERE key = ?",
|
||||
(key.upper(),),
|
||||
):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def update_last_read_at(key: str, timestamp: int | None = None) -> bool:
|
||||
@@ -96,35 +101,39 @@ class ChannelRepository:
|
||||
Returns True if a row was updated, False if channel not found.
|
||||
"""
|
||||
ts = timestamp if timestamp is not None else int(time.time())
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE channels SET last_read_at = ? WHERE key = ?",
|
||||
(ts, key.upper()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"UPDATE channels SET last_read_at = ? WHERE key = ?",
|
||||
(ts, key.upper()),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount > 0
|
||||
|
||||
@staticmethod
|
||||
async def update_flood_scope_override(key: str, flood_scope_override: str | None) -> bool:
|
||||
"""Set or clear a channel's flood-scope override."""
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE channels SET flood_scope_override = ? WHERE key = ?",
|
||||
(flood_scope_override, key.upper()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"UPDATE channels SET flood_scope_override = ? WHERE key = ?",
|
||||
(flood_scope_override, key.upper()),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount > 0
|
||||
|
||||
@staticmethod
|
||||
async def update_path_hash_mode_override(key: str, path_hash_mode_override: int | None) -> bool:
|
||||
"""Set or clear a channel's path hash mode override."""
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE channels SET path_hash_mode_override = ? WHERE key = ?",
|
||||
(path_hash_mode_override, key.upper()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"UPDATE channels SET path_hash_mode_override = ? WHERE key = ?",
|
||||
(path_hash_mode_override, key.upper()),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount > 0
|
||||
|
||||
@staticmethod
|
||||
async def mark_all_read(timestamp: int) -> None:
|
||||
"""Mark all channels as read at the given timestamp."""
|
||||
await db.conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,))
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,)):
|
||||
pass
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,8 @@ import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,26 +33,37 @@ def _row_to_dict(row: Any) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
async def _get_in_conn(conn: aiosqlite.Connection, config_id: str) -> dict[str, Any] | None:
|
||||
"""Fetch a config using an already-acquired connection.
|
||||
|
||||
Used by ``create`` and ``update`` to return the freshly-written row
|
||||
without re-entering the non-reentrant DB lock.
|
||||
"""
|
||||
async with conn.execute("SELECT * FROM fanout_configs WHERE id = ?", (config_id,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
class FanoutConfigRepository:
|
||||
"""CRUD operations for fanout_configs table."""
|
||||
|
||||
@staticmethod
|
||||
async def get_all() -> list[dict[str, Any]]:
|
||||
"""Get all fanout configs ordered by sort_order."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT * FROM fanout_configs ORDER BY sort_order, created_at"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT * FROM fanout_configs ORDER BY sort_order, created_at"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
async def get(config_id: str) -> dict[str, Any] | None:
|
||||
"""Get a single fanout config by ID."""
|
||||
cursor = await db.conn.execute("SELECT * FROM fanout_configs WHERE id = ?", (config_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_dict(row)
|
||||
async with db.readonly() as conn:
|
||||
return await _get_in_conn(conn, config_id)
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
@@ -65,39 +78,41 @@ class FanoutConfigRepository:
|
||||
new_id = config_id or str(uuid.uuid4())
|
||||
now = int(time.time())
|
||||
|
||||
# Get next sort_order
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT COALESCE(MAX(sort_order), -1) + 1 FROM fanout_configs"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
sort_order = row[0] if row else 0
|
||||
async with db.tx() as conn:
|
||||
# Determine next sort_order under the same lock as the insert,
|
||||
# so two concurrent ``create()`` calls cannot collide.
|
||||
async with conn.execute(
|
||||
"SELECT COALESCE(MAX(sort_order), -1) + 1 FROM fanout_configs"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
sort_order = row[0] if row else 0
|
||||
|
||||
await db.conn.execute(
|
||||
"""
|
||||
INSERT INTO fanout_configs (id, type, name, enabled, config, scope, sort_order, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
new_id,
|
||||
config_type,
|
||||
name,
|
||||
1 if enabled else 0,
|
||||
json.dumps(config),
|
||||
json.dumps(scope),
|
||||
sort_order,
|
||||
now,
|
||||
),
|
||||
)
|
||||
await db.conn.commit()
|
||||
async with conn.execute(
|
||||
"""
|
||||
INSERT INTO fanout_configs (id, type, name, enabled, config, scope, sort_order, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
new_id,
|
||||
config_type,
|
||||
name,
|
||||
1 if enabled else 0,
|
||||
json.dumps(config),
|
||||
json.dumps(scope),
|
||||
sort_order,
|
||||
now,
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
result = await FanoutConfigRepository.get(new_id)
|
||||
result = await _get_in_conn(conn, new_id)
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def update(config_id: str, **fields: Any) -> dict[str, Any] | None:
|
||||
"""Update a fanout config. Only provided fields are updated."""
|
||||
updates = []
|
||||
updates: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
for field in ("name", "enabled", "config", "scope", "sort_order"):
|
||||
@@ -115,23 +130,25 @@ class FanoutConfigRepository:
|
||||
|
||||
params.append(config_id)
|
||||
query = f"UPDATE fanout_configs SET {', '.join(updates)} WHERE id = ?"
|
||||
await db.conn.execute(query, params)
|
||||
await db.conn.commit()
|
||||
|
||||
return await FanoutConfigRepository.get(config_id)
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(query, params):
|
||||
pass
|
||||
return await _get_in_conn(conn, config_id)
|
||||
|
||||
@staticmethod
|
||||
async def delete(config_id: str) -> None:
|
||||
"""Delete a fanout config."""
|
||||
await db.conn.execute("DELETE FROM fanout_configs WHERE id = ?", (config_id,))
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute("DELETE FROM fanout_configs WHERE id = ?", (config_id,)):
|
||||
pass
|
||||
_configs_cache.pop(config_id, None)
|
||||
|
||||
@staticmethod
|
||||
async def get_enabled() -> list[dict[str, Any]]:
|
||||
"""Get all enabled fanout configs."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT * FROM fanout_configs WHERE enabled = 1 ORDER BY sort_order, created_at"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT * FROM fanout_configs WHERE enabled = 1 ORDER BY sort_order, created_at"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
@@ -89,32 +89,34 @@ class MessageRepository:
|
||||
# Normalize sender_key to lowercase so queries can match without LOWER().
|
||||
normalized_sender_key = sender_key.lower() if sender_key else sender_key
|
||||
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp,
|
||||
received_at, paths, txt_type, signature, outgoing,
|
||||
sender_name, sender_key)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
msg_type,
|
||||
conversation_key,
|
||||
text,
|
||||
sender_timestamp,
|
||||
received_at,
|
||||
paths_json,
|
||||
txt_type,
|
||||
signature,
|
||||
outgoing,
|
||||
sender_name,
|
||||
normalized_sender_key,
|
||||
),
|
||||
)
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp,
|
||||
received_at, paths, txt_type, signature, outgoing,
|
||||
sender_name, sender_key)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
msg_type,
|
||||
conversation_key,
|
||||
text,
|
||||
sender_timestamp,
|
||||
received_at,
|
||||
paths_json,
|
||||
txt_type,
|
||||
signature,
|
||||
outgoing,
|
||||
sender_name,
|
||||
normalized_sender_key,
|
||||
),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
lastrowid = cursor.lastrowid
|
||||
# rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation
|
||||
if cursor.rowcount == 0:
|
||||
if rowcount == 0:
|
||||
return None
|
||||
return cursor.lastrowid
|
||||
return lastrowid
|
||||
|
||||
@staticmethod
|
||||
async def add_path(
|
||||
@@ -142,17 +144,20 @@ class MessageRepository:
|
||||
if snr is not None:
|
||||
entry["snr"] = snr
|
||||
new_entry = json.dumps(entry)
|
||||
await db.conn.execute(
|
||||
"""UPDATE messages SET paths = json_insert(
|
||||
COALESCE(paths, '[]'), '$[#]', json(?)
|
||||
) WHERE id = ?""",
|
||||
(new_entry, message_id),
|
||||
)
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"""UPDATE messages SET paths = json_insert(
|
||||
COALESCE(paths, '[]'), '$[#]', json(?)
|
||||
) WHERE id = ?""",
|
||||
(new_entry, message_id),
|
||||
):
|
||||
pass
|
||||
|
||||
# Read back the full list for the return value
|
||||
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
|
||||
row = await cursor.fetchone()
|
||||
# Read back the full list for the return value, same transaction.
|
||||
async with conn.execute(
|
||||
"SELECT paths FROM messages WHERE id = ?", (message_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row or not row["paths"]:
|
||||
return []
|
||||
|
||||
@@ -171,23 +176,24 @@ class MessageRepository:
|
||||
only a prefix as conversation_key are updated to use the full key.
|
||||
"""
|
||||
lower_key = full_key.lower()
|
||||
cursor = await db.conn.execute(
|
||||
"""UPDATE messages SET conversation_key = ?,
|
||||
sender_key = CASE
|
||||
WHEN sender_key IS NOT NULL AND length(sender_key) < 64
|
||||
AND ? LIKE sender_key || '%'
|
||||
THEN ? ELSE sender_key END
|
||||
WHERE type = 'PRIV' AND length(conversation_key) < 64
|
||||
AND ? LIKE conversation_key || '%'
|
||||
AND (
|
||||
SELECT COUNT(*) FROM contacts
|
||||
WHERE length(public_key) = 64
|
||||
AND public_key LIKE messages.conversation_key || '%'
|
||||
) = 1""",
|
||||
(lower_key, lower_key, lower_key, lower_key),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"""UPDATE messages SET conversation_key = ?,
|
||||
sender_key = CASE
|
||||
WHEN sender_key IS NOT NULL AND length(sender_key) < 64
|
||||
AND ? LIKE sender_key || '%'
|
||||
THEN ? ELSE sender_key END
|
||||
WHERE type = 'PRIV' AND length(conversation_key) < 64
|
||||
AND ? LIKE conversation_key || '%'
|
||||
AND (
|
||||
SELECT COUNT(*) FROM contacts
|
||||
WHERE length(public_key) = 64
|
||||
AND public_key LIKE messages.conversation_key || '%'
|
||||
) = 1""",
|
||||
(lower_key, lower_key, lower_key, lower_key),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount
|
||||
|
||||
@staticmethod
|
||||
async def backfill_channel_sender_key(public_key: str, name: str) -> int:
|
||||
@@ -197,21 +203,22 @@ class MessageRepository:
|
||||
any channel messages with a matching sender_name but no sender_key
|
||||
are updated to associate them with this contact's public key.
|
||||
"""
|
||||
cursor = await db.conn.execute(
|
||||
"""UPDATE messages SET sender_key = ?
|
||||
WHERE type = 'CHAN' AND sender_name = ? AND sender_key IS NULL
|
||||
AND (
|
||||
SELECT COUNT(*) FROM contacts
|
||||
WHERE name = ?
|
||||
) = 1
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM contacts
|
||||
WHERE public_key = ? AND name = ?
|
||||
)""",
|
||||
(public_key.lower(), name, name, public_key.lower(), name),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"""UPDATE messages SET sender_key = ?
|
||||
WHERE type = 'CHAN' AND sender_name = ? AND sender_key IS NULL
|
||||
AND (
|
||||
SELECT COUNT(*) FROM contacts
|
||||
WHERE name = ?
|
||||
) = 1
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM contacts
|
||||
WHERE public_key = ? AND name = ?
|
||||
)""",
|
||||
(public_key.lower(), name, name, public_key.lower(), name),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount
|
||||
|
||||
@staticmethod
|
||||
def _normalize_conversation_key(conversation_key: str) -> tuple[str, str]:
|
||||
@@ -462,8 +469,9 @@ class MessageRepository:
|
||||
query += " OFFSET ?"
|
||||
params.append(offset)
|
||||
|
||||
cursor = await db.conn.execute(query, params)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [MessageRepository._row_to_message(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
@@ -501,51 +509,54 @@ class MessageRepository:
|
||||
where_sql = " AND ".join(["1=1", *where_parts])
|
||||
|
||||
# 1. Get the target message (must satisfy filters if provided)
|
||||
target_cursor = await db.conn.execute(
|
||||
f"SELECT {MessageRepository._message_select('messages')} "
|
||||
f"FROM messages WHERE id = ? AND {where_sql}",
|
||||
(message_id, *base_params),
|
||||
)
|
||||
target_row = await target_cursor.fetchone()
|
||||
if not target_row:
|
||||
return [], False, False
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
f"SELECT {MessageRepository._message_select('messages')} "
|
||||
f"FROM messages WHERE id = ? AND {where_sql}",
|
||||
(message_id, *base_params),
|
||||
) as target_cursor:
|
||||
target_row = await target_cursor.fetchone()
|
||||
if not target_row:
|
||||
return [], False, False
|
||||
|
||||
target = MessageRepository._row_to_message(target_row)
|
||||
target = MessageRepository._row_to_message(target_row)
|
||||
|
||||
# 2. Get context_size+1 messages before target (DESC)
|
||||
before_query = f"""
|
||||
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
|
||||
AND (received_at < ? OR (received_at = ? AND id < ?))
|
||||
ORDER BY received_at DESC, id DESC LIMIT ?
|
||||
"""
|
||||
before_params = [
|
||||
*base_params,
|
||||
target.received_at,
|
||||
target.received_at,
|
||||
target.id,
|
||||
context_size + 1,
|
||||
]
|
||||
before_cursor = await db.conn.execute(before_query, before_params)
|
||||
before_rows = list(await before_cursor.fetchall())
|
||||
# 2. Get context_size+1 messages before target (DESC)
|
||||
before_query = f"""
|
||||
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
|
||||
AND (received_at < ? OR (received_at = ? AND id < ?))
|
||||
ORDER BY received_at DESC, id DESC LIMIT ?
|
||||
"""
|
||||
before_params = [
|
||||
*base_params,
|
||||
target.received_at,
|
||||
target.received_at,
|
||||
target.id,
|
||||
context_size + 1,
|
||||
]
|
||||
async with conn.execute(before_query, before_params) as before_cursor:
|
||||
before_rows = list(await before_cursor.fetchall())
|
||||
|
||||
has_older = len(before_rows) > context_size
|
||||
before_messages = [MessageRepository._row_to_message(r) for r in before_rows[:context_size]]
|
||||
has_older = len(before_rows) > context_size
|
||||
before_messages = [
|
||||
MessageRepository._row_to_message(r) for r in before_rows[:context_size]
|
||||
]
|
||||
|
||||
# 3. Get context_size+1 messages after target (ASC)
|
||||
after_query = f"""
|
||||
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
|
||||
AND (received_at > ? OR (received_at = ? AND id > ?))
|
||||
ORDER BY received_at ASC, id ASC LIMIT ?
|
||||
"""
|
||||
after_params = [
|
||||
*base_params,
|
||||
target.received_at,
|
||||
target.received_at,
|
||||
target.id,
|
||||
context_size + 1,
|
||||
]
|
||||
after_cursor = await db.conn.execute(after_query, after_params)
|
||||
after_rows = list(await after_cursor.fetchall())
|
||||
# 3. Get context_size+1 messages after target (ASC)
|
||||
after_query = f"""
|
||||
SELECT {MessageRepository._message_select("messages")} FROM messages WHERE {where_sql}
|
||||
AND (received_at > ? OR (received_at = ? AND id > ?))
|
||||
ORDER BY received_at ASC, id ASC LIMIT ?
|
||||
"""
|
||||
after_params = [
|
||||
*base_params,
|
||||
target.received_at,
|
||||
target.received_at,
|
||||
target.id,
|
||||
context_size + 1,
|
||||
]
|
||||
async with conn.execute(after_query, after_params) as after_cursor:
|
||||
after_rows = list(await after_cursor.fetchall())
|
||||
|
||||
has_newer = len(after_rows) > context_size
|
||||
after_messages = [MessageRepository._row_to_message(r) for r in after_rows[:context_size]]
|
||||
@@ -556,21 +567,29 @@ class MessageRepository:
|
||||
|
||||
@staticmethod
|
||||
async def increment_ack_count(message_id: int) -> int:
|
||||
"""Increment ack count and return the new value."""
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE messages SET acked = acked + 1 WHERE id = ? RETURNING acked", (message_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
await db.conn.commit()
|
||||
"""Increment ack count and return the new value.
|
||||
|
||||
NOTE: ``RETURNING`` leaves the prepared statement active until the
|
||||
row is fetched, so we MUST consume it inside the ``async with``
|
||||
block. Without that, the commit at the end of ``db.tx()`` fails
|
||||
with ``cannot commit transaction - SQL statements in progress``.
|
||||
"""
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"UPDATE messages SET acked = acked + 1 WHERE id = ? RETURNING acked",
|
||||
(message_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["acked"] if row else 1
|
||||
|
||||
@staticmethod
|
||||
async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]:
|
||||
"""Get the current ack count and paths for a message."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT acked, paths FROM messages WHERE id = ?", (message_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT acked, paths FROM messages WHERE id = ?", (message_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0, None
|
||||
return row["acked"], MessageRepository._parse_paths(row["paths"])
|
||||
@@ -578,11 +597,12 @@ class MessageRepository:
|
||||
@staticmethod
|
||||
async def get_by_id(message_id: int) -> "Message | None":
|
||||
"""Look up a message by its ID."""
|
||||
cursor = await db.conn.execute(
|
||||
f"SELECT {MessageRepository._message_select('messages')} FROM messages WHERE id = ?",
|
||||
(message_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
f"SELECT {MessageRepository._message_select('messages')} FROM messages WHERE id = ?",
|
||||
(message_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
@@ -591,11 +611,14 @@ class MessageRepository:
|
||||
@staticmethod
|
||||
async def delete_by_id(message_id: int) -> None:
|
||||
"""Delete a message row by ID."""
|
||||
await db.conn.execute(
|
||||
"UPDATE raw_packets SET message_id = NULL WHERE message_id = ?", (message_id,)
|
||||
)
|
||||
await db.conn.execute("DELETE FROM messages WHERE id = ?", (message_id,))
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"UPDATE raw_packets SET message_id = NULL WHERE message_id = ?",
|
||||
(message_id,),
|
||||
):
|
||||
pass
|
||||
async with conn.execute("DELETE FROM messages WHERE id = ?", (message_id,)):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def get_by_content(
|
||||
@@ -618,8 +641,9 @@ class MessageRepository:
|
||||
query += " AND outgoing = ?"
|
||||
params.append(1 if outgoing else 0)
|
||||
query += " ORDER BY id ASC"
|
||||
cursor = await db.conn.execute(query, params)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(query, params) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
@@ -653,76 +677,6 @@ class MessageRepository:
|
||||
)
|
||||
blocked_sql = f" AND {blocked_clause}" if blocked_clause else ""
|
||||
|
||||
# Channel unreads
|
||||
cursor = await db.conn.execute(
|
||||
f"""
|
||||
SELECT m.conversation_key,
|
||||
COUNT(*) as unread_count,
|
||||
SUM(CASE
|
||||
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
|
||||
ELSE 0
|
||||
END) > 0 as has_mention
|
||||
FROM messages m
|
||||
JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.outgoing = 0
|
||||
AND m.received_at > COALESCE(c.last_read_at, 0)
|
||||
{blocked_sql}
|
||||
GROUP BY m.conversation_key
|
||||
""",
|
||||
(mention_token or "", mention_token or "", *blocked_params),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
state_key = f"channel-{row['conversation_key']}"
|
||||
counts[state_key] = row["unread_count"]
|
||||
if mention_token and row["has_mention"]:
|
||||
mention_flags[state_key] = True
|
||||
|
||||
# Contact unreads
|
||||
cursor = await db.conn.execute(
|
||||
f"""
|
||||
SELECT m.conversation_key,
|
||||
COUNT(*) as unread_count,
|
||||
SUM(CASE
|
||||
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
|
||||
ELSE 0
|
||||
END) > 0 as has_mention
|
||||
FROM messages m
|
||||
LEFT JOIN contacts ct ON m.conversation_key = ct.public_key
|
||||
WHERE m.type = 'PRIV' AND m.outgoing = 0
|
||||
AND m.received_at > COALESCE(ct.last_read_at, 0)
|
||||
{blocked_sql}
|
||||
GROUP BY m.conversation_key
|
||||
""",
|
||||
(mention_token or "", mention_token or "", *blocked_params),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
state_key = f"contact-{row['conversation_key']}"
|
||||
counts[state_key] = row["unread_count"]
|
||||
if mention_token and row["has_mention"]:
|
||||
mention_flags[state_key] = True
|
||||
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT key, last_read_at
|
||||
FROM channels
|
||||
"""
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
last_read_ats[f"channel-{row['key']}"] = row["last_read_at"]
|
||||
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT public_key, last_read_at
|
||||
FROM contacts
|
||||
"""
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
last_read_ats[f"contact-{row['public_key']}"] = row["last_read_at"]
|
||||
|
||||
# Last message times for all conversations (including read ones),
|
||||
# excluding blocked incoming traffic so refresh matches live WS behavior.
|
||||
last_time_clause, last_time_params = MessageRepository._build_blocked_incoming_clause(
|
||||
@@ -730,20 +684,94 @@ class MessageRepository:
|
||||
)
|
||||
last_time_where_sql = f"WHERE {last_time_clause}" if last_time_clause else ""
|
||||
|
||||
cursor = await db.conn.execute(
|
||||
f"""
|
||||
SELECT type, conversation_key, MAX(received_at) as last_message_time
|
||||
FROM messages
|
||||
{last_time_where_sql}
|
||||
GROUP BY type, conversation_key
|
||||
""",
|
||||
last_time_params,
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
prefix = "channel" if row["type"] == "CHAN" else "contact"
|
||||
state_key = f"{prefix}-{row['conversation_key']}"
|
||||
last_message_times[state_key] = row["last_message_time"]
|
||||
# Single readonly acquisition for all 5 queries — they form one logical
|
||||
# snapshot, and holding the lock for the batch is cheaper than acquiring
|
||||
# it 5 times.
|
||||
async with db.readonly() as conn:
|
||||
# Channel unreads
|
||||
async with conn.execute(
|
||||
f"""
|
||||
SELECT m.conversation_key,
|
||||
COUNT(*) as unread_count,
|
||||
SUM(CASE
|
||||
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
|
||||
ELSE 0
|
||||
END) > 0 as has_mention
|
||||
FROM messages m
|
||||
JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.outgoing = 0
|
||||
AND m.received_at > COALESCE(c.last_read_at, 0)
|
||||
{blocked_sql}
|
||||
GROUP BY m.conversation_key
|
||||
""",
|
||||
(mention_token or "", mention_token or "", *blocked_params),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
state_key = f"channel-{row['conversation_key']}"
|
||||
counts[state_key] = row["unread_count"]
|
||||
if mention_token and row["has_mention"]:
|
||||
mention_flags[state_key] = True
|
||||
|
||||
# Contact unreads
|
||||
async with conn.execute(
|
||||
f"""
|
||||
SELECT m.conversation_key,
|
||||
COUNT(*) as unread_count,
|
||||
SUM(CASE
|
||||
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
|
||||
ELSE 0
|
||||
END) > 0 as has_mention
|
||||
FROM messages m
|
||||
LEFT JOIN contacts ct ON m.conversation_key = ct.public_key
|
||||
WHERE m.type = 'PRIV' AND m.outgoing = 0
|
||||
AND m.received_at > COALESCE(ct.last_read_at, 0)
|
||||
{blocked_sql}
|
||||
GROUP BY m.conversation_key
|
||||
""",
|
||||
(mention_token or "", mention_token or "", *blocked_params),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
state_key = f"contact-{row['conversation_key']}"
|
||||
counts[state_key] = row["unread_count"]
|
||||
if mention_token and row["has_mention"]:
|
||||
mention_flags[state_key] = True
|
||||
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT key, last_read_at
|
||||
FROM channels
|
||||
"""
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
last_read_ats[f"channel-{row['key']}"] = row["last_read_at"]
|
||||
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT public_key, last_read_at
|
||||
FROM contacts
|
||||
"""
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
last_read_ats[f"contact-{row['public_key']}"] = row["last_read_at"]
|
||||
|
||||
async with conn.execute(
|
||||
f"""
|
||||
SELECT type, conversation_key, MAX(received_at) as last_message_time
|
||||
FROM messages
|
||||
{last_time_where_sql}
|
||||
GROUP BY type, conversation_key
|
||||
""",
|
||||
last_time_params,
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
prefix = "channel" if row["type"] == "CHAN" else "contact"
|
||||
state_key = f"{prefix}-{row['conversation_key']}"
|
||||
last_message_times[state_key] = row["last_message_time"]
|
||||
|
||||
# Only include last_read_ats for conversations that actually have messages.
|
||||
# Without this filter, every contact heard via advertisement (even without
|
||||
@@ -760,41 +788,45 @@ class MessageRepository:
|
||||
@staticmethod
|
||||
async def count_dm_messages(contact_key: str) -> int:
|
||||
"""Count total DM messages for a contact."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?",
|
||||
(contact_key.lower(),),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?",
|
||||
(contact_key.lower(),),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["cnt"] if row else 0
|
||||
|
||||
@staticmethod
|
||||
async def count_channel_messages_by_sender(sender_key: str) -> int:
|
||||
"""Count channel messages sent by a specific contact."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?",
|
||||
(sender_key.lower(),),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?",
|
||||
(sender_key.lower(),),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["cnt"] if row else 0
|
||||
|
||||
@staticmethod
|
||||
async def count_channel_messages_by_sender_name(sender_name: str) -> int:
|
||||
"""Count channel messages attributed to a display name."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_name = ?",
|
||||
(sender_name,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_name = ?",
|
||||
(sender_name,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["cnt"] if row else 0
|
||||
|
||||
@staticmethod
|
||||
async def get_first_channel_message_by_sender_name(sender_name: str) -> int | None:
|
||||
"""Get the earliest stored channel message timestamp for a display name."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT MIN(received_at) AS first_seen FROM messages WHERE type = 'CHAN' AND sender_name = ?",
|
||||
(sender_name,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT MIN(received_at) AS first_seen FROM messages WHERE type = 'CHAN' AND sender_name = ?",
|
||||
(sender_name,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["first_seen"] if row and row["first_seen"] is not None else None
|
||||
|
||||
@staticmethod
|
||||
@@ -813,68 +845,76 @@ class MessageRepository:
|
||||
t_48h = now - 172800
|
||||
t_7d = now - 604800
|
||||
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS all_time,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_1h,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_24h,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_48h,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_7d,
|
||||
MIN(received_at) AS first_message_at,
|
||||
COUNT(DISTINCT sender_key) AS unique_sender_count
|
||||
FROM messages WHERE type = 'CHAN' AND conversation_key = ?
|
||||
""",
|
||||
(t_1h, t_24h, t_48h, t_7d, conversation_key),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None # Aggregate query always returns a row
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS all_time,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_1h,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_24h,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_48h,
|
||||
SUM(CASE WHEN received_at >= ? THEN 1 ELSE 0 END) AS last_7d,
|
||||
MIN(received_at) AS first_message_at,
|
||||
COUNT(DISTINCT sender_key) AS unique_sender_count
|
||||
FROM messages WHERE type = 'CHAN' AND conversation_key = ?
|
||||
""",
|
||||
(t_1h, t_24h, t_48h, t_7d, conversation_key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None # Aggregate query always returns a row
|
||||
|
||||
message_counts = {
|
||||
"last_1h": row["last_1h"] or 0,
|
||||
"last_24h": row["last_24h"] or 0,
|
||||
"last_48h": row["last_48h"] or 0,
|
||||
"last_7d": row["last_7d"] or 0,
|
||||
"all_time": row["all_time"] or 0,
|
||||
}
|
||||
|
||||
cursor2 = await db.conn.execute(
|
||||
"""
|
||||
SELECT COALESCE(sender_name, sender_key, 'Unknown') AS display_name,
|
||||
sender_key, COUNT(*) AS cnt
|
||||
FROM messages
|
||||
WHERE type = 'CHAN' AND conversation_key = ?
|
||||
AND received_at >= ? AND sender_key IS NOT NULL
|
||||
GROUP BY sender_key ORDER BY cnt DESC LIMIT 5
|
||||
""",
|
||||
(conversation_key, t_24h),
|
||||
)
|
||||
top_rows = await cursor2.fetchall()
|
||||
top_senders = [
|
||||
{
|
||||
"sender_name": r["display_name"],
|
||||
"sender_key": r["sender_key"],
|
||||
"message_count": r["cnt"],
|
||||
message_counts = {
|
||||
"last_1h": row["last_1h"] or 0,
|
||||
"last_24h": row["last_24h"] or 0,
|
||||
"last_48h": row["last_48h"] or 0,
|
||||
"last_7d": row["last_7d"] or 0,
|
||||
"all_time": row["all_time"] or 0,
|
||||
}
|
||||
for r in top_rows
|
||||
]
|
||||
|
||||
# Path hash width distribution for last 24h (in-Python parse of raw packet envelopes)
|
||||
cursor3 = await db.conn.execute(
|
||||
"""
|
||||
SELECT rp.data FROM raw_packets rp
|
||||
JOIN messages m ON rp.message_id = m.id
|
||||
WHERE m.type = 'CHAN' AND m.conversation_key = ?
|
||||
AND rp.timestamp >= ?
|
||||
""",
|
||||
(conversation_key, t_24h),
|
||||
)
|
||||
rows3 = await cursor3.fetchall()
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT COALESCE(sender_name, sender_key, 'Unknown') AS display_name,
|
||||
sender_key, COUNT(*) AS cnt
|
||||
FROM messages
|
||||
WHERE type = 'CHAN' AND conversation_key = ?
|
||||
AND received_at >= ? AND sender_key IS NOT NULL
|
||||
GROUP BY sender_key ORDER BY cnt DESC LIMIT 5
|
||||
""",
|
||||
(conversation_key, t_24h),
|
||||
) as cursor:
|
||||
top_rows = await cursor.fetchall()
|
||||
top_senders = [
|
||||
{
|
||||
"sender_name": r["display_name"],
|
||||
"sender_key": r["sender_key"],
|
||||
"message_count": r["cnt"],
|
||||
}
|
||||
for r in top_rows
|
||||
]
|
||||
|
||||
# Path hash width distribution for last 24h: fetch raw rows under
|
||||
# the lock, then release BEFORE the CPU-bound in-Python envelope
|
||||
# parse. Parsing can iterate thousands of rows and previously held
|
||||
# the DB lock for the whole traversal — blocking every other repo
|
||||
# caller on a Pi. Keep the lock only for the fetch.
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT rp.data FROM raw_packets rp
|
||||
JOIN messages m ON rp.message_id = m.id
|
||||
WHERE m.type = 'CHAN' AND m.conversation_key = ?
|
||||
AND rp.timestamp >= ?
|
||||
""",
|
||||
(conversation_key, t_24h),
|
||||
) as cursor:
|
||||
rows3 = await cursor.fetchall()
|
||||
first_message_at = row["first_message_at"]
|
||||
unique_sender_count = row["unique_sender_count"] or 0
|
||||
|
||||
path_hash_width_24h = bucket_path_hash_widths(rows3)
|
||||
|
||||
return {
|
||||
"message_counts": message_counts,
|
||||
"first_message_at": row["first_message_at"],
|
||||
"unique_sender_count": row["unique_sender_count"] or 0,
|
||||
"first_message_at": first_message_at,
|
||||
"unique_sender_count": unique_sender_count,
|
||||
"top_senders_24h": top_senders,
|
||||
"path_hash_width_24h": path_hash_width_24h,
|
||||
}
|
||||
@@ -882,14 +922,15 @@ class MessageRepository:
|
||||
@staticmethod
|
||||
async def count_channels_with_incoming_messages() -> int:
|
||||
"""Count distinct channel conversations with at least one incoming message."""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT COUNT(DISTINCT conversation_key) AS cnt
|
||||
FROM messages
|
||||
WHERE type = 'CHAN' AND outgoing = 0
|
||||
"""
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT COUNT(DISTINCT conversation_key) AS cnt
|
||||
FROM messages
|
||||
WHERE type = 'CHAN' AND outgoing = 0
|
||||
"""
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return int(row["cnt"]) if row and row["cnt"] is not None else 0
|
||||
|
||||
@staticmethod
|
||||
@@ -898,20 +939,21 @@ class MessageRepository:
|
||||
|
||||
Returns list of (channel_key, channel_name, message_count) tuples.
|
||||
"""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
|
||||
COUNT(*) AS cnt
|
||||
FROM messages m
|
||||
LEFT JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.sender_key = ?
|
||||
GROUP BY m.conversation_key
|
||||
ORDER BY cnt DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(sender_key.lower(), limit),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
|
||||
COUNT(*) AS cnt
|
||||
FROM messages m
|
||||
LEFT JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.sender_key = ?
|
||||
GROUP BY m.conversation_key
|
||||
ORDER BY cnt DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(sender_key.lower(), limit),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
@@ -919,34 +961,36 @@ class MessageRepository:
|
||||
sender_name: str, limit: int = 5
|
||||
) -> list[tuple[str, str, int]]:
|
||||
"""Get channels where a display name has sent the most messages."""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
|
||||
COUNT(*) AS cnt
|
||||
FROM messages m
|
||||
LEFT JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.sender_name = ?
|
||||
GROUP BY m.conversation_key
|
||||
ORDER BY cnt DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(sender_name, limit),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
|
||||
COUNT(*) AS cnt
|
||||
FROM messages m
|
||||
LEFT JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.sender_name = ?
|
||||
GROUP BY m.conversation_key
|
||||
ORDER BY cnt DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(sender_name, limit),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
async def _get_activity_hour_buckets(where_sql: str, params: list[Any]) -> dict[int, int]:
|
||||
cursor = await db.conn.execute(
|
||||
f"""
|
||||
SELECT received_at / 3600 AS hour_bucket, COUNT(*) AS cnt
|
||||
FROM messages
|
||||
WHERE {where_sql}
|
||||
GROUP BY hour_bucket
|
||||
""",
|
||||
params,
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
f"""
|
||||
SELECT received_at / 3600 AS hour_bucket, COUNT(*) AS cnt
|
||||
FROM messages
|
||||
WHERE {where_sql}
|
||||
GROUP BY hour_bucket
|
||||
""",
|
||||
params,
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return {int(row["hour_bucket"]): row["cnt"] for row in rows}
|
||||
|
||||
@staticmethod
|
||||
@@ -1000,16 +1044,17 @@ class MessageRepository:
|
||||
current_day_start = (now // 86400) * 86400
|
||||
start = current_day_start - (weeks - 1) * bucket_seconds
|
||||
|
||||
cursor = await db.conn.execute(
|
||||
f"""
|
||||
SELECT (received_at - ?) / ? AS bucket_idx, COUNT(*) AS cnt
|
||||
FROM messages
|
||||
WHERE {where_sql} AND received_at >= ?
|
||||
GROUP BY bucket_idx
|
||||
""",
|
||||
[start, bucket_seconds, *params, start],
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
f"""
|
||||
SELECT (received_at - ?) / ? AS bucket_idx, COUNT(*) AS cnt
|
||||
FROM messages
|
||||
WHERE {where_sql} AND received_at >= ?
|
||||
GROUP BY bucket_idx
|
||||
""",
|
||||
[start, bucket_seconds, *params, start],
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
counts = {int(row["bucket_idx"]): row["cnt"] for row in rows}
|
||||
|
||||
return [
|
||||
|
||||
@@ -34,65 +34,85 @@ class RawPacketRepository:
|
||||
# For malformed packets, hash the full data
|
||||
payload_hash = sha256(data).digest()
|
||||
|
||||
cursor = await db.conn.execute(
|
||||
"INSERT OR IGNORE INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)",
|
||||
(ts, data, payload_hash),
|
||||
)
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"INSERT OR IGNORE INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)",
|
||||
(ts, data, payload_hash),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
lastrowid = cursor.lastrowid
|
||||
|
||||
if cursor.rowcount > 0:
|
||||
assert cursor.lastrowid is not None
|
||||
return (cursor.lastrowid, True)
|
||||
if rowcount > 0:
|
||||
assert lastrowid is not None
|
||||
return (lastrowid, True)
|
||||
|
||||
# Duplicate payload — look up the existing row.
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,)
|
||||
)
|
||||
existing = await cursor.fetchone()
|
||||
# Duplicate payload — look up the existing row (same transaction).
|
||||
async with conn.execute(
|
||||
"SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,)
|
||||
) as cursor:
|
||||
existing = await cursor.fetchone()
|
||||
assert existing is not None
|
||||
return (existing["id"], False)
|
||||
|
||||
@staticmethod
|
||||
async def get_undecrypted_count() -> int:
|
||||
"""Get count of undecrypted packets (those without a linked message)."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["count"] if row else 0
|
||||
|
||||
@staticmethod
|
||||
async def get_oldest_undecrypted() -> int | None:
|
||||
"""Get timestamp of oldest undecrypted packet, or None if none exist."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["oldest"] if row and row["oldest"] is not None else None
|
||||
|
||||
@staticmethod
|
||||
async def _stream_undecrypted_rows(
|
||||
batch_size: int,
|
||||
) -> AsyncIterator[tuple[int, bytes, int]]:
|
||||
"""Internal: keyset-paginated scan of every undecrypted raw packet.
|
||||
|
||||
Yields ``(id, data, timestamp)`` for each row across all batches.
|
||||
Lock is acquired per batch only — concurrent writes can interleave
|
||||
at batch boundaries rather than being blocked for the full scan.
|
||||
Each batch opens a fresh cursor and consumes it fully with
|
||||
``fetchall()`` before releasing, so no prepared statement is alive
|
||||
at a yield boundary.
|
||||
|
||||
``last_id`` advances per row, not per yield, so external filters
|
||||
(see ``stream_undecrypted_text_messages``) that drop rows do not
|
||||
cause a re-scan of skipped IDs.
|
||||
"""
|
||||
last_id = -1
|
||||
while True:
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT id, data, timestamp FROM raw_packets "
|
||||
"WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?",
|
||||
(last_id, batch_size),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
if not rows:
|
||||
return
|
||||
for row in rows:
|
||||
last_id = row["id"]
|
||||
yield (row["id"], bytes(row["data"]), row["timestamp"])
|
||||
|
||||
@staticmethod
|
||||
async def stream_all_undecrypted(
|
||||
batch_size: int = UNDECRYPTED_PACKET_BATCH_SIZE,
|
||||
) -> AsyncIterator[tuple[int, bytes, int]]:
|
||||
"""Yield all undecrypted packets as (id, data, timestamp) in bounded batches.
|
||||
|
||||
Uses keyset pagination so each batch is a fresh query with a fully
|
||||
consumed cursor — no open statement held across yield boundaries.
|
||||
"""
|
||||
last_id = -1
|
||||
while True:
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT id, data, timestamp FROM raw_packets "
|
||||
"WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?",
|
||||
(last_id, batch_size),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
await cursor.close()
|
||||
if not rows:
|
||||
break
|
||||
for row in rows:
|
||||
last_id = row["id"]
|
||||
yield (row["id"], bytes(row["data"]), row["timestamp"])
|
||||
"""Yield all undecrypted packets as (id, data, timestamp) in bounded batches."""
|
||||
async for row in RawPacketRepository._stream_undecrypted_rows(batch_size):
|
||||
yield row
|
||||
|
||||
@staticmethod
|
||||
async def stream_undecrypted_text_messages(
|
||||
@@ -100,26 +120,15 @@ class RawPacketRepository:
|
||||
) -> AsyncIterator[tuple[int, bytes, int]]:
|
||||
"""Yield undecrypted TEXT_MESSAGE packets in bounded-size batches.
|
||||
|
||||
Uses keyset pagination so each batch is a fresh query with a fully
|
||||
consumed cursor — no open statement held across yield boundaries.
|
||||
Filters the shared scan to rows whose payload parses as a text
|
||||
message. Non-matching rows still advance the keyset cursor so they
|
||||
aren't re-fetched on subsequent batches.
|
||||
"""
|
||||
last_id = -1
|
||||
while True:
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT id, data, timestamp FROM raw_packets "
|
||||
"WHERE message_id IS NULL AND id > ? ORDER BY id ASC LIMIT ?",
|
||||
(last_id, batch_size),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
await cursor.close()
|
||||
if not rows:
|
||||
break
|
||||
for row in rows:
|
||||
last_id = row["id"]
|
||||
data = bytes(row["data"])
|
||||
payload_type = get_packet_payload_type(data)
|
||||
if payload_type == PayloadType.TEXT_MESSAGE:
|
||||
yield (row["id"], data, row["timestamp"])
|
||||
async for packet_id, data, timestamp in RawPacketRepository._stream_undecrypted_rows(
|
||||
batch_size
|
||||
):
|
||||
if get_packet_payload_type(data) == PayloadType.TEXT_MESSAGE:
|
||||
yield (packet_id, data, timestamp)
|
||||
|
||||
@staticmethod
|
||||
async def count_undecrypted_text_messages(
|
||||
@@ -136,20 +145,22 @@ class RawPacketRepository:
|
||||
@staticmethod
|
||||
async def mark_decrypted(packet_id: int, message_id: int) -> None:
|
||||
"""Link a raw packet to its decrypted message."""
|
||||
await db.conn.execute(
|
||||
"UPDATE raw_packets SET message_id = ? WHERE id = ?",
|
||||
(message_id, packet_id),
|
||||
)
|
||||
await db.conn.commit()
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"UPDATE raw_packets SET message_id = ? WHERE id = ?",
|
||||
(message_id, packet_id),
|
||||
):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def get_linked_message_id(packet_id: int) -> int | None:
|
||||
"""Return the linked message ID for a raw packet, if any."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT message_id FROM raw_packets WHERE id = ?",
|
||||
(packet_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT message_id FROM raw_packets WHERE id = ?",
|
||||
(packet_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return row["message_id"]
|
||||
@@ -157,11 +168,12 @@ class RawPacketRepository:
|
||||
@staticmethod
|
||||
async def get_by_id(packet_id: int) -> tuple[int, bytes, int, int | None] | None:
|
||||
"""Return a raw packet row as (id, data, timestamp, message_id)."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT id, data, timestamp, message_id FROM raw_packets WHERE id = ?",
|
||||
(packet_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT id, data, timestamp, message_id FROM raw_packets WHERE id = ?",
|
||||
(packet_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return (row["id"], bytes(row["data"]), row["timestamp"], row["message_id"])
|
||||
@@ -170,16 +182,20 @@ class RawPacketRepository:
|
||||
async def prune_old_undecrypted(max_age_days: int) -> int:
|
||||
"""Delete undecrypted packets older than max_age_days. Returns count deleted."""
|
||||
cutoff = int(time.time()) - (max_age_days * 86400)
|
||||
cursor = await db.conn.execute(
|
||||
"DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?",
|
||||
(cutoff,),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?",
|
||||
(cutoff,),
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount
|
||||
|
||||
@staticmethod
|
||||
async def purge_linked_to_messages() -> int:
|
||||
"""Delete raw packets that are already linked to a stored message."""
|
||||
cursor = await db.conn.execute("DELETE FROM raw_packets WHERE message_id IS NOT NULL")
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"DELETE FROM raw_packets WHERE message_id IS NOT NULL"
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount
|
||||
|
||||
@@ -21,51 +21,54 @@ class RepeaterTelemetryRepository:
|
||||
data: dict,
|
||||
) -> None:
|
||||
"""Insert a telemetry history row and prune stale entries."""
|
||||
await db.conn.execute(
|
||||
"""
|
||||
INSERT INTO repeater_telemetry_history
|
||||
(public_key, timestamp, data)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(public_key, timestamp, json.dumps(data)),
|
||||
)
|
||||
|
||||
# Prune entries older than 30 days
|
||||
cutoff = int(time.time()) - _MAX_AGE_SECONDS
|
||||
await db.conn.execute(
|
||||
"DELETE FROM repeater_telemetry_history WHERE public_key = ? AND timestamp < ?",
|
||||
(public_key, cutoff),
|
||||
)
|
||||
async with db.tx() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
INSERT INTO repeater_telemetry_history
|
||||
(public_key, timestamp, data)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(public_key, timestamp, json.dumps(data)),
|
||||
):
|
||||
pass
|
||||
|
||||
# Cap at _MAX_ENTRIES_PER_REPEATER (keep newest)
|
||||
await db.conn.execute(
|
||||
"""
|
||||
DELETE FROM repeater_telemetry_history
|
||||
WHERE public_key = ? AND id NOT IN (
|
||||
SELECT id FROM repeater_telemetry_history
|
||||
WHERE public_key = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
(public_key, public_key, _MAX_ENTRIES_PER_REPEATER),
|
||||
)
|
||||
# Prune entries older than 30 days
|
||||
async with conn.execute(
|
||||
"DELETE FROM repeater_telemetry_history WHERE public_key = ? AND timestamp < ?",
|
||||
(public_key, cutoff),
|
||||
):
|
||||
pass
|
||||
|
||||
await db.conn.commit()
|
||||
# Cap at _MAX_ENTRIES_PER_REPEATER (keep newest)
|
||||
async with conn.execute(
|
||||
"""
|
||||
DELETE FROM repeater_telemetry_history
|
||||
WHERE public_key = ? AND id NOT IN (
|
||||
SELECT id FROM repeater_telemetry_history
|
||||
WHERE public_key = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
(public_key, public_key, _MAX_ENTRIES_PER_REPEATER),
|
||||
):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def get_history(public_key: str, since_timestamp: int) -> list[dict]:
|
||||
"""Return telemetry rows for a repeater since a given timestamp, ordered ASC."""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT timestamp, data
|
||||
FROM repeater_telemetry_history
|
||||
WHERE public_key = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC
|
||||
""",
|
||||
(public_key, since_timestamp),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT timestamp, data
|
||||
FROM repeater_telemetry_history
|
||||
WHERE public_key = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC
|
||||
""",
|
||||
(public_key, since_timestamp),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
{
|
||||
"timestamp": row["timestamp"],
|
||||
@@ -77,17 +80,18 @@ class RepeaterTelemetryRepository:
|
||||
@staticmethod
|
||||
async def get_latest(public_key: str) -> dict | None:
|
||||
"""Return the most recent telemetry row for a repeater, or None."""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT timestamp, data
|
||||
FROM repeater_telemetry_history
|
||||
WHERE public_key = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(public_key,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT timestamp, data
|
||||
FROM repeater_telemetry_history
|
||||
WHERE public_key = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(public_key,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.database import db
|
||||
from app.models import AppSettings
|
||||
from app.path_utils import bucket_path_hash_widths
|
||||
@@ -17,15 +19,23 @@ SECONDS_7D = 604800
|
||||
|
||||
|
||||
class AppSettingsRepository:
|
||||
"""Repository for app_settings table (single-row pattern)."""
|
||||
"""Repository for app_settings table (single-row pattern).
|
||||
|
||||
Public methods acquire the DB lock exactly once. ``toggle_*`` helpers that
|
||||
need a read-modify-write do so inside a single ``db.tx()`` — the internal
|
||||
``_get_in_conn`` / ``_apply_updates`` helpers run under the caller's
|
||||
already-held lock and must NEVER call ``db.tx()`` or ``db.readonly()``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get() -> AppSettings:
|
||||
"""Get the current app settings.
|
||||
async def _get_in_conn(conn: aiosqlite.Connection) -> AppSettings:
|
||||
"""Load settings using an already-acquired connection.
|
||||
|
||||
Always returns settings - creates default row if needed (migration handles initial row).
|
||||
Used by the public ``get()`` and by multi-step operations
|
||||
(``toggle_blocked_key``, ``toggle_blocked_name``) to avoid re-entering
|
||||
the non-reentrant DB lock.
|
||||
"""
|
||||
cursor = await db.conn.execute(
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT max_radio_contacts, auto_decrypt_dm_on_advert,
|
||||
last_message_times,
|
||||
@@ -35,8 +45,8 @@ class AppSettingsRepository:
|
||||
telemetry_interval_hours
|
||||
FROM app_settings WHERE id = 1
|
||||
"""
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
# Should not happen after migration, but handle gracefully
|
||||
@@ -119,7 +129,9 @@ class AppSettingsRepository:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
async def _apply_updates(
|
||||
conn: aiosqlite.Connection,
|
||||
*,
|
||||
max_radio_contacts: int | None = None,
|
||||
auto_decrypt_dm_on_advert: bool | None = None,
|
||||
last_message_times: dict[str, int] | None = None,
|
||||
@@ -132,9 +144,13 @@ class AppSettingsRepository:
|
||||
tracked_telemetry_repeaters: list[str] | None = None,
|
||||
auto_resend_channel: bool | None = None,
|
||||
telemetry_interval_hours: int | None = None,
|
||||
) -> AppSettings:
|
||||
"""Update app settings. Only provided fields are updated."""
|
||||
updates = []
|
||||
) -> None:
|
||||
"""Apply field updates using an already-acquired connection.
|
||||
|
||||
Emits a single UPDATE statement inside the caller's transaction. Does
|
||||
NOT commit — the caller's ``db.tx()`` handles that.
|
||||
"""
|
||||
updates: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
if max_radio_contacts is not None:
|
||||
@@ -187,47 +203,101 @@ class AppSettingsRepository:
|
||||
|
||||
if updates:
|
||||
query = f"UPDATE app_settings SET {', '.join(updates)} WHERE id = 1"
|
||||
await db.conn.execute(query, params)
|
||||
await db.conn.commit()
|
||||
async with conn.execute(query, params):
|
||||
pass
|
||||
|
||||
return await AppSettingsRepository.get()
|
||||
@staticmethod
|
||||
async def get() -> AppSettings:
|
||||
"""Get the current app settings.
|
||||
|
||||
Always returns settings - creates default row if needed (migration handles initial row).
|
||||
"""
|
||||
async with db.readonly() as conn:
|
||||
return await AppSettingsRepository._get_in_conn(conn)
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
max_radio_contacts: int | None = None,
|
||||
auto_decrypt_dm_on_advert: bool | None = None,
|
||||
last_message_times: dict[str, int] | None = None,
|
||||
advert_interval: int | None = None,
|
||||
last_advert_time: int | None = None,
|
||||
flood_scope: str | None = None,
|
||||
blocked_keys: list[str] | None = None,
|
||||
blocked_names: list[str] | None = None,
|
||||
discovery_blocked_types: list[int] | None = None,
|
||||
tracked_telemetry_repeaters: list[str] | None = None,
|
||||
auto_resend_channel: bool | None = None,
|
||||
telemetry_interval_hours: int | None = None,
|
||||
) -> AppSettings:
|
||||
"""Update app settings. Only provided fields are updated."""
|
||||
async with db.tx() as conn:
|
||||
await AppSettingsRepository._apply_updates(
|
||||
conn,
|
||||
max_radio_contacts=max_radio_contacts,
|
||||
auto_decrypt_dm_on_advert=auto_decrypt_dm_on_advert,
|
||||
last_message_times=last_message_times,
|
||||
advert_interval=advert_interval,
|
||||
last_advert_time=last_advert_time,
|
||||
flood_scope=flood_scope,
|
||||
blocked_keys=blocked_keys,
|
||||
blocked_names=blocked_names,
|
||||
discovery_blocked_types=discovery_blocked_types,
|
||||
tracked_telemetry_repeaters=tracked_telemetry_repeaters,
|
||||
auto_resend_channel=auto_resend_channel,
|
||||
telemetry_interval_hours=telemetry_interval_hours,
|
||||
)
|
||||
return await AppSettingsRepository._get_in_conn(conn)
|
||||
|
||||
@staticmethod
|
||||
async def toggle_blocked_key(key: str) -> AppSettings:
|
||||
"""Toggle a public key in the blocked list. Keys are normalized to lowercase."""
|
||||
"""Toggle a public key in the blocked list. Keys are normalized to lowercase.
|
||||
|
||||
Read-modify-write is atomic under a single ``db.tx()`` lock — two
|
||||
concurrent toggles for the same key cannot produce an inconsistent
|
||||
intermediate state.
|
||||
"""
|
||||
normalized = key.lower()
|
||||
settings = await AppSettingsRepository.get()
|
||||
if normalized in settings.blocked_keys:
|
||||
new_keys = [k for k in settings.blocked_keys if k != normalized]
|
||||
else:
|
||||
new_keys = settings.blocked_keys + [normalized]
|
||||
return await AppSettingsRepository.update(blocked_keys=new_keys)
|
||||
async with db.tx() as conn:
|
||||
settings = await AppSettingsRepository._get_in_conn(conn)
|
||||
if normalized in settings.blocked_keys:
|
||||
new_keys = [k for k in settings.blocked_keys if k != normalized]
|
||||
else:
|
||||
new_keys = settings.blocked_keys + [normalized]
|
||||
await AppSettingsRepository._apply_updates(conn, blocked_keys=new_keys)
|
||||
return await AppSettingsRepository._get_in_conn(conn)
|
||||
|
||||
@staticmethod
|
||||
async def toggle_blocked_name(name: str) -> AppSettings:
|
||||
"""Toggle a display name in the blocked list."""
|
||||
settings = await AppSettingsRepository.get()
|
||||
if name in settings.blocked_names:
|
||||
new_names = [n for n in settings.blocked_names if n != name]
|
||||
else:
|
||||
new_names = settings.blocked_names + [name]
|
||||
return await AppSettingsRepository.update(blocked_names=new_names)
|
||||
"""Toggle a display name in the blocked list.
|
||||
|
||||
Same atomicity guarantee as ``toggle_blocked_key``.
|
||||
"""
|
||||
async with db.tx() as conn:
|
||||
settings = await AppSettingsRepository._get_in_conn(conn)
|
||||
if name in settings.blocked_names:
|
||||
new_names = [n for n in settings.blocked_names if n != name]
|
||||
else:
|
||||
new_names = settings.blocked_names + [name]
|
||||
await AppSettingsRepository._apply_updates(conn, blocked_names=new_names)
|
||||
return await AppSettingsRepository._get_in_conn(conn)
|
||||
|
||||
|
||||
class StatisticsRepository:
|
||||
@staticmethod
|
||||
async def get_database_message_totals() -> dict[str, int]:
|
||||
"""Return message totals needed by lightweight debug surfaces."""
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms,
|
||||
SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages,
|
||||
SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing
|
||||
FROM messages
|
||||
"""
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
SUM(CASE WHEN type = 'PRIV' THEN 1 ELSE 0 END) AS total_dms,
|
||||
SUM(CASE WHEN type = 'CHAN' THEN 1 ELSE 0 END) AS total_channel_messages,
|
||||
SUM(CASE WHEN outgoing = 1 THEN 1 ELSE 0 END) AS total_outgoing
|
||||
FROM messages
|
||||
"""
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
return {
|
||||
"total_dms": row["total_dms"] or 0,
|
||||
@@ -240,18 +310,19 @@ class StatisticsRepository:
|
||||
"""Get time-windowed counts for contacts/repeaters heard."""
|
||||
now = int(time.time())
|
||||
op = "!=" if exclude else "="
|
||||
cursor = await db.conn.execute(
|
||||
f"""
|
||||
SELECT
|
||||
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour,
|
||||
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours,
|
||||
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week
|
||||
FROM contacts
|
||||
WHERE type {op} ? AND last_seen IS NOT NULL
|
||||
""",
|
||||
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
f"""
|
||||
SELECT
|
||||
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour,
|
||||
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours,
|
||||
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week
|
||||
FROM contacts
|
||||
WHERE type {op} ? AND last_seen IS NOT NULL
|
||||
""",
|
||||
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None # Aggregate query always returns a row
|
||||
return {
|
||||
"last_hour": row["last_hour"] or 0,
|
||||
@@ -267,24 +338,25 @@ class StatisticsRepository:
|
||||
the old UPPER(...) join and aggregate per known channel directly.
|
||||
"""
|
||||
now = int(time.time())
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
WITH known AS (
|
||||
SELECT conversation_key, MAX(received_at) AS last_received_at
|
||||
FROM messages
|
||||
WHERE type = 'CHAN'
|
||||
AND conversation_key IN (SELECT key FROM channels)
|
||||
GROUP BY conversation_key
|
||||
)
|
||||
SELECT
|
||||
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_hour,
|
||||
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_24_hours,
|
||||
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_week
|
||||
FROM known
|
||||
""",
|
||||
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
WITH known AS (
|
||||
SELECT conversation_key, MAX(received_at) AS last_received_at
|
||||
FROM messages
|
||||
WHERE type = 'CHAN'
|
||||
AND conversation_key IN (SELECT key FROM channels)
|
||||
GROUP BY conversation_key
|
||||
)
|
||||
SELECT
|
||||
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_hour,
|
||||
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_24_hours,
|
||||
SUM(CASE WHEN last_received_at >= ? THEN 1 ELSE 0 END) AS last_week
|
||||
FROM known
|
||||
""",
|
||||
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
return {
|
||||
"last_hour": row["last_hour"] or 0,
|
||||
@@ -298,92 +370,105 @@ class StatisticsRepository:
|
||||
now = int(time.time())
|
||||
cutoff = now - SECONDS_72H
|
||||
# Bucket timestamps to the start of each hour
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT (timestamp / 3600) * 3600 AS hour_ts, COUNT(*) AS count
|
||||
FROM raw_packets
|
||||
WHERE timestamp >= ?
|
||||
GROUP BY hour_ts
|
||||
ORDER BY hour_ts
|
||||
""",
|
||||
(cutoff,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT (timestamp / 3600) * 3600 AS hour_ts, COUNT(*) AS count
|
||||
FROM raw_packets
|
||||
WHERE timestamp >= ?
|
||||
GROUP BY hour_ts
|
||||
ORDER BY hour_ts
|
||||
""",
|
||||
(cutoff,),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [{"timestamp": row["hour_ts"], "count": row["count"]} for row in rows]
|
||||
|
||||
@staticmethod
|
||||
async def _path_hash_width_24h() -> dict[str, int | float]:
|
||||
"""Count parsed raw packets from the last 24h by hop hash width."""
|
||||
now = int(time.time())
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT data FROM raw_packets WHERE timestamp >= ?",
|
||||
(now - SECONDS_24H,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
async with db.readonly() as conn:
|
||||
async with conn.execute(
|
||||
"SELECT data FROM raw_packets WHERE timestamp >= ?",
|
||||
(now - SECONDS_24H,),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return bucket_path_hash_widths(rows)
|
||||
|
||||
@staticmethod
|
||||
async def get_all() -> dict:
|
||||
"""Aggregate all statistics from existing tables."""
|
||||
"""Aggregate all statistics from existing tables.
|
||||
|
||||
Each helper acquires its own lock; there's no requirement that the
|
||||
whole snapshot be atomic. If we ever wanted a consistent snapshot
|
||||
we'd batch all queries into a single ``db.readonly()`` and use
|
||||
``_in_conn`` helpers, but statistics are intentionally approximate.
|
||||
"""
|
||||
now = int(time.time())
|
||||
|
||||
# Top 5 busiest channels in last 24h
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
|
||||
COUNT(*) AS message_count
|
||||
FROM messages m
|
||||
LEFT JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.received_at >= ?
|
||||
GROUP BY m.conversation_key
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT 5
|
||||
""",
|
||||
(now - SECONDS_24H,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
busiest_channels_24h = [
|
||||
{
|
||||
"channel_key": row["conversation_key"],
|
||||
"channel_name": row["channel_name"],
|
||||
"message_count": row["message_count"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
async with db.readonly() as conn:
|
||||
# Top 5 busiest channels in last 24h
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
|
||||
COUNT(*) AS message_count
|
||||
FROM messages m
|
||||
LEFT JOIN channels c ON m.conversation_key = c.key
|
||||
WHERE m.type = 'CHAN' AND m.received_at >= ?
|
||||
GROUP BY m.conversation_key
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT 5
|
||||
""",
|
||||
(now - SECONDS_24H,),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
busiest_channels_24h = [
|
||||
{
|
||||
"channel_key": row["conversation_key"],
|
||||
"channel_name": row["channel_name"],
|
||||
"message_count": row["message_count"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
# Entity counts
|
||||
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
contact_count: int = row["cnt"]
|
||||
# Entity counts
|
||||
async with conn.execute(
|
||||
"SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
contact_count: int = row["cnt"]
|
||||
|
||||
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
repeater_count: int = row["cnt"]
|
||||
async with conn.execute(
|
||||
"SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
repeater_count: int = row["cnt"]
|
||||
|
||||
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM channels")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
channel_count: int = row["cnt"]
|
||||
async with conn.execute("SELECT COUNT(*) AS cnt FROM channels") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
channel_count: int = row["cnt"]
|
||||
|
||||
# Packet split
|
||||
cursor = await db.conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS total,
|
||||
SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted
|
||||
FROM raw_packets
|
||||
"""
|
||||
)
|
||||
pkt_row = await cursor.fetchone()
|
||||
assert pkt_row is not None
|
||||
total_packets = pkt_row["total"] or 0
|
||||
decrypted_packets = pkt_row["decrypted"] or 0
|
||||
undecrypted_packets = total_packets - decrypted_packets
|
||||
# Packet split
|
||||
async with conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS total,
|
||||
SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted
|
||||
FROM raw_packets
|
||||
"""
|
||||
) as cursor:
|
||||
pkt_row = await cursor.fetchone()
|
||||
assert pkt_row is not None
|
||||
total_packets = pkt_row["total"] or 0
|
||||
decrypted_packets = pkt_row["decrypted"] or 0
|
||||
undecrypted_packets = total_packets - decrypted_packets
|
||||
|
||||
# These each acquire their own lock. The snapshot isn't atomic across
|
||||
# them — fine for stats, which are approximate by nature.
|
||||
message_totals = await StatisticsRepository.get_database_message_totals()
|
||||
|
||||
# Activity windows
|
||||
contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True)
|
||||
repeaters_heard = await StatisticsRepository._activity_counts(contact_type=2)
|
||||
known_channels_active = await StatisticsRepository._known_channels_active()
|
||||
|
||||
@@ -28,13 +28,28 @@ def cleanup_test_db_dir():
|
||||
@pytest.fixture
|
||||
async def test_db():
|
||||
"""Create an in-memory test database with schema + migrations."""
|
||||
from app.repository import channels, contacts, messages, raw_packets, settings
|
||||
from app.repository import (
|
||||
channels,
|
||||
contacts,
|
||||
messages,
|
||||
raw_packets,
|
||||
repeater_telemetry,
|
||||
settings,
|
||||
)
|
||||
from app.repository import fanout as fanout_repo
|
||||
|
||||
db = Database(":memory:")
|
||||
await db.connect()
|
||||
|
||||
submodules = [contacts, channels, messages, raw_packets, settings, fanout_repo]
|
||||
submodules = [
|
||||
contacts,
|
||||
channels,
|
||||
messages,
|
||||
raw_packets,
|
||||
settings,
|
||||
fanout_repo,
|
||||
repeater_telemetry,
|
||||
]
|
||||
originals = [(mod, mod.db) for mod in submodules]
|
||||
|
||||
for mod in submodules:
|
||||
|
||||
@@ -322,7 +322,7 @@ class TestUndecryptedTextPacketStreaming:
|
||||
[],
|
||||
]
|
||||
|
||||
async def fake_execute(*_args, **_kwargs):
|
||||
def fake_execute(*_args, **_kwargs):
|
||||
batch = batches.pop(0)
|
||||
|
||||
class FakeCursor:
|
||||
@@ -332,6 +332,16 @@ class TestUndecryptedTextPacketStreaming:
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
# aiosqlite's execute() returns a `contextmanager`-decorated
|
||||
# coroutine that is both awaitable and usable as an async-with.
|
||||
# Our repo code now uses `async with conn.execute(...) as cursor:`,
|
||||
# so the mock just needs to return something with __aenter__/__aexit__.
|
||||
return FakeCursor()
|
||||
|
||||
with patch.object(test_db.conn, "execute", side_effect=fake_execute):
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Tests for repository layer."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models import Contact, ContactUpsert
|
||||
from app.repository import (
|
||||
AppSettingsRepository,
|
||||
ContactAdvertPathRepository,
|
||||
ContactNameHistoryRepository,
|
||||
ContactRepository,
|
||||
@@ -613,37 +614,103 @@ class TestAppSettingsRepository:
|
||||
"""Test AppSettingsRepository parsing and migration edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_handles_corrupted_json_and_invalid_sort_order(self):
|
||||
"""Corrupted JSON fields are recovered with safe defaults."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(
|
||||
return_value={
|
||||
"max_radio_contacts": 250,
|
||||
"auto_decrypt_dm_on_advert": 1,
|
||||
"last_message_times": "{also-not-json",
|
||||
"advert_interval": None,
|
||||
"last_advert_time": None,
|
||||
"flood_scope": "",
|
||||
"blocked_keys": "[]",
|
||||
"blocked_names": "[]",
|
||||
"discovery_blocked_types": "[]",
|
||||
}
|
||||
async def test_get_handles_corrupted_json_and_invalid_sort_order(self, test_db):
|
||||
"""Corrupted JSON fields are recovered with safe defaults.
|
||||
|
||||
Uses the real DB so it exercises the lock-aware path. We stuff
|
||||
malformed JSON directly into the row, then verify ``get()`` recovers
|
||||
with defaults rather than propagating a parse error.
|
||||
"""
|
||||
await test_db.conn.execute(
|
||||
"""
|
||||
UPDATE app_settings
|
||||
SET max_radio_contacts = 250,
|
||||
auto_decrypt_dm_on_advert = 1,
|
||||
last_message_times = '{also-not-json',
|
||||
advert_interval = NULL,
|
||||
last_advert_time = NULL,
|
||||
flood_scope = '',
|
||||
blocked_keys = '[]',
|
||||
blocked_names = '[]',
|
||||
discovery_blocked_types = '[]'
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_db = MagicMock()
|
||||
mock_db.conn = mock_conn
|
||||
await test_db.conn.commit()
|
||||
|
||||
with patch("app.repository.settings.db", mock_db):
|
||||
from app.repository import AppSettingsRepository
|
||||
|
||||
settings = await AppSettingsRepository.get()
|
||||
settings = await AppSettingsRepository.get()
|
||||
|
||||
assert settings.max_radio_contacts == 250
|
||||
assert settings.last_message_times == {}
|
||||
assert settings.advert_interval == 0
|
||||
assert settings.last_advert_time == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_in_conn_tolerates_missing_columns(self):
|
||||
"""Defend against partial migrations where columns added by later
|
||||
migrations are absent from the row.
|
||||
|
||||
Real DBs can't produce this state (schema init + migrations always
|
||||
run to the latest version on startup), but hand-rolled snapshots,
|
||||
external DB tools, or interrupted migrations might. The
|
||||
``KeyError``-catching branches in ``_get_in_conn`` exist specifically
|
||||
to guarantee graceful degradation.
|
||||
|
||||
We test these directly by mocking the connection boundary with a
|
||||
dict-backed row that mimics a pre-migration snapshot missing:
|
||||
- ``tracked_telemetry_repeaters`` (migration 53)
|
||||
- ``auto_resend_channel`` (migration 54)
|
||||
- ``telemetry_interval_hours`` (migration 57)
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.telemetry_interval import DEFAULT_TELEMETRY_INTERVAL_HOURS
|
||||
|
||||
# sqlite3.Row raises KeyError for missing columns when accessed by
|
||||
# name, which is what we want to simulate. We mimic that here with a
|
||||
# dict-backed object whose __getitem__ raises KeyError for absent
|
||||
# keys (dict.__getitem__ already does this).
|
||||
class PartialRow(dict):
|
||||
def keys(self): # pragma: no cover - aiosqlite.Row compat
|
||||
return super().keys()
|
||||
|
||||
partial_row = PartialRow(
|
||||
{
|
||||
"max_radio_contacts": 123,
|
||||
"auto_decrypt_dm_on_advert": 1,
|
||||
"last_message_times": "{}",
|
||||
"advert_interval": 0,
|
||||
"last_advert_time": 0,
|
||||
"flood_scope": "",
|
||||
"blocked_keys": "[]",
|
||||
"blocked_names": "[]",
|
||||
"discovery_blocked_types": "[]",
|
||||
# intentionally missing: tracked_telemetry_repeaters,
|
||||
# auto_resend_channel, telemetry_interval_hours
|
||||
}
|
||||
)
|
||||
|
||||
class FakeCursor:
|
||||
async def fetchone(self):
|
||||
return partial_row
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute = MagicMock(return_value=FakeCursor())
|
||||
|
||||
settings = await AppSettingsRepository._get_in_conn(mock_conn)
|
||||
|
||||
assert settings.max_radio_contacts == 123
|
||||
# Missing-column defaults kick in:
|
||||
assert settings.tracked_telemetry_repeaters == []
|
||||
assert settings.auto_resend_channel is False
|
||||
assert settings.telemetry_interval_hours == DEFAULT_TELEMETRY_INTERVAL_HOURS
|
||||
|
||||
|
||||
class TestMessageRepositoryGetById:
|
||||
"""Test MessageRepository.get_by_id method."""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -353,13 +353,21 @@ class TestPathHashWidthStats:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_hash_width_scan_fetches_all_then_buckets(self, test_db):
|
||||
"""Hash-width stats should fetchall() then bucket synchronously."""
|
||||
"""Hash-width stats should fetchall() then bucket synchronously.
|
||||
|
||||
fake_rows = [{"data": b"a"}, {"data": b"b"}, {"data": b"c"}]
|
||||
Uses real DB rows + a patched parser so it exercises the lock-aware
|
||||
readonly path. Mocking ``conn.execute`` on the pre-refactor code no
|
||||
longer reflects the actual call pattern (we use ``async with``).
|
||||
"""
|
||||
|
||||
class FakeCursor:
|
||||
async def fetchall(self):
|
||||
return fake_rows
|
||||
now = int(time.time())
|
||||
# Seed three raw packets in the last 24h with arbitrary distinguishing bytes.
|
||||
for i, data in enumerate((b"a", b"b", b"c")):
|
||||
await test_db.conn.execute(
|
||||
"INSERT INTO raw_packets (timestamp, data) VALUES (?, ?)",
|
||||
(now - (i + 1), data),
|
||||
)
|
||||
await test_db.conn.commit()
|
||||
|
||||
def fake_parse(raw_packet: bytes):
|
||||
hash_sizes = {
|
||||
@@ -372,10 +380,7 @@ class TestPathHashWidthStats:
|
||||
return None
|
||||
return SimpleNamespace(hash_size=hash_size)
|
||||
|
||||
with (
|
||||
patch.object(test_db.conn, "execute", new=AsyncMock(return_value=FakeCursor())),
|
||||
patch("app.path_utils.parse_packet_envelope", side_effect=fake_parse),
|
||||
):
|
||||
with patch("app.path_utils.parse_packet_envelope", side_effect=fake_parse):
|
||||
breakdown = await StatisticsRepository._path_hash_width_24h()
|
||||
|
||||
assert breakdown["total_packets"] == 3
|
||||
|
||||
Reference in New Issue
Block a user