diff --git a/app/main.py b/app/main.py index 57de687..c20ce1a 100644 --- a/app/main.py +++ b/app/main.py @@ -23,6 +23,7 @@ from app.routers import ( packets, radio, read_state, + repeaters, settings, statistics, ws, @@ -106,6 +107,7 @@ async def radio_disconnected_handler(request: Request, exc: RadioDisconnectedErr app.include_router(health.router, prefix="/api") app.include_router(radio.router, prefix="/api") app.include_router(contacts.router, prefix="/api") +app.include_router(repeaters.router, prefix="/api") app.include_router(channels.router, prefix="/api") app.include_router(messages.router, prefix="/api") app.include_router(packets.router, prefix="/api") diff --git a/app/repository.py b/app/repository.py deleted file mode 100644 index 170c5dd..0000000 --- a/app/repository.py +++ /dev/null @@ -1,1375 +0,0 @@ -import json -import logging -import sqlite3 -import time -from hashlib import sha256 -from typing import Any, Literal - -from app.database import db -from app.decoder import PayloadType, extract_payload, get_packet_payload_type -from app.models import ( - AppSettings, - BotConfig, - Channel, - Contact, - ContactAdvertPath, - ContactAdvertPathSummary, - ContactNameHistory, - Favorite, - Message, - MessagePath, -) - -logger = logging.getLogger(__name__) - - -SECONDS_1H = 3600 -SECONDS_24H = 86400 -SECONDS_7D = 604800 - - -class AmbiguousPublicKeyPrefixError(ValueError): - """Raised when a public key prefix matches multiple contacts.""" - - def __init__(self, prefix: str, matches: list[str]): - self.prefix = prefix.lower() - self.matches = matches - super().__init__(f"Ambiguous public key prefix '{self.prefix}'") - - -class ContactRepository: - @staticmethod - async def upsert(contact: dict[str, Any]) -> None: - await db.conn.execute( - """ - INSERT INTO contacts (public_key, name, type, flags, last_path, last_path_len, - last_advert, lat, lon, last_seen, on_radio, last_contacted, - first_seen) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(public_key) DO UPDATE SET - name = COALESCE(excluded.name, contacts.name), - type = CASE WHEN excluded.type = 0 THEN contacts.type ELSE excluded.type END, - flags = excluded.flags, - last_path = COALESCE(excluded.last_path, contacts.last_path), - last_path_len = excluded.last_path_len, - last_advert = COALESCE(excluded.last_advert, contacts.last_advert), - lat = COALESCE(excluded.lat, contacts.lat), - lon = COALESCE(excluded.lon, contacts.lon), - last_seen = excluded.last_seen, - on_radio = COALESCE(excluded.on_radio, contacts.on_radio), - last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted), - first_seen = COALESCE(contacts.first_seen, excluded.first_seen) - """, - ( - contact.get("public_key", "").lower(), - contact.get("name"), - contact.get("type", 0), - contact.get("flags", 0), - contact.get("last_path"), - contact.get("last_path_len", -1), - contact.get("last_advert"), - contact.get("lat"), - contact.get("lon"), - contact.get("last_seen", int(time.time())), - contact.get("on_radio"), - contact.get("last_contacted"), - contact.get("first_seen"), - ), - ) - await db.conn.commit() - - @staticmethod - def _row_to_contact(row) -> Contact: - """Convert a database row to a Contact model.""" - return Contact( - public_key=row["public_key"], - name=row["name"], - type=row["type"], - flags=row["flags"], - last_path=row["last_path"], - last_path_len=row["last_path_len"], - last_advert=row["last_advert"], - lat=row["lat"], - lon=row["lon"], - last_seen=row["last_seen"], - on_radio=bool(row["on_radio"]), - last_contacted=row["last_contacted"], - last_read_at=row["last_read_at"], - first_seen=row["first_seen"], - ) - - @staticmethod - async def get_by_key(public_key: str) -> Contact | None: - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),) - ) - row = await cursor.fetchone() - return ContactRepository._row_to_contact(row) if row else None - - @staticmethod - async def get_by_key_prefix(prefix: str) -> Contact | None: - """Get a contact by key prefix only if it resolves uniquely. - - Returns None when no contacts match OR when multiple contacts match - the prefix (to avoid silently selecting the wrong contact). - """ - normalized_prefix = prefix.lower() - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2", - (f"{normalized_prefix}%",), - ) - rows = list(await cursor.fetchall()) - if len(rows) != 1: - return None - return ContactRepository._row_to_contact(rows[0]) - - @staticmethod - async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]: - """Get contacts matching a key prefix, up to limit.""" - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?", - (f"{prefix.lower()}%", limit), - ) - rows = list(await cursor.fetchall()) - return [ContactRepository._row_to_contact(row) for row in rows] - - @staticmethod - async def get_by_key_or_prefix(key_or_prefix: str) -> Contact | None: - """Get a contact by exact key match, falling back to prefix match. - - Useful when the input might be a full 64-char public key or a shorter prefix. - """ - contact = await ContactRepository.get_by_key(key_or_prefix) - if contact: - return contact - - matches = await ContactRepository._get_prefix_matches(key_or_prefix, limit=2) - if len(matches) == 1: - return matches[0] - if len(matches) > 1: - raise AmbiguousPublicKeyPrefixError( - key_or_prefix, - [m.public_key for m in matches], - ) - return None - - @staticmethod - async def get_by_name(name: str) -> list[Contact]: - """Get all contacts with the given exact name.""" - cursor = await db.conn.execute("SELECT * FROM contacts WHERE name = ?", (name,)) - rows = await cursor.fetchall() - return [ContactRepository._row_to_contact(row) for row in rows] - - @staticmethod - async def resolve_prefixes(prefixes: list[str]) -> dict[str, Contact]: - """Resolve multiple key prefixes to contacts in a single query. - - Returns a dict mapping each prefix to its Contact, only for prefixes - that resolve uniquely (exactly one match). Ambiguous or unmatched - prefixes are omitted. - """ - if not prefixes: - return {} - normalized = [p.lower() for p in prefixes] - conditions = " OR ".join(["public_key LIKE ?"] * len(normalized)) - params = [f"{p}%" for p in normalized] - cursor = await db.conn.execute(f"SELECT * FROM contacts WHERE {conditions}", params) - rows = await cursor.fetchall() - # Group by which prefix each row matches - prefix_to_rows: dict[str, list] = {p: [] for p in normalized} - for row in rows: - pk = row["public_key"] - for p in normalized: - if pk.startswith(p): - prefix_to_rows[p].append(row) - # Only include uniquely-resolved prefixes - result: dict[str, Contact] = {} - for p in normalized: - if len(prefix_to_rows[p]) == 1: - result[p] = ContactRepository._row_to_contact(prefix_to_rows[p][0]) - return result - - @staticmethod - async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]: - cursor = await db.conn.execute( - "SELECT * FROM contacts ORDER BY COALESCE(name, public_key) LIMIT ? OFFSET ?", - (limit, offset), - ) - rows = await cursor.fetchall() - return [ContactRepository._row_to_contact(row) for row in rows] - - @staticmethod - async def get_recent_non_repeaters(limit: int = 200) -> list[Contact]: - """Get the most recently active non-repeater contacts. - - Orders by most recent activity (last_contacted or last_advert), - excluding repeaters (type=2). - """ - cursor = await db.conn.execute( - """ - SELECT * FROM contacts - WHERE type != 2 - ORDER BY COALESCE(last_contacted, 0) DESC, COALESCE(last_advert, 0) DESC - LIMIT ? - """, - (limit,), - ) - rows = await cursor.fetchall() - return [ContactRepository._row_to_contact(row) for row in rows] - - @staticmethod - async def update_path(public_key: str, path: str, path_len: int) -> None: - await db.conn.execute( - "UPDATE contacts SET last_path = ?, last_path_len = ?, last_seen = ? WHERE public_key = ?", - (path, path_len, int(time.time()), public_key.lower()), - ) - await db.conn.commit() - - @staticmethod - async def set_on_radio(public_key: str, on_radio: bool) -> None: - await db.conn.execute( - "UPDATE contacts SET on_radio = ? WHERE public_key = ?", - (on_radio, public_key.lower()), - ) - await db.conn.commit() - - @staticmethod - async def delete(public_key: str) -> None: - normalized = public_key.lower() - await db.conn.execute( - "DELETE FROM contact_name_history WHERE public_key = ?", (normalized,) - ) - await db.conn.execute( - "DELETE FROM contact_advert_paths WHERE public_key = ?", (normalized,) - ) - await db.conn.execute("DELETE FROM contacts WHERE public_key = ?", (normalized,)) - await db.conn.commit() - - @staticmethod - async def update_last_contacted(public_key: str, timestamp: int | None = None) -> None: - """Update the last_contacted timestamp for a contact.""" - ts = timestamp if timestamp is not None else int(time.time()) - await db.conn.execute( - "UPDATE contacts SET last_contacted = ?, last_seen = ? WHERE public_key = ?", - (ts, ts, public_key.lower()), - ) - await db.conn.commit() - - @staticmethod - async def update_last_read_at(public_key: str, timestamp: int | None = None) -> bool: - """Update the last_read_at timestamp for a contact. - - Returns True if a row was updated, False if contact not found. - """ - ts = timestamp if timestamp is not None else int(time.time()) - cursor = await db.conn.execute( - "UPDATE contacts SET last_read_at = ? WHERE public_key = ?", - (ts, public_key.lower()), - ) - await db.conn.commit() - return cursor.rowcount > 0 - - @staticmethod - async def mark_all_read(timestamp: int) -> None: - """Mark all contacts as read at the given timestamp.""" - await db.conn.execute("UPDATE contacts SET last_read_at = ?", (timestamp,)) - await db.conn.commit() - - @staticmethod - async def get_by_pubkey_first_byte(hex_byte: str) -> list[Contact]: - """Get contacts whose public key starts with the given hex byte (2 chars).""" - cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?", - (hex_byte.lower(),), - ) - rows = await cursor.fetchall() - return [ContactRepository._row_to_contact(row) for row in rows] - - -class ContactAdvertPathRepository: - """Repository for recent unique advertisement paths per contact.""" - - @staticmethod - def _row_to_path(row) -> ContactAdvertPath: - path = row["path_hex"] or "" - next_hop = path[:2].lower() if len(path) >= 2 else None - return ContactAdvertPath( - path=path, - path_len=row["path_len"], - next_hop=next_hop, - first_seen=row["first_seen"], - last_seen=row["last_seen"], - heard_count=row["heard_count"], - ) - - @staticmethod - async def record_observation( - public_key: str, - path_hex: str, - timestamp: int, - max_paths: int = 10, - ) -> None: - """ - Upsert a unique advert path observation for a contact and prune to N most recent. - """ - if max_paths < 1: - max_paths = 1 - - normalized_key = public_key.lower() - normalized_path = path_hex.lower() - path_len = len(normalized_path) // 2 - - await db.conn.execute( - """ - INSERT INTO contact_advert_paths - (public_key, path_hex, path_len, first_seen, last_seen, heard_count) - VALUES (?, ?, ?, ?, ?, 1) - ON CONFLICT(public_key, path_hex) DO UPDATE SET - last_seen = MAX(contact_advert_paths.last_seen, excluded.last_seen), - path_len = excluded.path_len, - heard_count = contact_advert_paths.heard_count + 1 - """, - (normalized_key, normalized_path, path_len, timestamp, timestamp), - ) - - # Keep only the N most recent unique paths per contact. - await db.conn.execute( - """ - DELETE FROM contact_advert_paths - WHERE public_key = ? - AND path_hex NOT IN ( - SELECT path_hex - FROM contact_advert_paths - WHERE public_key = ? - ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC - LIMIT ? - ) - """, - (normalized_key, normalized_key, max_paths), - ) - await db.conn.commit() - - @staticmethod - async def get_recent_for_contact(public_key: str, limit: int = 10) -> list[ContactAdvertPath]: - cursor = await db.conn.execute( - """ - SELECT path_hex, path_len, first_seen, last_seen, heard_count - FROM contact_advert_paths - WHERE public_key = ? - ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC - LIMIT ? - """, - (public_key.lower(), limit), - ) - rows = await cursor.fetchall() - return [ContactAdvertPathRepository._row_to_path(row) for row in rows] - - @staticmethod - async def get_recent_for_all_contacts( - limit_per_contact: int = 10, - ) -> list[ContactAdvertPathSummary]: - cursor = await db.conn.execute( - """ - SELECT public_key, path_hex, path_len, first_seen, last_seen, heard_count - FROM contact_advert_paths - ORDER BY public_key ASC, last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC - """ - ) - rows = await cursor.fetchall() - - grouped: dict[str, list[ContactAdvertPath]] = {} - for row in rows: - key = row["public_key"] - paths = grouped.get(key) - if paths is None: - paths = [] - grouped[key] = paths - if len(paths) >= limit_per_contact: - continue - paths.append(ContactAdvertPathRepository._row_to_path(row)) - - return [ - ContactAdvertPathSummary(public_key=key, paths=paths) for key, paths in grouped.items() - ] - - -class ContactNameHistoryRepository: - """Repository for contact name change history.""" - - @staticmethod - async def record_name(public_key: str, name: str, timestamp: int) -> None: - """Record a name observation. Upserts: updates last_seen if name already known.""" - await db.conn.execute( - """ - INSERT INTO contact_name_history (public_key, name, first_seen, last_seen) - VALUES (?, ?, ?, ?) - ON CONFLICT(public_key, name) DO UPDATE SET - last_seen = MAX(contact_name_history.last_seen, excluded.last_seen) - """, - (public_key.lower(), name, timestamp, timestamp), - ) - await db.conn.commit() - - @staticmethod - async def get_history(public_key: str) -> list[ContactNameHistory]: - cursor = await db.conn.execute( - """ - SELECT name, first_seen, last_seen - FROM contact_name_history - WHERE public_key = ? - ORDER BY last_seen DESC - """, - (public_key.lower(),), - ) - rows = await cursor.fetchall() - return [ - ContactNameHistory( - name=row["name"], first_seen=row["first_seen"], last_seen=row["last_seen"] - ) - for row in rows - ] - - -class ChannelRepository: - @staticmethod - async def upsert(key: str, name: str, is_hashtag: bool = False, on_radio: bool = False) -> None: - """Upsert a channel. Key is 32-char hex string.""" - await db.conn.execute( - """ - INSERT INTO channels (key, name, is_hashtag, on_radio) - VALUES (?, ?, ?, ?) - ON CONFLICT(key) DO UPDATE SET - name = excluded.name, - is_hashtag = excluded.is_hashtag, - on_radio = excluded.on_radio - """, - (key.upper(), name, is_hashtag, on_radio), - ) - await db.conn.commit() - - @staticmethod - async def get_by_key(key: str) -> Channel | None: - """Get a channel by its key (32-char hex string).""" - cursor = await db.conn.execute( - "SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels WHERE key = ?", - (key.upper(),), - ) - row = await cursor.fetchone() - if row: - return Channel( - key=row["key"], - name=row["name"], - is_hashtag=bool(row["is_hashtag"]), - on_radio=bool(row["on_radio"]), - last_read_at=row["last_read_at"], - ) - return None - - @staticmethod - async def get_all() -> list[Channel]: - cursor = await db.conn.execute( - "SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels ORDER BY name" - ) - rows = await cursor.fetchall() - return [ - Channel( - key=row["key"], - name=row["name"], - is_hashtag=bool(row["is_hashtag"]), - on_radio=bool(row["on_radio"]), - last_read_at=row["last_read_at"], - ) - for row in rows - ] - - @staticmethod - async def delete(key: str) -> None: - """Delete a channel by key.""" - await db.conn.execute( - "DELETE FROM channels WHERE key = ?", - (key.upper(),), - ) - await db.conn.commit() - - @staticmethod - async def update_last_read_at(key: str, timestamp: int | None = None) -> bool: - """Update the last_read_at timestamp for a channel. - - Returns True if a row was updated, False if channel not found. - """ - ts = timestamp if timestamp is not None else int(time.time()) - cursor = await db.conn.execute( - "UPDATE channels SET last_read_at = ? WHERE key = ?", - (ts, key.upper()), - ) - await db.conn.commit() - return cursor.rowcount > 0 - - @staticmethod - async def mark_all_read(timestamp: int) -> None: - """Mark all channels as read at the given timestamp.""" - await db.conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,)) - await db.conn.commit() - - -class MessageRepository: - @staticmethod - def _parse_paths(paths_json: str | None) -> list[MessagePath] | None: - """Parse paths JSON string to list of MessagePath objects.""" - if not paths_json: - return None - try: - paths_data = json.loads(paths_json) - return [MessagePath(**p) for p in paths_data] - except (json.JSONDecodeError, TypeError, KeyError): - return None - - @staticmethod - async def create( - msg_type: str, - text: str, - received_at: int, - conversation_key: str, - sender_timestamp: int | None = None, - path: str | None = None, - txt_type: int = 0, - signature: str | None = None, - outgoing: bool = False, - sender_name: str | None = None, - sender_key: str | None = None, - ) -> int | None: - """Create a message, returning the ID or None if duplicate. - - Uses INSERT OR IGNORE to handle the UNIQUE constraint on - (type, conversation_key, text, sender_timestamp). This prevents - duplicate messages when the same message arrives via multiple RF paths. - - The path parameter is converted to the paths JSON array format. - """ - # Convert single path to paths array format - paths_json = None - if path is not None: - paths_json = json.dumps([{"path": path, "received_at": received_at}]) - - cursor = await db.conn.execute( - """ - INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp, - received_at, paths, txt_type, signature, outgoing, - sender_name, sender_key) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - msg_type, - conversation_key, - text, - sender_timestamp, - received_at, - paths_json, - txt_type, - signature, - outgoing, - sender_name, - sender_key, - ), - ) - await db.conn.commit() - # rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation - if cursor.rowcount == 0: - return None - return cursor.lastrowid - - @staticmethod - async def add_path( - message_id: int, path: str, received_at: int | None = None - ) -> list[MessagePath]: - """Add a new path to an existing message. - - This is used when a repeat/echo of a message arrives via a different route. - Returns the updated list of paths. - """ - ts = received_at if received_at is not None else int(time.time()) - - # Atomic append: use json_insert to avoid read-modify-write race when - # multiple duplicate packets arrive concurrently for the same message. - new_entry = json.dumps({"path": path, "received_at": ts}) - await db.conn.execute( - """UPDATE messages SET paths = json_insert( - COALESCE(paths, '[]'), '$[#]', json(?) - ) WHERE id = ?""", - (new_entry, message_id), - ) - await db.conn.commit() - - # Read back the full list for the return value - cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,)) - row = await cursor.fetchone() - if not row or not row["paths"]: - return [] - - try: - all_paths = json.loads(row["paths"]) - except json.JSONDecodeError: - return [] - - return [MessagePath(**p) for p in all_paths] - - @staticmethod - async def claim_prefix_messages(full_key: str) -> int: - """Promote prefix-stored messages to the full conversation key. - - When a full key becomes known for a contact, any messages stored with - only a prefix as conversation_key are updated to use the full key. - """ - lower_key = full_key.lower() - cursor = await db.conn.execute( - """UPDATE messages SET conversation_key = ? - WHERE type = 'PRIV' AND length(conversation_key) < 64 - AND ? LIKE conversation_key || '%' - AND ( - SELECT COUNT(*) FROM contacts - WHERE public_key LIKE messages.conversation_key || '%' - ) = 1""", - (lower_key, lower_key), - ) - await db.conn.commit() - return cursor.rowcount - - @staticmethod - async def get_all( - limit: int = 100, - offset: int = 0, - msg_type: str | None = None, - conversation_key: str | None = None, - before: int | None = None, - before_id: int | None = None, - ) -> list[Message]: - query = "SELECT * FROM messages WHERE 1=1" - params: list[Any] = [] - - if msg_type: - query += " AND type = ?" - params.append(msg_type) - if conversation_key: - normalized_key = conversation_key - # Prefer exact matching for full keys. - if len(conversation_key) == 64: - normalized_key = conversation_key.lower() - query += " AND conversation_key = ?" - params.append(normalized_key) - elif len(conversation_key) == 32: - normalized_key = conversation_key.upper() - query += " AND conversation_key = ?" - params.append(normalized_key) - else: - # Prefix match is only for legacy/partial key callers. - query += " AND conversation_key LIKE ?" - params.append(f"{conversation_key}%") - - if before is not None and before_id is not None: - query += " AND (received_at < ? OR (received_at = ? AND id < ?))" - params.extend([before, before, before_id]) - - query += " ORDER BY received_at DESC, id DESC LIMIT ?" - params.append(limit) - if before is None or before_id is None: - query += " OFFSET ?" - params.append(offset) - - cursor = await db.conn.execute(query, params) - rows = await cursor.fetchall() - return [ - Message( - id=row["id"], - type=row["type"], - conversation_key=row["conversation_key"], - text=row["text"], - sender_timestamp=row["sender_timestamp"], - received_at=row["received_at"], - paths=MessageRepository._parse_paths(row["paths"]), - txt_type=row["txt_type"], - signature=row["signature"], - outgoing=bool(row["outgoing"]), - acked=row["acked"], - ) - for row in rows - ] - - @staticmethod - async def increment_ack_count(message_id: int) -> int: - """Increment ack count and return the new value.""" - await db.conn.execute("UPDATE messages SET acked = acked + 1 WHERE id = ?", (message_id,)) - await db.conn.commit() - cursor = await db.conn.execute("SELECT acked FROM messages WHERE id = ?", (message_id,)) - row = await cursor.fetchone() - return row["acked"] if row else 1 - - @staticmethod - async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]: - """Get the current ack count and paths for a message.""" - cursor = await db.conn.execute( - "SELECT acked, paths FROM messages WHERE id = ?", (message_id,) - ) - row = await cursor.fetchone() - if not row: - return 0, None - return row["acked"], MessageRepository._parse_paths(row["paths"]) - - @staticmethod - async def get_by_id(message_id: int) -> "Message | None": - """Look up a message by its ID.""" - cursor = await db.conn.execute( - """ - SELECT id, type, conversation_key, text, sender_timestamp, received_at, - paths, txt_type, signature, outgoing, acked - FROM messages - WHERE id = ? - """, - (message_id,), - ) - row = await cursor.fetchone() - if not row: - return None - - return Message( - id=row["id"], - type=row["type"], - conversation_key=row["conversation_key"], - text=row["text"], - sender_timestamp=row["sender_timestamp"], - received_at=row["received_at"], - paths=MessageRepository._parse_paths(row["paths"]), - txt_type=row["txt_type"], - signature=row["signature"], - outgoing=bool(row["outgoing"]), - acked=row["acked"], - ) - - @staticmethod - async def get_by_content( - msg_type: str, - conversation_key: str, - text: str, - sender_timestamp: int | None, - ) -> "Message | None": - """Look up a message by its unique content fields.""" - cursor = await db.conn.execute( - """ - SELECT id, type, conversation_key, text, sender_timestamp, received_at, - paths, txt_type, signature, outgoing, acked - FROM messages - WHERE type = ? AND conversation_key = ? AND text = ? - AND (sender_timestamp = ? OR (sender_timestamp IS NULL AND ? IS NULL)) - """, - (msg_type, conversation_key, text, sender_timestamp, sender_timestamp), - ) - row = await cursor.fetchone() - if not row: - return None - - paths = None - if row["paths"]: - try: - paths_data = json.loads(row["paths"]) - paths = [ - MessagePath(path=p["path"], received_at=p["received_at"]) for p in paths_data - ] - except (json.JSONDecodeError, KeyError): - pass - - return Message( - id=row["id"], - type=row["type"], - conversation_key=row["conversation_key"], - text=row["text"], - sender_timestamp=row["sender_timestamp"], - received_at=row["received_at"], - paths=paths, - txt_type=row["txt_type"], - signature=row["signature"], - outgoing=bool(row["outgoing"]), - acked=row["acked"], - ) - - @staticmethod - async def get_unread_counts(name: str | None = None) -> dict: - """Get unread message counts, mention flags, and last message times for all conversations. - - Args: - name: User's display name for @[name] mention detection. If None, mentions are skipped. - - Returns: - Dict with 'counts', 'mentions', and 'last_message_times' keys. - """ - counts: dict[str, int] = {} - mention_flags: dict[str, bool] = {} - last_message_times: dict[str, int] = {} - - mention_token = f"@[{name}]" if name else None - - # Channel unreads - cursor = await db.conn.execute( - """ - SELECT m.conversation_key, - COUNT(*) as unread_count, - SUM(CASE - WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 - ELSE 0 - END) > 0 as has_mention - FROM messages m - JOIN channels c ON m.conversation_key = c.key - WHERE m.type = 'CHAN' AND m.outgoing = 0 - AND m.received_at > COALESCE(c.last_read_at, 0) - GROUP BY m.conversation_key - """, - (mention_token or "", mention_token or ""), - ) - rows = await cursor.fetchall() - for row in rows: - state_key = f"channel-{row['conversation_key']}" - counts[state_key] = row["unread_count"] - if mention_token and row["has_mention"]: - mention_flags[state_key] = True - - # Contact unreads - cursor = await db.conn.execute( - """ - SELECT m.conversation_key, - COUNT(*) as unread_count, - SUM(CASE - WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 - ELSE 0 - END) > 0 as has_mention - FROM messages m - JOIN contacts ct ON m.conversation_key = ct.public_key - WHERE m.type = 'PRIV' AND m.outgoing = 0 - AND m.received_at > COALESCE(ct.last_read_at, 0) - GROUP BY m.conversation_key - """, - (mention_token or "", mention_token or ""), - ) - rows = await cursor.fetchall() - for row in rows: - state_key = f"contact-{row['conversation_key']}" - counts[state_key] = row["unread_count"] - if mention_token and row["has_mention"]: - mention_flags[state_key] = True - - # Last message times for all conversations (including read ones) - cursor = await db.conn.execute( - """ - SELECT type, conversation_key, MAX(received_at) as last_message_time - FROM messages - GROUP BY type, conversation_key - """ - ) - rows = await cursor.fetchall() - for row in rows: - prefix = "channel" if row["type"] == "CHAN" else "contact" - state_key = f"{prefix}-{row['conversation_key']}" - last_message_times[state_key] = row["last_message_time"] - - return { - "counts": counts, - "mentions": mention_flags, - "last_message_times": last_message_times, - } - - @staticmethod - async def count_dm_messages(contact_key: str) -> int: - """Count total DM messages for a contact.""" - cursor = await db.conn.execute( - "SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?", - (contact_key.lower(),), - ) - row = await cursor.fetchone() - return row["cnt"] if row else 0 - - @staticmethod - async def count_channel_messages_by_sender(sender_key: str) -> int: - """Count channel messages sent by a specific contact.""" - cursor = await db.conn.execute( - "SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?", - (sender_key.lower(),), - ) - row = await cursor.fetchone() - return row["cnt"] if row else 0 - - @staticmethod - async def get_most_active_rooms(sender_key: str, limit: int = 5) -> list[tuple[str, str, int]]: - """Get channels where a contact has sent the most messages. - - Returns list of (channel_key, channel_name, message_count) tuples. - """ - cursor = await db.conn.execute( - """ - SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) as channel_name, - COUNT(*) as cnt - FROM messages m - LEFT JOIN channels c ON m.conversation_key = c.key - WHERE m.type = 'CHAN' AND m.sender_key = ? - GROUP BY m.conversation_key - ORDER BY cnt DESC - LIMIT ? - """, - (sender_key.lower(), limit), - ) - rows = await cursor.fetchall() - return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows] - - -class RawPacketRepository: - @staticmethod - async def create(data: bytes, timestamp: int | None = None) -> tuple[int, bool]: - """ - Create a raw packet with payload-based deduplication. - - Returns (packet_id, is_new) tuple: - - is_new=True: New packet stored, packet_id is the new row ID - - is_new=False: Duplicate payload detected, packet_id is the existing row ID - - Deduplication is based on the SHA-256 hash of the packet payload - (excluding routing/path information). - """ - ts = timestamp if timestamp is not None else int(time.time()) - - # Compute payload hash for deduplication - payload = extract_payload(data) - if payload: - payload_hash = sha256(payload).digest() - else: - # For malformed packets, hash the full data - payload_hash = sha256(data).digest() - - # Check if this payload already exists - cursor = await db.conn.execute( - "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) - ) - existing = await cursor.fetchone() - - if existing: - # Duplicate - return existing packet ID - logger.debug( - "Duplicate payload detected (hash=%s..., existing_id=%d)", - payload_hash.hex()[:12], - existing["id"], - ) - return (existing["id"], False) - - # New packet - insert with hash - try: - cursor = await db.conn.execute( - "INSERT INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)", - (ts, data, payload_hash), - ) - await db.conn.commit() - assert cursor.lastrowid is not None # INSERT always returns a row ID - return (cursor.lastrowid, True) - except sqlite3.IntegrityError: - # Race condition: another insert with same payload_hash happened between - # our SELECT and INSERT. This is expected for duplicate packets arriving - # close together. Query again to get the existing ID. - logger.debug( - "Duplicate packet detected via race condition (payload_hash=%s), dropping", - payload_hash.hex()[:16], - ) - cursor = await db.conn.execute( - "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) - ) - existing = await cursor.fetchone() - if existing: - return (existing["id"], False) - # This shouldn't happen, but if it does, re-raise - raise - - @staticmethod - async def get_undecrypted_count() -> int: - """Get count of undecrypted packets (those without a linked message).""" - cursor = await db.conn.execute( - "SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL" - ) - row = await cursor.fetchone() - return row["count"] if row else 0 - - @staticmethod - async def get_oldest_undecrypted() -> int | None: - """Get timestamp of oldest undecrypted packet, or None if none exist.""" - cursor = await db.conn.execute( - "SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL" - ) - row = await cursor.fetchone() - return row["oldest"] if row and row["oldest"] is not None else None - - @staticmethod - async def get_all_undecrypted() -> list[tuple[int, bytes, int]]: - """Get all undecrypted packets as (id, data, timestamp) tuples.""" - cursor = await db.conn.execute( - "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" - ) - rows = await cursor.fetchall() - return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows] - - @staticmethod - async def mark_decrypted(packet_id: int, message_id: int) -> None: - """Link a raw packet to its decrypted message.""" - await db.conn.execute( - "UPDATE raw_packets SET message_id = ? WHERE id = ?", - (message_id, packet_id), - ) - await db.conn.commit() - - @staticmethod - async def prune_old_undecrypted(max_age_days: int) -> int: - """Delete undecrypted packets older than max_age_days. Returns count deleted.""" - cutoff = int(time.time()) - (max_age_days * 86400) - cursor = await db.conn.execute( - "DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?", - (cutoff,), - ) - await db.conn.commit() - return cursor.rowcount - - @staticmethod - async def purge_linked_to_messages() -> int: - """Delete raw packets that are already linked to a stored message.""" - cursor = await db.conn.execute("DELETE FROM raw_packets WHERE message_id IS NOT NULL") - await db.conn.commit() - return cursor.rowcount - - @staticmethod - async def get_undecrypted_text_messages() -> list[tuple[int, bytes, int]]: - """Get all undecrypted TEXT_MESSAGE packets as (id, data, timestamp) tuples. - - Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02). - These are direct messages that can be decrypted with contact ECDH keys. - """ - cursor = await db.conn.execute( - "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" - ) - rows = await cursor.fetchall() - - # Filter for TEXT_MESSAGE packets - result = [] - for row in rows: - data = bytes(row["data"]) - payload_type = get_packet_payload_type(data) - if payload_type == PayloadType.TEXT_MESSAGE: - result.append((row["id"], data, row["timestamp"])) - - return result - - -class AppSettingsRepository: - """Repository for app_settings table (single-row pattern).""" - - @staticmethod - async def get() -> AppSettings: - """Get the current app settings. - - Always returns settings - creates default row if needed (migration handles initial row). - """ - cursor = await db.conn.execute( - """ - SELECT max_radio_contacts, favorites, auto_decrypt_dm_on_advert, - sidebar_sort_order, last_message_times, preferences_migrated, - advert_interval, last_advert_time, bots - FROM app_settings WHERE id = 1 - """ - ) - row = await cursor.fetchone() - - if not row: - # Should not happen after migration, but handle gracefully - return AppSettings() - - # Parse favorites JSON - favorites = [] - if row["favorites"]: - try: - favorites_data = json.loads(row["favorites"]) - favorites = [Favorite(**f) for f in favorites_data] - except (json.JSONDecodeError, TypeError, KeyError) as e: - logger.warning( - "Failed to parse favorites JSON, using empty list: %s (data=%r)", - e, - row["favorites"][:100] if row["favorites"] else None, - ) - favorites = [] - - # Parse last_message_times JSON - last_message_times: dict[str, int] = {} - if row["last_message_times"]: - try: - last_message_times = json.loads(row["last_message_times"]) - except (json.JSONDecodeError, TypeError) as e: - logger.warning( - "Failed to parse last_message_times JSON, using empty dict: %s", - e, - ) - last_message_times = {} - - # Parse bots JSON - bots: list[BotConfig] = [] - if row["bots"]: - try: - bots_data = json.loads(row["bots"]) - bots = [BotConfig(**b) for b in bots_data] - except (json.JSONDecodeError, TypeError, KeyError) as e: - logger.warning( - "Failed to parse bots JSON, using empty list: %s (data=%r)", - e, - row["bots"][:100] if row["bots"] else None, - ) - bots = [] - - # Validate sidebar_sort_order (fallback to "recent" if invalid) - sort_order = row["sidebar_sort_order"] - if sort_order not in ("recent", "alpha"): - sort_order = "recent" - - return AppSettings( - max_radio_contacts=row["max_radio_contacts"], - favorites=favorites, - auto_decrypt_dm_on_advert=bool(row["auto_decrypt_dm_on_advert"]), - sidebar_sort_order=sort_order, - last_message_times=last_message_times, - preferences_migrated=bool(row["preferences_migrated"]), - advert_interval=row["advert_interval"] or 0, - last_advert_time=row["last_advert_time"] or 0, - bots=bots, - ) - - @staticmethod - async def update( - max_radio_contacts: int | None = None, - favorites: list[Favorite] | None = None, - auto_decrypt_dm_on_advert: bool | None = None, - sidebar_sort_order: str | None = None, - last_message_times: dict[str, int] | None = None, - preferences_migrated: bool | None = None, - advert_interval: int | None = None, - last_advert_time: int | None = None, - bots: list[BotConfig] | None = None, - ) -> AppSettings: - """Update app settings. Only provided fields are updated.""" - updates = [] - params: list[Any] = [] - - if max_radio_contacts is not None: - updates.append("max_radio_contacts = ?") - params.append(max_radio_contacts) - - if favorites is not None: - updates.append("favorites = ?") - favorites_json = json.dumps([f.model_dump() for f in favorites]) - params.append(favorites_json) - - if auto_decrypt_dm_on_advert is not None: - updates.append("auto_decrypt_dm_on_advert = ?") - params.append(1 if auto_decrypt_dm_on_advert else 0) - - if sidebar_sort_order is not None: - updates.append("sidebar_sort_order = ?") - params.append(sidebar_sort_order) - - if last_message_times is not None: - updates.append("last_message_times = ?") - params.append(json.dumps(last_message_times)) - - if preferences_migrated is not None: - updates.append("preferences_migrated = ?") - params.append(1 if preferences_migrated else 0) - - if advert_interval is not None: - updates.append("advert_interval = ?") - params.append(advert_interval) - - if last_advert_time is not None: - updates.append("last_advert_time = ?") - params.append(last_advert_time) - - if bots is not None: - updates.append("bots = ?") - bots_json = json.dumps([b.model_dump() for b in bots]) - params.append(bots_json) - - if updates: - query = f"UPDATE app_settings SET {', '.join(updates)} WHERE id = 1" - await db.conn.execute(query, params) - await db.conn.commit() - - return await AppSettingsRepository.get() - - @staticmethod - async def add_favorite(fav_type: Literal["channel", "contact"], fav_id: str) -> AppSettings: - """Add a favorite, avoiding duplicates.""" - settings = await AppSettingsRepository.get() - - # Check if already favorited - if any(f.type == fav_type and f.id == fav_id for f in settings.favorites): - return settings - - new_favorites = settings.favorites + [Favorite(type=fav_type, id=fav_id)] - return await AppSettingsRepository.update(favorites=new_favorites) - - @staticmethod - async def remove_favorite(fav_type: Literal["channel", "contact"], fav_id: str) -> AppSettings: - """Remove a favorite.""" - settings = await AppSettingsRepository.get() - new_favorites = [ - f for f in settings.favorites if not (f.type == fav_type and f.id == fav_id) - ] - return await AppSettingsRepository.update(favorites=new_favorites) - - @staticmethod - async def migrate_preferences_from_frontend( - favorites: list[dict], - sort_order: str, - last_message_times: dict[str, int], - ) -> tuple[AppSettings, bool]: - """Migrate all preferences from frontend localStorage. - - This is a one-time migration. If already migrated, returns current settings - without overwriting. Returns (settings, did_migrate) tuple. - """ - settings = await AppSettingsRepository.get() - - if settings.preferences_migrated: - # Already migrated, don't overwrite - return settings, False - - # Convert frontend favorites format to Favorite objects - new_favorites = [] - for f in favorites: - if f.get("type") in ("channel", "contact") and f.get("id"): - new_favorites.append(Favorite(type=f["type"], id=f["id"])) - - # Update with migrated preferences and mark as migrated - settings = await AppSettingsRepository.update( - favorites=new_favorites, - sidebar_sort_order=sort_order if sort_order in ("recent", "alpha") else "recent", - last_message_times=last_message_times, - preferences_migrated=True, - ) - - return settings, True - - -class StatisticsRepository: - @staticmethod - async def _activity_counts(*, contact_type: int, exclude: bool = False) -> dict[str, int]: - """Get time-windowed counts for contacts/repeaters heard.""" - now = int(time.time()) - op = "!=" if exclude else "=" - cursor = await db.conn.execute( - f""" - SELECT - SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour, - SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours, - SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week - FROM contacts - WHERE type {op} ? AND last_seen IS NOT NULL - """, - (now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type), - ) - row = await cursor.fetchone() - assert row is not None # Aggregate query always returns a row - return { - "last_hour": row["last_hour"] or 0, - "last_24_hours": row["last_24_hours"] or 0, - "last_week": row["last_week"] or 0, - } - - @staticmethod - async def get_all() -> dict: - """Aggregate all statistics from existing tables.""" - now = int(time.time()) - - # Top 5 busiest channels in last 24h - cursor = await db.conn.execute( - """ - SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, - COUNT(*) AS message_count - FROM messages m - LEFT JOIN channels c ON m.conversation_key = c.key - WHERE m.type = 'CHAN' AND m.received_at >= ? - GROUP BY m.conversation_key - ORDER BY COUNT(*) DESC - LIMIT 5 - """, - (now - SECONDS_24H,), - ) - rows = await cursor.fetchall() - busiest_channels_24h = [ - { - "channel_key": row["conversation_key"], - "channel_name": row["channel_name"], - "message_count": row["message_count"], - } - for row in rows - ] - - # Entity counts - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2") - row = await cursor.fetchone() - assert row is not None - contact_count: int = row["cnt"] - - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2") - row = await cursor.fetchone() - assert row is not None - repeater_count: int = row["cnt"] - - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM channels") - row = await cursor.fetchone() - assert row is not None - channel_count: int = row["cnt"] - - # Packet split - cursor = await db.conn.execute( - """ - SELECT COUNT(*) AS total, - SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted - FROM raw_packets - """ - ) - pkt_row = await cursor.fetchone() - assert pkt_row is not None - total_packets = pkt_row["total"] or 0 - decrypted_packets = pkt_row["decrypted"] or 0 - undecrypted_packets = total_packets - decrypted_packets - - # Message type counts - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'PRIV'") - row = await cursor.fetchone() - assert row is not None - total_dms: int = row["cnt"] - - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'CHAN'") - row = await cursor.fetchone() - assert row is not None - total_channel_messages: int = row["cnt"] - - # Outgoing count - cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE outgoing = 1") - row = await cursor.fetchone() - assert row is not None - total_outgoing: int = row["cnt"] - - # Activity windows - contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True) - repeaters_heard = await StatisticsRepository._activity_counts(contact_type=2) - - return { - "busiest_channels_24h": busiest_channels_24h, - "contact_count": contact_count, - "repeater_count": repeater_count, - "channel_count": channel_count, - "total_packets": total_packets, - "decrypted_packets": decrypted_packets, - "undecrypted_packets": undecrypted_packets, - "total_dms": total_dms, - "total_channel_messages": total_channel_messages, - "total_outgoing": total_outgoing, - "contacts_heard": contacts_heard, - "repeaters_heard": repeaters_heard, - } diff --git a/app/repository/__init__.py b/app/repository/__init__.py new file mode 100644 index 0000000..0058956 --- /dev/null +++ b/app/repository/__init__.py @@ -0,0 +1,22 @@ +from app.repository.channels import ChannelRepository +from app.repository.contacts import ( + AmbiguousPublicKeyPrefixError, + ContactAdvertPathRepository, + ContactNameHistoryRepository, + ContactRepository, +) +from app.repository.messages import MessageRepository +from app.repository.raw_packets import RawPacketRepository +from app.repository.settings import AppSettingsRepository, StatisticsRepository + +__all__ = [ + "AmbiguousPublicKeyPrefixError", + "AppSettingsRepository", + "ChannelRepository", + "ContactAdvertPathRepository", + "ContactNameHistoryRepository", + "ContactRepository", + "MessageRepository", + "RawPacketRepository", + "StatisticsRepository", +] diff --git a/app/repository/channels.py b/app/repository/channels.py new file mode 100644 index 0000000..c83bac5 --- /dev/null +++ b/app/repository/channels.py @@ -0,0 +1,86 @@ +import time + +from app.database import db +from app.models import Channel + + +class ChannelRepository: + @staticmethod + async def upsert(key: str, name: str, is_hashtag: bool = False, on_radio: bool = False) -> None: + """Upsert a channel. Key is 32-char hex string.""" + await db.conn.execute( + """ + INSERT INTO channels (key, name, is_hashtag, on_radio) + VALUES (?, ?, ?, ?) + ON CONFLICT(key) DO UPDATE SET + name = excluded.name, + is_hashtag = excluded.is_hashtag, + on_radio = excluded.on_radio + """, + (key.upper(), name, is_hashtag, on_radio), + ) + await db.conn.commit() + + @staticmethod + async def get_by_key(key: str) -> Channel | None: + """Get a channel by its key (32-char hex string).""" + cursor = await db.conn.execute( + "SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels WHERE key = ?", + (key.upper(),), + ) + row = await cursor.fetchone() + if row: + return Channel( + key=row["key"], + name=row["name"], + is_hashtag=bool(row["is_hashtag"]), + on_radio=bool(row["on_radio"]), + last_read_at=row["last_read_at"], + ) + return None + + @staticmethod + async def get_all() -> list[Channel]: + cursor = await db.conn.execute( + "SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels ORDER BY name" + ) + rows = await cursor.fetchall() + return [ + Channel( + key=row["key"], + name=row["name"], + is_hashtag=bool(row["is_hashtag"]), + on_radio=bool(row["on_radio"]), + last_read_at=row["last_read_at"], + ) + for row in rows + ] + + @staticmethod + async def delete(key: str) -> None: + """Delete a channel by key.""" + await db.conn.execute( + "DELETE FROM channels WHERE key = ?", + (key.upper(),), + ) + await db.conn.commit() + + @staticmethod + async def update_last_read_at(key: str, timestamp: int | None = None) -> bool: + """Update the last_read_at timestamp for a channel. + + Returns True if a row was updated, False if channel not found. + """ + ts = timestamp if timestamp is not None else int(time.time()) + cursor = await db.conn.execute( + "UPDATE channels SET last_read_at = ? WHERE key = ?", + (ts, key.upper()), + ) + await db.conn.commit() + return cursor.rowcount > 0 + + @staticmethod + async def mark_all_read(timestamp: int) -> None: + """Mark all channels as read at the given timestamp.""" + await db.conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,)) + await db.conn.commit() diff --git a/app/repository/contacts.py b/app/repository/contacts.py new file mode 100644 index 0000000..3af13c7 --- /dev/null +++ b/app/repository/contacts.py @@ -0,0 +1,412 @@ +import time +from typing import Any + +from app.database import db +from app.models import ( + Contact, + ContactAdvertPath, + ContactAdvertPathSummary, + ContactNameHistory, +) + + +class AmbiguousPublicKeyPrefixError(ValueError): + """Raised when a public key prefix matches multiple contacts.""" + + def __init__(self, prefix: str, matches: list[str]): + self.prefix = prefix.lower() + self.matches = matches + super().__init__(f"Ambiguous public key prefix '{self.prefix}'") + + +class ContactRepository: + @staticmethod + async def upsert(contact: dict[str, Any]) -> None: + await db.conn.execute( + """ + INSERT INTO contacts (public_key, name, type, flags, last_path, last_path_len, + last_advert, lat, lon, last_seen, on_radio, last_contacted, + first_seen) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(public_key) DO UPDATE SET + name = COALESCE(excluded.name, contacts.name), + type = CASE WHEN excluded.type = 0 THEN contacts.type ELSE excluded.type END, + flags = excluded.flags, + last_path = COALESCE(excluded.last_path, contacts.last_path), + last_path_len = excluded.last_path_len, + last_advert = COALESCE(excluded.last_advert, contacts.last_advert), + lat = COALESCE(excluded.lat, contacts.lat), + lon = COALESCE(excluded.lon, contacts.lon), + last_seen = excluded.last_seen, + on_radio = COALESCE(excluded.on_radio, contacts.on_radio), + last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted), + first_seen = COALESCE(contacts.first_seen, excluded.first_seen) + """, + ( + contact.get("public_key", "").lower(), + contact.get("name"), + contact.get("type", 0), + contact.get("flags", 0), + contact.get("last_path"), + contact.get("last_path_len", -1), + contact.get("last_advert"), + contact.get("lat"), + contact.get("lon"), + contact.get("last_seen", int(time.time())), + contact.get("on_radio"), + contact.get("last_contacted"), + contact.get("first_seen"), + ), + ) + await db.conn.commit() + + @staticmethod + def _row_to_contact(row) -> Contact: + """Convert a database row to a Contact model.""" + return Contact( + public_key=row["public_key"], + name=row["name"], + type=row["type"], + flags=row["flags"], + last_path=row["last_path"], + last_path_len=row["last_path_len"], + last_advert=row["last_advert"], + lat=row["lat"], + lon=row["lon"], + last_seen=row["last_seen"], + on_radio=bool(row["on_radio"]), + last_contacted=row["last_contacted"], + last_read_at=row["last_read_at"], + first_seen=row["first_seen"], + ) + + @staticmethod + async def get_by_key(public_key: str) -> Contact | None: + cursor = await db.conn.execute( + "SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),) + ) + row = await cursor.fetchone() + return ContactRepository._row_to_contact(row) if row else None + + @staticmethod + async def get_by_key_prefix(prefix: str) -> Contact | None: + """Get a contact by key prefix only if it resolves uniquely. + + Returns None when no contacts match OR when multiple contacts match + the prefix (to avoid silently selecting the wrong contact). + """ + normalized_prefix = prefix.lower() + cursor = await db.conn.execute( + "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2", + (f"{normalized_prefix}%",), + ) + rows = list(await cursor.fetchall()) + if len(rows) != 1: + return None + return ContactRepository._row_to_contact(rows[0]) + + @staticmethod + async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]: + """Get contacts matching a key prefix, up to limit.""" + cursor = await db.conn.execute( + "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?", + (f"{prefix.lower()}%", limit), + ) + rows = list(await cursor.fetchall()) + return [ContactRepository._row_to_contact(row) for row in rows] + + @staticmethod + async def get_by_key_or_prefix(key_or_prefix: str) -> Contact | None: + """Get a contact by exact key match, falling back to prefix match. + + Useful when the input might be a full 64-char public key or a shorter prefix. + """ + contact = await ContactRepository.get_by_key(key_or_prefix) + if contact: + return contact + + matches = await ContactRepository._get_prefix_matches(key_or_prefix, limit=2) + if len(matches) == 1: + return matches[0] + if len(matches) > 1: + raise AmbiguousPublicKeyPrefixError( + key_or_prefix, + [m.public_key for m in matches], + ) + return None + + @staticmethod + async def get_by_name(name: str) -> list[Contact]: + """Get all contacts with the given exact name.""" + cursor = await db.conn.execute("SELECT * FROM contacts WHERE name = ?", (name,)) + rows = await cursor.fetchall() + return [ContactRepository._row_to_contact(row) for row in rows] + + @staticmethod + async def resolve_prefixes(prefixes: list[str]) -> dict[str, Contact]: + """Resolve multiple key prefixes to contacts in a single query. + + Returns a dict mapping each prefix to its Contact, only for prefixes + that resolve uniquely (exactly one match). Ambiguous or unmatched + prefixes are omitted. + """ + if not prefixes: + return {} + normalized = [p.lower() for p in prefixes] + conditions = " OR ".join(["public_key LIKE ?"] * len(normalized)) + params = [f"{p}%" for p in normalized] + cursor = await db.conn.execute(f"SELECT * FROM contacts WHERE {conditions}", params) + rows = await cursor.fetchall() + # Group by which prefix each row matches + prefix_to_rows: dict[str, list] = {p: [] for p in normalized} + for row in rows: + pk = row["public_key"] + for p in normalized: + if pk.startswith(p): + prefix_to_rows[p].append(row) + # Only include uniquely-resolved prefixes + result: dict[str, Contact] = {} + for p in normalized: + if len(prefix_to_rows[p]) == 1: + result[p] = ContactRepository._row_to_contact(prefix_to_rows[p][0]) + return result + + @staticmethod + async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]: + cursor = await db.conn.execute( + "SELECT * FROM contacts ORDER BY COALESCE(name, public_key) LIMIT ? OFFSET ?", + (limit, offset), + ) + rows = await cursor.fetchall() + return [ContactRepository._row_to_contact(row) for row in rows] + + @staticmethod + async def get_recent_non_repeaters(limit: int = 200) -> list[Contact]: + """Get the most recently active non-repeater contacts. + + Orders by most recent activity (last_contacted or last_advert), + excluding repeaters (type=2). + """ + cursor = await db.conn.execute( + """ + SELECT * FROM contacts + WHERE type != 2 + ORDER BY COALESCE(last_contacted, 0) DESC, COALESCE(last_advert, 0) DESC + LIMIT ? + """, + (limit,), + ) + rows = await cursor.fetchall() + return [ContactRepository._row_to_contact(row) for row in rows] + + @staticmethod + async def update_path(public_key: str, path: str, path_len: int) -> None: + await db.conn.execute( + "UPDATE contacts SET last_path = ?, last_path_len = ?, last_seen = ? WHERE public_key = ?", + (path, path_len, int(time.time()), public_key.lower()), + ) + await db.conn.commit() + + @staticmethod + async def set_on_radio(public_key: str, on_radio: bool) -> None: + await db.conn.execute( + "UPDATE contacts SET on_radio = ? WHERE public_key = ?", + (on_radio, public_key.lower()), + ) + await db.conn.commit() + + @staticmethod + async def delete(public_key: str) -> None: + normalized = public_key.lower() + await db.conn.execute( + "DELETE FROM contact_name_history WHERE public_key = ?", (normalized,) + ) + await db.conn.execute( + "DELETE FROM contact_advert_paths WHERE public_key = ?", (normalized,) + ) + await db.conn.execute("DELETE FROM contacts WHERE public_key = ?", (normalized,)) + await db.conn.commit() + + @staticmethod + async def update_last_contacted(public_key: str, timestamp: int | None = None) -> None: + """Update the last_contacted timestamp for a contact.""" + ts = timestamp if timestamp is not None else int(time.time()) + await db.conn.execute( + "UPDATE contacts SET last_contacted = ?, last_seen = ? WHERE public_key = ?", + (ts, ts, public_key.lower()), + ) + await db.conn.commit() + + @staticmethod + async def update_last_read_at(public_key: str, timestamp: int | None = None) -> bool: + """Update the last_read_at timestamp for a contact. + + Returns True if a row was updated, False if contact not found. + """ + ts = timestamp if timestamp is not None else int(time.time()) + cursor = await db.conn.execute( + "UPDATE contacts SET last_read_at = ? WHERE public_key = ?", + (ts, public_key.lower()), + ) + await db.conn.commit() + return cursor.rowcount > 0 + + @staticmethod + async def mark_all_read(timestamp: int) -> None: + """Mark all contacts as read at the given timestamp.""" + await db.conn.execute("UPDATE contacts SET last_read_at = ?", (timestamp,)) + await db.conn.commit() + + @staticmethod + async def get_by_pubkey_first_byte(hex_byte: str) -> list[Contact]: + """Get contacts whose public key starts with the given hex byte (2 chars).""" + cursor = await db.conn.execute( + "SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?", + (hex_byte.lower(),), + ) + rows = await cursor.fetchall() + return [ContactRepository._row_to_contact(row) for row in rows] + + +class ContactAdvertPathRepository: + """Repository for recent unique advertisement paths per contact.""" + + @staticmethod + def _row_to_path(row) -> ContactAdvertPath: + path = row["path_hex"] or "" + next_hop = path[:2].lower() if len(path) >= 2 else None + return ContactAdvertPath( + path=path, + path_len=row["path_len"], + next_hop=next_hop, + first_seen=row["first_seen"], + last_seen=row["last_seen"], + heard_count=row["heard_count"], + ) + + @staticmethod + async def record_observation( + public_key: str, + path_hex: str, + timestamp: int, + max_paths: int = 10, + ) -> None: + """ + Upsert a unique advert path observation for a contact and prune to N most recent. + """ + if max_paths < 1: + max_paths = 1 + + normalized_key = public_key.lower() + normalized_path = path_hex.lower() + path_len = len(normalized_path) // 2 + + await db.conn.execute( + """ + INSERT INTO contact_advert_paths + (public_key, path_hex, path_len, first_seen, last_seen, heard_count) + VALUES (?, ?, ?, ?, ?, 1) + ON CONFLICT(public_key, path_hex) DO UPDATE SET + last_seen = MAX(contact_advert_paths.last_seen, excluded.last_seen), + path_len = excluded.path_len, + heard_count = contact_advert_paths.heard_count + 1 + """, + (normalized_key, normalized_path, path_len, timestamp, timestamp), + ) + + # Keep only the N most recent unique paths per contact. + await db.conn.execute( + """ + DELETE FROM contact_advert_paths + WHERE public_key = ? + AND path_hex NOT IN ( + SELECT path_hex + FROM contact_advert_paths + WHERE public_key = ? + ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC + LIMIT ? + ) + """, + (normalized_key, normalized_key, max_paths), + ) + await db.conn.commit() + + @staticmethod + async def get_recent_for_contact(public_key: str, limit: int = 10) -> list[ContactAdvertPath]: + cursor = await db.conn.execute( + """ + SELECT path_hex, path_len, first_seen, last_seen, heard_count + FROM contact_advert_paths + WHERE public_key = ? + ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC + LIMIT ? + """, + (public_key.lower(), limit), + ) + rows = await cursor.fetchall() + return [ContactAdvertPathRepository._row_to_path(row) for row in rows] + + @staticmethod + async def get_recent_for_all_contacts( + limit_per_contact: int = 10, + ) -> list[ContactAdvertPathSummary]: + cursor = await db.conn.execute( + """ + SELECT public_key, path_hex, path_len, first_seen, last_seen, heard_count + FROM contact_advert_paths + ORDER BY public_key ASC, last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC + """ + ) + rows = await cursor.fetchall() + + grouped: dict[str, list[ContactAdvertPath]] = {} + for row in rows: + key = row["public_key"] + paths = grouped.get(key) + if paths is None: + paths = [] + grouped[key] = paths + if len(paths) >= limit_per_contact: + continue + paths.append(ContactAdvertPathRepository._row_to_path(row)) + + return [ + ContactAdvertPathSummary(public_key=key, paths=paths) for key, paths in grouped.items() + ] + + +class ContactNameHistoryRepository: + """Repository for contact name change history.""" + + @staticmethod + async def record_name(public_key: str, name: str, timestamp: int) -> None: + """Record a name observation. Upserts: updates last_seen if name already known.""" + await db.conn.execute( + """ + INSERT INTO contact_name_history (public_key, name, first_seen, last_seen) + VALUES (?, ?, ?, ?) + ON CONFLICT(public_key, name) DO UPDATE SET + last_seen = MAX(contact_name_history.last_seen, excluded.last_seen) + """, + (public_key.lower(), name, timestamp, timestamp), + ) + await db.conn.commit() + + @staticmethod + async def get_history(public_key: str) -> list[ContactNameHistory]: + cursor = await db.conn.execute( + """ + SELECT name, first_seen, last_seen + FROM contact_name_history + WHERE public_key = ? + ORDER BY last_seen DESC + """, + (public_key.lower(),), + ) + rows = await cursor.fetchall() + return [ + ContactNameHistory( + name=row["name"], first_seen=row["first_seen"], last_seen=row["last_seen"] + ) + for row in rows + ] diff --git a/app/repository/messages.py b/app/repository/messages.py new file mode 100644 index 0000000..279bd17 --- /dev/null +++ b/app/repository/messages.py @@ -0,0 +1,411 @@ +import json +import time +from typing import Any + +from app.database import db +from app.models import Message, MessagePath + + +class MessageRepository: + @staticmethod + def _parse_paths(paths_json: str | None) -> list[MessagePath] | None: + """Parse paths JSON string to list of MessagePath objects.""" + if not paths_json: + return None + try: + paths_data = json.loads(paths_json) + return [MessagePath(**p) for p in paths_data] + except (json.JSONDecodeError, TypeError, KeyError): + return None + + @staticmethod + async def create( + msg_type: str, + text: str, + received_at: int, + conversation_key: str, + sender_timestamp: int | None = None, + path: str | None = None, + txt_type: int = 0, + signature: str | None = None, + outgoing: bool = False, + sender_name: str | None = None, + sender_key: str | None = None, + ) -> int | None: + """Create a message, returning the ID or None if duplicate. + + Uses INSERT OR IGNORE to handle the UNIQUE constraint on + (type, conversation_key, text, sender_timestamp). This prevents + duplicate messages when the same message arrives via multiple RF paths. + + The path parameter is converted to the paths JSON array format. + """ + # Convert single path to paths array format + paths_json = None + if path is not None: + paths_json = json.dumps([{"path": path, "received_at": received_at}]) + + cursor = await db.conn.execute( + """ + INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp, + received_at, paths, txt_type, signature, outgoing, + sender_name, sender_key) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + msg_type, + conversation_key, + text, + sender_timestamp, + received_at, + paths_json, + txt_type, + signature, + outgoing, + sender_name, + sender_key, + ), + ) + await db.conn.commit() + # rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation + if cursor.rowcount == 0: + return None + return cursor.lastrowid + + @staticmethod + async def add_path( + message_id: int, path: str, received_at: int | None = None + ) -> list[MessagePath]: + """Add a new path to an existing message. + + This is used when a repeat/echo of a message arrives via a different route. + Returns the updated list of paths. + """ + ts = received_at if received_at is not None else int(time.time()) + + # Atomic append: use json_insert to avoid read-modify-write race when + # multiple duplicate packets arrive concurrently for the same message. + new_entry = json.dumps({"path": path, "received_at": ts}) + await db.conn.execute( + """UPDATE messages SET paths = json_insert( + COALESCE(paths, '[]'), '$[#]', json(?) + ) WHERE id = ?""", + (new_entry, message_id), + ) + await db.conn.commit() + + # Read back the full list for the return value + cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,)) + row = await cursor.fetchone() + if not row or not row["paths"]: + return [] + + try: + all_paths = json.loads(row["paths"]) + except json.JSONDecodeError: + return [] + + return [MessagePath(**p) for p in all_paths] + + @staticmethod + async def claim_prefix_messages(full_key: str) -> int: + """Promote prefix-stored messages to the full conversation key. + + When a full key becomes known for a contact, any messages stored with + only a prefix as conversation_key are updated to use the full key. + """ + lower_key = full_key.lower() + cursor = await db.conn.execute( + """UPDATE messages SET conversation_key = ? + WHERE type = 'PRIV' AND length(conversation_key) < 64 + AND ? LIKE conversation_key || '%' + AND ( + SELECT COUNT(*) FROM contacts + WHERE public_key LIKE messages.conversation_key || '%' + ) = 1""", + (lower_key, lower_key), + ) + await db.conn.commit() + return cursor.rowcount + + @staticmethod + async def get_all( + limit: int = 100, + offset: int = 0, + msg_type: str | None = None, + conversation_key: str | None = None, + before: int | None = None, + before_id: int | None = None, + ) -> list[Message]: + query = "SELECT * FROM messages WHERE 1=1" + params: list[Any] = [] + + if msg_type: + query += " AND type = ?" + params.append(msg_type) + if conversation_key: + normalized_key = conversation_key + # Prefer exact matching for full keys. + if len(conversation_key) == 64: + normalized_key = conversation_key.lower() + query += " AND conversation_key = ?" + params.append(normalized_key) + elif len(conversation_key) == 32: + normalized_key = conversation_key.upper() + query += " AND conversation_key = ?" + params.append(normalized_key) + else: + # Prefix match is only for legacy/partial key callers. + query += " AND conversation_key LIKE ?" + params.append(f"{conversation_key}%") + + if before is not None and before_id is not None: + query += " AND (received_at < ? OR (received_at = ? AND id < ?))" + params.extend([before, before, before_id]) + + query += " ORDER BY received_at DESC, id DESC LIMIT ?" + params.append(limit) + if before is None or before_id is None: + query += " OFFSET ?" + params.append(offset) + + cursor = await db.conn.execute(query, params) + rows = await cursor.fetchall() + return [ + Message( + id=row["id"], + type=row["type"], + conversation_key=row["conversation_key"], + text=row["text"], + sender_timestamp=row["sender_timestamp"], + received_at=row["received_at"], + paths=MessageRepository._parse_paths(row["paths"]), + txt_type=row["txt_type"], + signature=row["signature"], + outgoing=bool(row["outgoing"]), + acked=row["acked"], + ) + for row in rows + ] + + @staticmethod + async def increment_ack_count(message_id: int) -> int: + """Increment ack count and return the new value.""" + await db.conn.execute("UPDATE messages SET acked = acked + 1 WHERE id = ?", (message_id,)) + await db.conn.commit() + cursor = await db.conn.execute("SELECT acked FROM messages WHERE id = ?", (message_id,)) + row = await cursor.fetchone() + return row["acked"] if row else 1 + + @staticmethod + async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]: + """Get the current ack count and paths for a message.""" + cursor = await db.conn.execute( + "SELECT acked, paths FROM messages WHERE id = ?", (message_id,) + ) + row = await cursor.fetchone() + if not row: + return 0, None + return row["acked"], MessageRepository._parse_paths(row["paths"]) + + @staticmethod + async def get_by_id(message_id: int) -> "Message | None": + """Look up a message by its ID.""" + cursor = await db.conn.execute( + """ + SELECT id, type, conversation_key, text, sender_timestamp, received_at, + paths, txt_type, signature, outgoing, acked + FROM messages + WHERE id = ? + """, + (message_id,), + ) + row = await cursor.fetchone() + if not row: + return None + + return Message( + id=row["id"], + type=row["type"], + conversation_key=row["conversation_key"], + text=row["text"], + sender_timestamp=row["sender_timestamp"], + received_at=row["received_at"], + paths=MessageRepository._parse_paths(row["paths"]), + txt_type=row["txt_type"], + signature=row["signature"], + outgoing=bool(row["outgoing"]), + acked=row["acked"], + ) + + @staticmethod + async def get_by_content( + msg_type: str, + conversation_key: str, + text: str, + sender_timestamp: int | None, + ) -> "Message | None": + """Look up a message by its unique content fields.""" + cursor = await db.conn.execute( + """ + SELECT id, type, conversation_key, text, sender_timestamp, received_at, + paths, txt_type, signature, outgoing, acked + FROM messages + WHERE type = ? AND conversation_key = ? AND text = ? + AND (sender_timestamp = ? OR (sender_timestamp IS NULL AND ? IS NULL)) + """, + (msg_type, conversation_key, text, sender_timestamp, sender_timestamp), + ) + row = await cursor.fetchone() + if not row: + return None + + paths = None + if row["paths"]: + try: + paths_data = json.loads(row["paths"]) + paths = [ + MessagePath(path=p["path"], received_at=p["received_at"]) for p in paths_data + ] + except (json.JSONDecodeError, KeyError): + pass + + return Message( + id=row["id"], + type=row["type"], + conversation_key=row["conversation_key"], + text=row["text"], + sender_timestamp=row["sender_timestamp"], + received_at=row["received_at"], + paths=paths, + txt_type=row["txt_type"], + signature=row["signature"], + outgoing=bool(row["outgoing"]), + acked=row["acked"], + ) + + @staticmethod + async def get_unread_counts(name: str | None = None) -> dict: + """Get unread message counts, mention flags, and last message times for all conversations. + + Args: + name: User's display name for @[name] mention detection. If None, mentions are skipped. + + Returns: + Dict with 'counts', 'mentions', and 'last_message_times' keys. + """ + counts: dict[str, int] = {} + mention_flags: dict[str, bool] = {} + last_message_times: dict[str, int] = {} + + mention_token = f"@[{name}]" if name else None + + # Channel unreads + cursor = await db.conn.execute( + """ + SELECT m.conversation_key, + COUNT(*) as unread_count, + SUM(CASE + WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 + ELSE 0 + END) > 0 as has_mention + FROM messages m + JOIN channels c ON m.conversation_key = c.key + WHERE m.type = 'CHAN' AND m.outgoing = 0 + AND m.received_at > COALESCE(c.last_read_at, 0) + GROUP BY m.conversation_key + """, + (mention_token or "", mention_token or ""), + ) + rows = await cursor.fetchall() + for row in rows: + state_key = f"channel-{row['conversation_key']}" + counts[state_key] = row["unread_count"] + if mention_token and row["has_mention"]: + mention_flags[state_key] = True + + # Contact unreads + cursor = await db.conn.execute( + """ + SELECT m.conversation_key, + COUNT(*) as unread_count, + SUM(CASE + WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 + ELSE 0 + END) > 0 as has_mention + FROM messages m + JOIN contacts ct ON m.conversation_key = ct.public_key + WHERE m.type = 'PRIV' AND m.outgoing = 0 + AND m.received_at > COALESCE(ct.last_read_at, 0) + GROUP BY m.conversation_key + """, + (mention_token or "", mention_token or ""), + ) + rows = await cursor.fetchall() + for row in rows: + state_key = f"contact-{row['conversation_key']}" + counts[state_key] = row["unread_count"] + if mention_token and row["has_mention"]: + mention_flags[state_key] = True + + # Last message times for all conversations (including read ones) + cursor = await db.conn.execute( + """ + SELECT type, conversation_key, MAX(received_at) as last_message_time + FROM messages + GROUP BY type, conversation_key + """ + ) + rows = await cursor.fetchall() + for row in rows: + prefix = "channel" if row["type"] == "CHAN" else "contact" + state_key = f"{prefix}-{row['conversation_key']}" + last_message_times[state_key] = row["last_message_time"] + + return { + "counts": counts, + "mentions": mention_flags, + "last_message_times": last_message_times, + } + + @staticmethod + async def count_dm_messages(contact_key: str) -> int: + """Count total DM messages for a contact.""" + cursor = await db.conn.execute( + "SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?", + (contact_key.lower(),), + ) + row = await cursor.fetchone() + return row["cnt"] if row else 0 + + @staticmethod + async def count_channel_messages_by_sender(sender_key: str) -> int: + """Count channel messages sent by a specific contact.""" + cursor = await db.conn.execute( + "SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?", + (sender_key.lower(),), + ) + row = await cursor.fetchone() + return row["cnt"] if row else 0 + + @staticmethod + async def get_most_active_rooms(sender_key: str, limit: int = 5) -> list[tuple[str, str, int]]: + """Get channels where a contact has sent the most messages. + + Returns list of (channel_key, channel_name, message_count) tuples. + """ + cursor = await db.conn.execute( + """ + SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, + COUNT(*) AS cnt + FROM messages m + LEFT JOIN channels c ON m.conversation_key = c.key + WHERE m.type = 'CHAN' AND m.sender_key = ? + GROUP BY m.conversation_key + ORDER BY cnt DESC + LIMIT ? + """, + (sender_key.lower(), limit), + ) + rows = await cursor.fetchall() + return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows] diff --git a/app/repository/raw_packets.py b/app/repository/raw_packets.py new file mode 100644 index 0000000..f661bc3 --- /dev/null +++ b/app/repository/raw_packets.py @@ -0,0 +1,150 @@ +import logging +import sqlite3 +import time +from hashlib import sha256 + +from app.database import db +from app.decoder import PayloadType, extract_payload, get_packet_payload_type + +logger = logging.getLogger(__name__) + + +class RawPacketRepository: + @staticmethod + async def create(data: bytes, timestamp: int | None = None) -> tuple[int, bool]: + """ + Create a raw packet with payload-based deduplication. + + Returns (packet_id, is_new) tuple: + - is_new=True: New packet stored, packet_id is the new row ID + - is_new=False: Duplicate payload detected, packet_id is the existing row ID + + Deduplication is based on the SHA-256 hash of the packet payload + (excluding routing/path information). + """ + ts = timestamp if timestamp is not None else int(time.time()) + + # Compute payload hash for deduplication + payload = extract_payload(data) + if payload: + payload_hash = sha256(payload).digest() + else: + # For malformed packets, hash the full data + payload_hash = sha256(data).digest() + + # Check if this payload already exists + cursor = await db.conn.execute( + "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) + ) + existing = await cursor.fetchone() + + if existing: + # Duplicate - return existing packet ID + logger.debug( + "Duplicate payload detected (hash=%s..., existing_id=%d)", + payload_hash.hex()[:12], + existing["id"], + ) + return (existing["id"], False) + + # New packet - insert with hash + try: + cursor = await db.conn.execute( + "INSERT INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)", + (ts, data, payload_hash), + ) + await db.conn.commit() + assert cursor.lastrowid is not None # INSERT always returns a row ID + return (cursor.lastrowid, True) + except sqlite3.IntegrityError: + # Race condition: another insert with same payload_hash happened between + # our SELECT and INSERT. This is expected for duplicate packets arriving + # close together. Query again to get the existing ID. + logger.debug( + "Duplicate packet detected via race condition (payload_hash=%s), dropping", + payload_hash.hex()[:16], + ) + cursor = await db.conn.execute( + "SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,) + ) + existing = await cursor.fetchone() + if existing: + return (existing["id"], False) + # This shouldn't happen, but if it does, re-raise + raise + + @staticmethod + async def get_undecrypted_count() -> int: + """Get count of undecrypted packets (those without a linked message).""" + cursor = await db.conn.execute( + "SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL" + ) + row = await cursor.fetchone() + return row["count"] if row else 0 + + @staticmethod + async def get_oldest_undecrypted() -> int | None: + """Get timestamp of oldest undecrypted packet, or None if none exist.""" + cursor = await db.conn.execute( + "SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL" + ) + row = await cursor.fetchone() + return row["oldest"] if row and row["oldest"] is not None else None + + @staticmethod + async def get_all_undecrypted() -> list[tuple[int, bytes, int]]: + """Get all undecrypted packets as (id, data, timestamp) tuples.""" + cursor = await db.conn.execute( + "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" + ) + rows = await cursor.fetchall() + return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows] + + @staticmethod + async def mark_decrypted(packet_id: int, message_id: int) -> None: + """Link a raw packet to its decrypted message.""" + await db.conn.execute( + "UPDATE raw_packets SET message_id = ? WHERE id = ?", + (message_id, packet_id), + ) + await db.conn.commit() + + @staticmethod + async def prune_old_undecrypted(max_age_days: int) -> int: + """Delete undecrypted packets older than max_age_days. Returns count deleted.""" + cutoff = int(time.time()) - (max_age_days * 86400) + cursor = await db.conn.execute( + "DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?", + (cutoff,), + ) + await db.conn.commit() + return cursor.rowcount + + @staticmethod + async def purge_linked_to_messages() -> int: + """Delete raw packets that are already linked to a stored message.""" + cursor = await db.conn.execute("DELETE FROM raw_packets WHERE message_id IS NOT NULL") + await db.conn.commit() + return cursor.rowcount + + @staticmethod + async def get_undecrypted_text_messages() -> list[tuple[int, bytes, int]]: + """Get all undecrypted TEXT_MESSAGE packets as (id, data, timestamp) tuples. + + Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02). + These are direct messages that can be decrypted with contact ECDH keys. + """ + cursor = await db.conn.execute( + "SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC" + ) + rows = await cursor.fetchall() + + # Filter for TEXT_MESSAGE packets + result = [] + for row in rows: + data = bytes(row["data"]) + payload_type = get_packet_payload_type(data) + if payload_type == PayloadType.TEXT_MESSAGE: + result.append((row["id"], data, row["timestamp"])) + + return result diff --git a/app/repository/settings.py b/app/repository/settings.py new file mode 100644 index 0000000..b888047 --- /dev/null +++ b/app/repository/settings.py @@ -0,0 +1,330 @@ +import json +import logging +import time +from typing import Any, Literal + +from app.database import db +from app.models import AppSettings, BotConfig, Favorite + +logger = logging.getLogger(__name__) + +SECONDS_1H = 3600 +SECONDS_24H = 86400 +SECONDS_7D = 604800 + + +class AppSettingsRepository: + """Repository for app_settings table (single-row pattern).""" + + @staticmethod + async def get() -> AppSettings: + """Get the current app settings. + + Always returns settings - creates default row if needed (migration handles initial row). + """ + cursor = await db.conn.execute( + """ + SELECT max_radio_contacts, favorites, auto_decrypt_dm_on_advert, + sidebar_sort_order, last_message_times, preferences_migrated, + advert_interval, last_advert_time, bots + FROM app_settings WHERE id = 1 + """ + ) + row = await cursor.fetchone() + + if not row: + # Should not happen after migration, but handle gracefully + return AppSettings() + + # Parse favorites JSON + favorites = [] + if row["favorites"]: + try: + favorites_data = json.loads(row["favorites"]) + favorites = [Favorite(**f) for f in favorites_data] + except (json.JSONDecodeError, TypeError, KeyError) as e: + logger.warning( + "Failed to parse favorites JSON, using empty list: %s (data=%r)", + e, + row["favorites"][:100] if row["favorites"] else None, + ) + favorites = [] + + # Parse last_message_times JSON + last_message_times: dict[str, int] = {} + if row["last_message_times"]: + try: + last_message_times = json.loads(row["last_message_times"]) + except (json.JSONDecodeError, TypeError) as e: + logger.warning( + "Failed to parse last_message_times JSON, using empty dict: %s", + e, + ) + last_message_times = {} + + # Parse bots JSON + bots: list[BotConfig] = [] + if row["bots"]: + try: + bots_data = json.loads(row["bots"]) + bots = [BotConfig(**b) for b in bots_data] + except (json.JSONDecodeError, TypeError, KeyError) as e: + logger.warning( + "Failed to parse bots JSON, using empty list: %s (data=%r)", + e, + row["bots"][:100] if row["bots"] else None, + ) + bots = [] + + # Validate sidebar_sort_order (fallback to "recent" if invalid) + sort_order = row["sidebar_sort_order"] + if sort_order not in ("recent", "alpha"): + sort_order = "recent" + + return AppSettings( + max_radio_contacts=row["max_radio_contacts"], + favorites=favorites, + auto_decrypt_dm_on_advert=bool(row["auto_decrypt_dm_on_advert"]), + sidebar_sort_order=sort_order, + last_message_times=last_message_times, + preferences_migrated=bool(row["preferences_migrated"]), + advert_interval=row["advert_interval"] or 0, + last_advert_time=row["last_advert_time"] or 0, + bots=bots, + ) + + @staticmethod + async def update( + max_radio_contacts: int | None = None, + favorites: list[Favorite] | None = None, + auto_decrypt_dm_on_advert: bool | None = None, + sidebar_sort_order: str | None = None, + last_message_times: dict[str, int] | None = None, + preferences_migrated: bool | None = None, + advert_interval: int | None = None, + last_advert_time: int | None = None, + bots: list[BotConfig] | None = None, + ) -> AppSettings: + """Update app settings. Only provided fields are updated.""" + updates = [] + params: list[Any] = [] + + if max_radio_contacts is not None: + updates.append("max_radio_contacts = ?") + params.append(max_radio_contacts) + + if favorites is not None: + updates.append("favorites = ?") + favorites_json = json.dumps([f.model_dump() for f in favorites]) + params.append(favorites_json) + + if auto_decrypt_dm_on_advert is not None: + updates.append("auto_decrypt_dm_on_advert = ?") + params.append(1 if auto_decrypt_dm_on_advert else 0) + + if sidebar_sort_order is not None: + updates.append("sidebar_sort_order = ?") + params.append(sidebar_sort_order) + + if last_message_times is not None: + updates.append("last_message_times = ?") + params.append(json.dumps(last_message_times)) + + if preferences_migrated is not None: + updates.append("preferences_migrated = ?") + params.append(1 if preferences_migrated else 0) + + if advert_interval is not None: + updates.append("advert_interval = ?") + params.append(advert_interval) + + if last_advert_time is not None: + updates.append("last_advert_time = ?") + params.append(last_advert_time) + + if bots is not None: + updates.append("bots = ?") + bots_json = json.dumps([b.model_dump() for b in bots]) + params.append(bots_json) + + if updates: + query = f"UPDATE app_settings SET {', '.join(updates)} WHERE id = 1" + await db.conn.execute(query, params) + await db.conn.commit() + + return await AppSettingsRepository.get() + + @staticmethod + async def add_favorite(fav_type: Literal["channel", "contact"], fav_id: str) -> AppSettings: + """Add a favorite, avoiding duplicates.""" + settings = await AppSettingsRepository.get() + + # Check if already favorited + if any(f.type == fav_type and f.id == fav_id for f in settings.favorites): + return settings + + new_favorites = settings.favorites + [Favorite(type=fav_type, id=fav_id)] + return await AppSettingsRepository.update(favorites=new_favorites) + + @staticmethod + async def remove_favorite(fav_type: Literal["channel", "contact"], fav_id: str) -> AppSettings: + """Remove a favorite.""" + settings = await AppSettingsRepository.get() + new_favorites = [ + f for f in settings.favorites if not (f.type == fav_type and f.id == fav_id) + ] + return await AppSettingsRepository.update(favorites=new_favorites) + + @staticmethod + async def migrate_preferences_from_frontend( + favorites: list[dict], + sort_order: str, + last_message_times: dict[str, int], + ) -> tuple[AppSettings, bool]: + """Migrate all preferences from frontend localStorage. + + This is a one-time migration. If already migrated, returns current settings + without overwriting. Returns (settings, did_migrate) tuple. + """ + settings = await AppSettingsRepository.get() + + if settings.preferences_migrated: + # Already migrated, don't overwrite + return settings, False + + # Convert frontend favorites format to Favorite objects + new_favorites = [] + for f in favorites: + if f.get("type") in ("channel", "contact") and f.get("id"): + new_favorites.append(Favorite(type=f["type"], id=f["id"])) + + # Update with migrated preferences and mark as migrated + settings = await AppSettingsRepository.update( + favorites=new_favorites, + sidebar_sort_order=sort_order if sort_order in ("recent", "alpha") else "recent", + last_message_times=last_message_times, + preferences_migrated=True, + ) + + return settings, True + + +class StatisticsRepository: + @staticmethod + async def _activity_counts(*, contact_type: int, exclude: bool = False) -> dict[str, int]: + """Get time-windowed counts for contacts/repeaters heard.""" + now = int(time.time()) + op = "!=" if exclude else "=" + cursor = await db.conn.execute( + f""" + SELECT + SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour, + SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours, + SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week + FROM contacts + WHERE type {op} ? AND last_seen IS NOT NULL + """, + (now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type), + ) + row = await cursor.fetchone() + assert row is not None # Aggregate query always returns a row + return { + "last_hour": row["last_hour"] or 0, + "last_24_hours": row["last_24_hours"] or 0, + "last_week": row["last_week"] or 0, + } + + @staticmethod + async def get_all() -> dict: + """Aggregate all statistics from existing tables.""" + now = int(time.time()) + + # Top 5 busiest channels in last 24h + cursor = await db.conn.execute( + """ + SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, + COUNT(*) AS message_count + FROM messages m + LEFT JOIN channels c ON m.conversation_key = c.key + WHERE m.type = 'CHAN' AND m.received_at >= ? + GROUP BY m.conversation_key + ORDER BY COUNT(*) DESC + LIMIT 5 + """, + (now - SECONDS_24H,), + ) + rows = await cursor.fetchall() + busiest_channels_24h = [ + { + "channel_key": row["conversation_key"], + "channel_name": row["channel_name"], + "message_count": row["message_count"], + } + for row in rows + ] + + # Entity counts + cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2") + row = await cursor.fetchone() + assert row is not None + contact_count: int = row["cnt"] + + cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2") + row = await cursor.fetchone() + assert row is not None + repeater_count: int = row["cnt"] + + cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM channels") + row = await cursor.fetchone() + assert row is not None + channel_count: int = row["cnt"] + + # Packet split + cursor = await db.conn.execute( + """ + SELECT COUNT(*) AS total, + SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted + FROM raw_packets + """ + ) + pkt_row = await cursor.fetchone() + assert pkt_row is not None + total_packets = pkt_row["total"] or 0 + decrypted_packets = pkt_row["decrypted"] or 0 + undecrypted_packets = total_packets - decrypted_packets + + # Message type counts + cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'PRIV'") + row = await cursor.fetchone() + assert row is not None + total_dms: int = row["cnt"] + + cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'CHAN'") + row = await cursor.fetchone() + assert row is not None + total_channel_messages: int = row["cnt"] + + # Outgoing count + cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE outgoing = 1") + row = await cursor.fetchone() + assert row is not None + total_outgoing: int = row["cnt"] + + # Activity windows + contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True) + repeaters_heard = await StatisticsRepository._activity_counts(contact_type=2) + + return { + "busiest_channels_24h": busiest_channels_24h, + "contact_count": contact_count, + "repeater_count": repeater_count, + "channel_count": channel_count, + "total_packets": total_packets, + "decrypted_packets": decrypted_packets, + "undecrypted_packets": undecrypted_packets, + "total_dms": total_dms, + "total_channel_messages": total_channel_messages, + "total_outgoing": total_outgoing, + "contacts_heard": contacts_heard, + "repeaters_heard": repeaters_heard, + } diff --git a/app/routers/contacts.py b/app/routers/contacts.py index ccfc55d..cd54890 100644 --- a/app/routers/contacts.py +++ b/app/routers/contacts.py @@ -1,36 +1,18 @@ -import asyncio import logging import random -import time -from typing import TYPE_CHECKING from fastapi import APIRouter, BackgroundTasks, HTTPException, Query from meshcore import EventType from app.dependencies import require_connected from app.models import ( - CONTACT_TYPE_REPEATER, - AclEntry, - CommandRequest, - CommandResponse, Contact, ContactActiveRoom, ContactAdvertPath, ContactAdvertPathSummary, ContactDetail, CreateContactRequest, - LppSensor, NearestRepeater, - NeighborInfo, - RepeaterAclResponse, - RepeaterAdvertIntervalsResponse, - RepeaterLoginRequest, - RepeaterLoginResponse, - RepeaterLppTelemetryResponse, - RepeaterNeighborsResponse, - RepeaterOwnerInfoResponse, - RepeaterRadioSettingsResponse, - RepeaterStatusResponse, TraceResponse, ) from app.packet_processor import start_historical_dm_decryption @@ -43,111 +25,10 @@ from app.repository import ( MessageRepository, ) -if TYPE_CHECKING: - from meshcore.events import Event - logger = logging.getLogger(__name__) -# ACL permission level names -ACL_PERMISSION_NAMES = { - 0: "Guest", - 1: "Read-only", - 2: "Read-write", - 3: "Admin", -} router = APIRouter(prefix="/contacts", tags=["contacts"]) -# Delay between repeater radio operations to allow key exchange and path establishment -REPEATER_OP_DELAY_SECONDS = 2.0 - - -def _monotonic() -> float: - """Wrapper around time.monotonic() for testability. - - Patching time.monotonic directly breaks the asyncio event loop which also - uses it. This indirection allows tests to control the clock safely. - """ - return time.monotonic() - - -def _extract_response_text(event) -> str: - """Extract text from a CLI response event, stripping the firmware '> ' prefix.""" - text = event.payload.get("text", str(event.payload)) - if text.startswith("> "): - text = text[2:] - return text - - -async def _fetch_repeater_response( - mc, - target_pubkey_prefix: str, - timeout: float = 20.0, -) -> "Event | None": - """Fetch a CLI response from a specific repeater via a validated get_msg() loop. - - Calls get_msg() repeatedly until a matching CLI response (txt_type=1) from the - target repeater arrives or the wall-clock deadline expires. Unrelated messages - are safe to skip — meshcore's event dispatcher already delivers them to the - normal subscription handlers (on_contact_message, etc.) when get_msg() returns. - - Args: - mc: MeshCore instance - target_pubkey_prefix: 12-char hex prefix of the repeater's public key - timeout: Wall-clock seconds before giving up - - Returns: - The matching Event, or None if no response arrived before the deadline. - """ - deadline = _monotonic() + timeout - - while _monotonic() < deadline: - try: - result = await mc.commands.get_msg(timeout=2.0) - except asyncio.TimeoutError: - continue - except Exception as e: - logger.debug("get_msg() exception: %s", e) - await asyncio.sleep(1.0) - continue - - if result.type == EventType.NO_MORE_MSGS: - # No messages queued yet — wait and retry - await asyncio.sleep(1.0) - continue - - if result.type == EventType.ERROR: - logger.debug("get_msg() error: %s", result.payload) - await asyncio.sleep(1.0) - continue - - if result.type == EventType.CONTACT_MSG_RECV: - msg_prefix = result.payload.get("pubkey_prefix", "") - txt_type = result.payload.get("txt_type", 0) - if msg_prefix == target_pubkey_prefix and txt_type == 1: - return result - # Not our target — already dispatched to subscribers by meshcore, - # so just continue draining the queue. - logger.debug( - "Skipping non-target message (from=%s, txt_type=%d) while waiting for %s", - msg_prefix, - txt_type, - target_pubkey_prefix, - ) - continue - - if result.type == EventType.CHANNEL_MSG_RECV: - # Already dispatched to subscribers by meshcore; skip. - logger.debug( - "Skipping channel message (channel_idx=%s) during repeater fetch", - result.payload.get("channel_idx"), - ) - continue - - logger.debug("Unexpected event type %s during repeater fetch, skipping", result.type) - - logger.warning("No CLI response from repeater %s within %.1fs", target_pubkey_prefix, timeout) - return None - def _ambiguous_contact_detail(err: AmbiguousPublicKeyPrefixError) -> str: sample = ", ".join(key[:12] for key in err.matches[:2]) @@ -169,42 +50,6 @@ async def _resolve_contact_or_404( return contact -async def prepare_repeater_connection(mc, contact: Contact, password: str) -> None: - """Prepare connection to a repeater by adding to radio and logging in. - - Args: - mc: MeshCore instance - contact: The repeater contact - password: Password for login (empty string for no password) - - Raises: - HTTPException: If login fails - """ - # Add contact to radio with path from DB (non-fatal — contact may already be loaded) - logger.info("Adding repeater %s to radio", contact.public_key[:12]) - await _ensure_on_radio(mc, contact) - - # Send login with password - logger.info("Sending login to repeater %s", contact.public_key[:12]) - login_result = await mc.commands.send_login(contact.public_key, password) - - if login_result.type == EventType.ERROR: - raise HTTPException(status_code=401, detail=f"Login failed: {login_result.payload}") - - # Wait for key exchange to complete before sending requests - logger.debug("Waiting %.1fs for key exchange to complete", REPEATER_OP_DELAY_SECONDS) - await asyncio.sleep(REPEATER_OP_DELAY_SECONDS) - - -def _require_repeater(contact: Contact) -> None: - """Raise 400 if contact is not a repeater.""" - if contact.type != CONTACT_TYPE_REPEATER: - raise HTTPException( - status_code=400, - detail=f"Contact is not a repeater (type={contact.type}, expected {CONTACT_TYPE_REPEATER})", - ) - - async def _ensure_on_radio(mc, contact: Contact) -> None: """Add a contact to the radio for routing, raising 500 on failure.""" add_result = await mc.commands.add_contact(contact.to_radio_dict()) @@ -214,272 +59,6 @@ async def _ensure_on_radio(mc, contact: Contact) -> None: ) -# --------------------------------------------------------------------------- -# Granular repeater endpoints — one attempt, no server-side retries. -# Frontend manages retry logic for better UX control. -# --------------------------------------------------------------------------- - - -@router.post("/{public_key}/repeater/login", response_model=RepeaterLoginResponse) -async def repeater_login(public_key: str, request: RepeaterLoginRequest) -> RepeaterLoginResponse: - """Log in to a repeater. Adds contact to radio, sends login, waits for key exchange.""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - async with radio_manager.radio_operation( - "repeater_login", - pause_polling=True, - suspend_auto_fetch=True, - ) as mc: - await prepare_repeater_connection(mc, contact, request.password) - - return RepeaterLoginResponse(status="ok") - - -@router.post("/{public_key}/repeater/status", response_model=RepeaterStatusResponse) -async def repeater_status(public_key: str) -> RepeaterStatusResponse: - """Fetch status telemetry from a repeater (single attempt, 10s timeout).""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - async with radio_manager.radio_operation( - "repeater_status", pause_polling=True, suspend_auto_fetch=True - ) as mc: - # Ensure contact is on radio for routing - await _ensure_on_radio(mc, contact) - - status = await mc.commands.req_status_sync(contact.public_key, timeout=10, min_timeout=5) - - if status is None: - raise HTTPException(status_code=504, detail="No status response from repeater") - - return RepeaterStatusResponse( - battery_volts=status.get("bat", 0) / 1000.0, - tx_queue_len=status.get("tx_queue_len", 0), - noise_floor_dbm=status.get("noise_floor", 0), - last_rssi_dbm=status.get("last_rssi", 0), - last_snr_db=status.get("last_snr", 0.0), - packets_received=status.get("nb_recv", 0), - packets_sent=status.get("nb_sent", 0), - airtime_seconds=status.get("airtime", 0), - rx_airtime_seconds=status.get("rx_airtime", 0), - uptime_seconds=status.get("uptime", 0), - sent_flood=status.get("sent_flood", 0), - sent_direct=status.get("sent_direct", 0), - recv_flood=status.get("recv_flood", 0), - recv_direct=status.get("recv_direct", 0), - flood_dups=status.get("flood_dups", 0), - direct_dups=status.get("direct_dups", 0), - full_events=status.get("full_evts", 0), - ) - - -@router.post("/{public_key}/repeater/lpp-telemetry", response_model=RepeaterLppTelemetryResponse) -async def repeater_lpp_telemetry(public_key: str) -> RepeaterLppTelemetryResponse: - """Fetch CayenneLPP sensor telemetry from a repeater (single attempt, 10s timeout).""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - async with radio_manager.radio_operation( - "repeater_lpp_telemetry", pause_polling=True, suspend_auto_fetch=True - ) as mc: - await _ensure_on_radio(mc, contact) - - telemetry = await mc.commands.req_telemetry_sync( - contact.public_key, timeout=10, min_timeout=5 - ) - - if telemetry is None: - raise HTTPException(status_code=504, detail="No telemetry response from repeater") - - sensors: list[LppSensor] = [] - for entry in telemetry: - channel = entry.get("channel", 0) - type_name = str(entry.get("type", "unknown")) - value = entry.get("value", 0) - sensors.append(LppSensor(channel=channel, type_name=type_name, value=value)) - - return RepeaterLppTelemetryResponse(sensors=sensors) - - -@router.post("/{public_key}/repeater/neighbors", response_model=RepeaterNeighborsResponse) -async def repeater_neighbors(public_key: str) -> RepeaterNeighborsResponse: - """Fetch neighbors from a repeater (single attempt, 10s timeout).""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - async with radio_manager.radio_operation( - "repeater_neighbors", pause_polling=True, suspend_auto_fetch=True - ) as mc: - # Ensure contact is on radio for routing - await _ensure_on_radio(mc, contact) - - neighbors_data = await mc.commands.fetch_all_neighbours( - contact.public_key, timeout=10, min_timeout=5 - ) - - neighbors: list[NeighborInfo] = [] - if neighbors_data and "neighbours" in neighbors_data: - for n in neighbors_data["neighbours"]: - pubkey_prefix = n.get("pubkey", "") - resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) - neighbors.append( - NeighborInfo( - pubkey_prefix=pubkey_prefix, - name=resolved_contact.name if resolved_contact else None, - snr=n.get("snr", 0.0), - last_heard_seconds=n.get("secs_ago", 0), - ) - ) - - return RepeaterNeighborsResponse(neighbors=neighbors) - - -@router.post("/{public_key}/repeater/acl", response_model=RepeaterAclResponse) -async def repeater_acl(public_key: str) -> RepeaterAclResponse: - """Fetch ACL from a repeater (single attempt, 10s timeout).""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - async with radio_manager.radio_operation( - "repeater_acl", pause_polling=True, suspend_auto_fetch=True - ) as mc: - # Ensure contact is on radio for routing - await _ensure_on_radio(mc, contact) - - acl_data = await mc.commands.req_acl_sync(contact.public_key, timeout=10, min_timeout=5) - - acl_entries: list[AclEntry] = [] - if acl_data and isinstance(acl_data, list): - for entry in acl_data: - pubkey_prefix = entry.get("key", "") - perm = entry.get("perm", 0) - resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) - acl_entries.append( - AclEntry( - pubkey_prefix=pubkey_prefix, - name=resolved_contact.name if resolved_contact else None, - permission=perm, - permission_name=ACL_PERMISSION_NAMES.get(perm, f"Unknown({perm})"), - ) - ) - - return RepeaterAclResponse(acl=acl_entries) - - -async def _batch_cli_fetch( - contact: Contact, - operation_name: str, - commands: list[tuple[str, str]], -) -> dict[str, str | None]: - """Send a batch of CLI commands to a repeater and collect responses. - - Opens a radio operation with polling paused and auto-fetch suspended (since - we call get_msg() directly via _fetch_repeater_response), adds the contact - to the radio for routing, then sends each command sequentially with a 1-second - gap between them. - - Returns a dict mapping field names to response strings (or None on timeout). - """ - results: dict[str, str | None] = {field: None for _, field in commands} - - async with radio_manager.radio_operation( - operation_name, - pause_polling=True, - suspend_auto_fetch=True, - ) as mc: - await _ensure_on_radio(mc, contact) - await asyncio.sleep(1.0) - - for i, (cmd, field) in enumerate(commands): - if i > 0: - await asyncio.sleep(1.0) - - send_result = await mc.commands.send_cmd(contact.public_key, cmd) - if send_result.type == EventType.ERROR: - logger.debug("Command '%s' send error: %s", cmd, send_result.payload) - continue - - response_event = await _fetch_repeater_response( - mc, contact.public_key[:12], timeout=10.0 - ) - if response_event is not None: - results[field] = _extract_response_text(response_event) - else: - logger.warning("No response for command '%s' (%s)", cmd, field) - - return results - - -@router.post("/{public_key}/repeater/radio-settings", response_model=RepeaterRadioSettingsResponse) -async def repeater_radio_settings(public_key: str) -> RepeaterRadioSettingsResponse: - """Fetch radio settings from a repeater via batch CLI commands.""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - results = await _batch_cli_fetch( - contact, - "repeater_radio_settings", - [ - ("ver", "firmware_version"), - ("get radio", "radio"), - ("get tx", "tx_power"), - ("get af", "airtime_factor"), - ("get repeat", "repeat_enabled"), - ("get flood.max", "flood_max"), - ("get name", "name"), - ("get lat", "lat"), - ("get lon", "lon"), - ("clock", "clock_utc"), - ], - ) - return RepeaterRadioSettingsResponse(**results) - - -@router.post( - "/{public_key}/repeater/advert-intervals", response_model=RepeaterAdvertIntervalsResponse -) -async def repeater_advert_intervals(public_key: str) -> RepeaterAdvertIntervalsResponse: - """Fetch advertisement intervals from a repeater via CLI commands.""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - results = await _batch_cli_fetch( - contact, - "repeater_advert_intervals", - [ - ("get advert.interval", "advert_interval"), - ("get flood.advert.interval", "flood_advert_interval"), - ], - ) - return RepeaterAdvertIntervalsResponse(**results) - - -@router.post("/{public_key}/repeater/owner-info", response_model=RepeaterOwnerInfoResponse) -async def repeater_owner_info(public_key: str) -> RepeaterOwnerInfoResponse: - """Fetch owner info and guest password from a repeater via CLI commands.""" - require_connected() - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - results = await _batch_cli_fetch( - contact, - "repeater_owner_info", - [ - ("get owner.info", "owner_info"), - ("get guest.password", "guest_password"), - ], - ) - return RepeaterOwnerInfoResponse(**results) - - @router.get("", response_model=list[Contact]) async def list_contacts( limit: int = Query(default=100, ge=1, le=1000), @@ -792,79 +371,6 @@ async def delete_contact(public_key: str) -> dict: return {"status": "ok"} -@router.post("/{public_key}/command", response_model=CommandResponse) -async def send_repeater_command(public_key: str, request: CommandRequest) -> CommandResponse: - """Send a CLI command to a repeater. - - The contact must be a repeater (type=2). The user must have already logged in - via the repeater/login endpoint. This endpoint ensures the contact is on the - radio before sending commands (the repeater remembers ACL permissions after login). - - Common commands: - - get name, set name - - get tx, set tx - - get radio, set radio - - tempradio - - setperm (0=guest, 1=read-only, 2=read-write, 3=admin) - - clock, clock sync - - reboot - - ver - """ - require_connected() - - # Get contact from database - contact = await _resolve_contact_or_404(public_key) - _require_repeater(contact) - - async with radio_manager.radio_operation( - "send_repeater_command", - pause_polling=True, - suspend_auto_fetch=True, - ) as mc: - # Add contact to radio with path from DB (non-fatal — contact may already be loaded) - logger.info("Adding repeater %s to radio", contact.public_key[:12]) - await _ensure_on_radio(mc, contact) - await asyncio.sleep(1.0) - - # Send the command - logger.info("Sending command to repeater %s: %s", contact.public_key[:12], request.command) - - send_result = await mc.commands.send_cmd(contact.public_key, request.command) - - if send_result.type == EventType.ERROR: - raise HTTPException( - status_code=500, detail=f"Failed to send command: {send_result.payload}" - ) - - # Wait for response using validated fetch loop - response_event = await _fetch_repeater_response(mc, contact.public_key[:12]) - - if response_event is None: - logger.warning( - "No response from repeater %s for command: %s", - contact.public_key[:12], - request.command, - ) - return CommandResponse( - command=request.command, - response="(no response - command may have been processed)", - ) - - # CONTACT_MSG_RECV payloads use sender_timestamp in meshcore. - response_text = _extract_response_text(response_event) - sender_timestamp = response_event.payload.get( - "sender_timestamp", - response_event.payload.get("timestamp"), - ) - logger.info("Received response from %s: %s", contact.public_key[:12], response_text) - - return CommandResponse( - command=request.command, - response=response_text, - sender_timestamp=sender_timestamp, - ) - - @router.post("/{public_key}/trace", response_model=TraceResponse) async def request_trace(public_key: str) -> TraceResponse: """Send a single-hop trace to a contact and wait for the result. diff --git a/app/routers/repeaters.py b/app/routers/repeaters.py new file mode 100644 index 0000000..bbb43ae --- /dev/null +++ b/app/routers/repeaters.py @@ -0,0 +1,510 @@ +import asyncio +import logging +import time +from typing import TYPE_CHECKING + +from fastapi import APIRouter, HTTPException +from meshcore import EventType + +from app.dependencies import require_connected +from app.models import ( + CONTACT_TYPE_REPEATER, + AclEntry, + CommandRequest, + CommandResponse, + Contact, + LppSensor, + NeighborInfo, + RepeaterAclResponse, + RepeaterAdvertIntervalsResponse, + RepeaterLoginRequest, + RepeaterLoginResponse, + RepeaterLppTelemetryResponse, + RepeaterNeighborsResponse, + RepeaterOwnerInfoResponse, + RepeaterRadioSettingsResponse, + RepeaterStatusResponse, +) +from app.radio import radio_manager +from app.repository import ContactRepository +from app.routers.contacts import _ensure_on_radio, _resolve_contact_or_404 + +if TYPE_CHECKING: + from meshcore.events import Event + +logger = logging.getLogger(__name__) + +# ACL permission level names +ACL_PERMISSION_NAMES = { + 0: "Guest", + 1: "Read-only", + 2: "Read-write", + 3: "Admin", +} +router = APIRouter(prefix="/contacts", tags=["repeaters"]) + +# Delay between repeater radio operations to allow key exchange and path establishment +REPEATER_OP_DELAY_SECONDS = 2.0 + + +def _monotonic() -> float: + """Wrapper around time.monotonic() for testability. + + Patching time.monotonic directly breaks the asyncio event loop which also + uses it. This indirection allows tests to control the clock safely. + """ + return time.monotonic() + + +def _extract_response_text(event) -> str: + """Extract text from a CLI response event, stripping the firmware '> ' prefix.""" + text = event.payload.get("text", str(event.payload)) + if text.startswith("> "): + text = text[2:] + return text + + +async def _fetch_repeater_response( + mc, + target_pubkey_prefix: str, + timeout: float = 20.0, +) -> "Event | None": + """Fetch a CLI response from a specific repeater via a validated get_msg() loop. + + Calls get_msg() repeatedly until a matching CLI response (txt_type=1) from the + target repeater arrives or the wall-clock deadline expires. Unrelated messages + are safe to skip — meshcore's event dispatcher already delivers them to the + normal subscription handlers (on_contact_message, etc.) when get_msg() returns. + + Args: + mc: MeshCore instance + target_pubkey_prefix: 12-char hex prefix of the repeater's public key + timeout: Wall-clock seconds before giving up + + Returns: + The matching Event, or None if no response arrived before the deadline. + """ + deadline = _monotonic() + timeout + + while _monotonic() < deadline: + try: + result = await mc.commands.get_msg(timeout=2.0) + except asyncio.TimeoutError: + continue + except Exception as e: + logger.debug("get_msg() exception: %s", e) + await asyncio.sleep(1.0) + continue + + if result.type == EventType.NO_MORE_MSGS: + # No messages queued yet — wait and retry + await asyncio.sleep(1.0) + continue + + if result.type == EventType.ERROR: + logger.debug("get_msg() error: %s", result.payload) + await asyncio.sleep(1.0) + continue + + if result.type == EventType.CONTACT_MSG_RECV: + msg_prefix = result.payload.get("pubkey_prefix", "") + txt_type = result.payload.get("txt_type", 0) + if msg_prefix == target_pubkey_prefix and txt_type == 1: + return result + # Not our target — already dispatched to subscribers by meshcore, + # so just continue draining the queue. + logger.debug( + "Skipping non-target message (from=%s, txt_type=%d) while waiting for %s", + msg_prefix, + txt_type, + target_pubkey_prefix, + ) + continue + + if result.type == EventType.CHANNEL_MSG_RECV: + # Already dispatched to subscribers by meshcore; skip. + logger.debug( + "Skipping channel message (channel_idx=%s) during repeater fetch", + result.payload.get("channel_idx"), + ) + continue + + logger.debug("Unexpected event type %s during repeater fetch, skipping", result.type) + + logger.warning("No CLI response from repeater %s within %.1fs", target_pubkey_prefix, timeout) + return None + + +async def prepare_repeater_connection(mc, contact: Contact, password: str) -> None: + """Prepare connection to a repeater by adding to radio and logging in. + + Args: + mc: MeshCore instance + contact: The repeater contact + password: Password for login (empty string for no password) + + Raises: + HTTPException: If login fails + """ + # Add contact to radio with path from DB (non-fatal — contact may already be loaded) + logger.info("Adding repeater %s to radio", contact.public_key[:12]) + await _ensure_on_radio(mc, contact) + + # Send login with password + logger.info("Sending login to repeater %s", contact.public_key[:12]) + login_result = await mc.commands.send_login(contact.public_key, password) + + if login_result.type == EventType.ERROR: + raise HTTPException(status_code=401, detail=f"Login failed: {login_result.payload}") + + # Wait for key exchange to complete before sending requests + logger.debug("Waiting %.1fs for key exchange to complete", REPEATER_OP_DELAY_SECONDS) + await asyncio.sleep(REPEATER_OP_DELAY_SECONDS) + + +def _require_repeater(contact: Contact) -> None: + """Raise 400 if contact is not a repeater.""" + if contact.type != CONTACT_TYPE_REPEATER: + raise HTTPException( + status_code=400, + detail=f"Contact is not a repeater (type={contact.type}, expected {CONTACT_TYPE_REPEATER})", + ) + + +# --------------------------------------------------------------------------- +# Granular repeater endpoints — one attempt, no server-side retries. +# Frontend manages retry logic for better UX control. +# --------------------------------------------------------------------------- + + +@router.post("/{public_key}/repeater/login", response_model=RepeaterLoginResponse) +async def repeater_login(public_key: str, request: RepeaterLoginRequest) -> RepeaterLoginResponse: + """Log in to a repeater. Adds contact to radio, sends login, waits for key exchange.""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + async with radio_manager.radio_operation( + "repeater_login", + pause_polling=True, + suspend_auto_fetch=True, + ) as mc: + await prepare_repeater_connection(mc, contact, request.password) + + return RepeaterLoginResponse(status="ok") + + +@router.post("/{public_key}/repeater/status", response_model=RepeaterStatusResponse) +async def repeater_status(public_key: str) -> RepeaterStatusResponse: + """Fetch status telemetry from a repeater (single attempt, 10s timeout).""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + async with radio_manager.radio_operation( + "repeater_status", pause_polling=True, suspend_auto_fetch=True + ) as mc: + # Ensure contact is on radio for routing + await _ensure_on_radio(mc, contact) + + status = await mc.commands.req_status_sync(contact.public_key, timeout=10, min_timeout=5) + + if status is None: + raise HTTPException(status_code=504, detail="No status response from repeater") + + return RepeaterStatusResponse( + battery_volts=status.get("bat", 0) / 1000.0, + tx_queue_len=status.get("tx_queue_len", 0), + noise_floor_dbm=status.get("noise_floor", 0), + last_rssi_dbm=status.get("last_rssi", 0), + last_snr_db=status.get("last_snr", 0.0), + packets_received=status.get("nb_recv", 0), + packets_sent=status.get("nb_sent", 0), + airtime_seconds=status.get("airtime", 0), + rx_airtime_seconds=status.get("rx_airtime", 0), + uptime_seconds=status.get("uptime", 0), + sent_flood=status.get("sent_flood", 0), + sent_direct=status.get("sent_direct", 0), + recv_flood=status.get("recv_flood", 0), + recv_direct=status.get("recv_direct", 0), + flood_dups=status.get("flood_dups", 0), + direct_dups=status.get("direct_dups", 0), + full_events=status.get("full_evts", 0), + ) + + +@router.post("/{public_key}/repeater/lpp-telemetry", response_model=RepeaterLppTelemetryResponse) +async def repeater_lpp_telemetry(public_key: str) -> RepeaterLppTelemetryResponse: + """Fetch CayenneLPP sensor telemetry from a repeater (single attempt, 10s timeout).""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + async with radio_manager.radio_operation( + "repeater_lpp_telemetry", pause_polling=True, suspend_auto_fetch=True + ) as mc: + await _ensure_on_radio(mc, contact) + + telemetry = await mc.commands.req_telemetry_sync( + contact.public_key, timeout=10, min_timeout=5 + ) + + if telemetry is None: + raise HTTPException(status_code=504, detail="No telemetry response from repeater") + + sensors: list[LppSensor] = [] + for entry in telemetry: + channel = entry.get("channel", 0) + type_name = str(entry.get("type", "unknown")) + value = entry.get("value", 0) + sensors.append(LppSensor(channel=channel, type_name=type_name, value=value)) + + return RepeaterLppTelemetryResponse(sensors=sensors) + + +@router.post("/{public_key}/repeater/neighbors", response_model=RepeaterNeighborsResponse) +async def repeater_neighbors(public_key: str) -> RepeaterNeighborsResponse: + """Fetch neighbors from a repeater (single attempt, 10s timeout).""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + async with radio_manager.radio_operation( + "repeater_neighbors", pause_polling=True, suspend_auto_fetch=True + ) as mc: + # Ensure contact is on radio for routing + await _ensure_on_radio(mc, contact) + + neighbors_data = await mc.commands.fetch_all_neighbours( + contact.public_key, timeout=10, min_timeout=5 + ) + + neighbors: list[NeighborInfo] = [] + if neighbors_data and "neighbours" in neighbors_data: + for n in neighbors_data["neighbours"]: + pubkey_prefix = n.get("pubkey", "") + resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) + neighbors.append( + NeighborInfo( + pubkey_prefix=pubkey_prefix, + name=resolved_contact.name if resolved_contact else None, + snr=n.get("snr", 0.0), + last_heard_seconds=n.get("secs_ago", 0), + ) + ) + + return RepeaterNeighborsResponse(neighbors=neighbors) + + +@router.post("/{public_key}/repeater/acl", response_model=RepeaterAclResponse) +async def repeater_acl(public_key: str) -> RepeaterAclResponse: + """Fetch ACL from a repeater (single attempt, 10s timeout).""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + async with radio_manager.radio_operation( + "repeater_acl", pause_polling=True, suspend_auto_fetch=True + ) as mc: + # Ensure contact is on radio for routing + await _ensure_on_radio(mc, contact) + + acl_data = await mc.commands.req_acl_sync(contact.public_key, timeout=10, min_timeout=5) + + acl_entries: list[AclEntry] = [] + if acl_data and isinstance(acl_data, list): + for entry in acl_data: + pubkey_prefix = entry.get("key", "") + perm = entry.get("perm", 0) + resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix) + acl_entries.append( + AclEntry( + pubkey_prefix=pubkey_prefix, + name=resolved_contact.name if resolved_contact else None, + permission=perm, + permission_name=ACL_PERMISSION_NAMES.get(perm, f"Unknown({perm})"), + ) + ) + + return RepeaterAclResponse(acl=acl_entries) + + +async def _batch_cli_fetch( + contact: Contact, + operation_name: str, + commands: list[tuple[str, str]], +) -> dict[str, str | None]: + """Send a batch of CLI commands to a repeater and collect responses. + + Opens a radio operation with polling paused and auto-fetch suspended (since + we call get_msg() directly via _fetch_repeater_response), adds the contact + to the radio for routing, then sends each command sequentially with a 1-second + gap between them. + + Returns a dict mapping field names to response strings (or None on timeout). + """ + results: dict[str, str | None] = {field: None for _, field in commands} + + async with radio_manager.radio_operation( + operation_name, + pause_polling=True, + suspend_auto_fetch=True, + ) as mc: + await _ensure_on_radio(mc, contact) + await asyncio.sleep(1.0) + + for i, (cmd, field) in enumerate(commands): + if i > 0: + await asyncio.sleep(1.0) + + send_result = await mc.commands.send_cmd(contact.public_key, cmd) + if send_result.type == EventType.ERROR: + logger.debug("Command '%s' send error: %s", cmd, send_result.payload) + continue + + response_event = await _fetch_repeater_response( + mc, contact.public_key[:12], timeout=10.0 + ) + if response_event is not None: + results[field] = _extract_response_text(response_event) + else: + logger.warning("No response for command '%s' (%s)", cmd, field) + + return results + + +@router.post("/{public_key}/repeater/radio-settings", response_model=RepeaterRadioSettingsResponse) +async def repeater_radio_settings(public_key: str) -> RepeaterRadioSettingsResponse: + """Fetch radio settings from a repeater via batch CLI commands.""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + results = await _batch_cli_fetch( + contact, + "repeater_radio_settings", + [ + ("ver", "firmware_version"), + ("get radio", "radio"), + ("get tx", "tx_power"), + ("get af", "airtime_factor"), + ("get repeat", "repeat_enabled"), + ("get flood.max", "flood_max"), + ("get name", "name"), + ("get lat", "lat"), + ("get lon", "lon"), + ("clock", "clock_utc"), + ], + ) + return RepeaterRadioSettingsResponse(**results) + + +@router.post( + "/{public_key}/repeater/advert-intervals", response_model=RepeaterAdvertIntervalsResponse +) +async def repeater_advert_intervals(public_key: str) -> RepeaterAdvertIntervalsResponse: + """Fetch advertisement intervals from a repeater via CLI commands.""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + results = await _batch_cli_fetch( + contact, + "repeater_advert_intervals", + [ + ("get advert.interval", "advert_interval"), + ("get flood.advert.interval", "flood_advert_interval"), + ], + ) + return RepeaterAdvertIntervalsResponse(**results) + + +@router.post("/{public_key}/repeater/owner-info", response_model=RepeaterOwnerInfoResponse) +async def repeater_owner_info(public_key: str) -> RepeaterOwnerInfoResponse: + """Fetch owner info and guest password from a repeater via CLI commands.""" + require_connected() + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + results = await _batch_cli_fetch( + contact, + "repeater_owner_info", + [ + ("get owner.info", "owner_info"), + ("get guest.password", "guest_password"), + ], + ) + return RepeaterOwnerInfoResponse(**results) + + +@router.post("/{public_key}/command", response_model=CommandResponse) +async def send_repeater_command(public_key: str, request: CommandRequest) -> CommandResponse: + """Send a CLI command to a repeater. + + The contact must be a repeater (type=2). The user must have already logged in + via the repeater/login endpoint. This endpoint ensures the contact is on the + radio before sending commands (the repeater remembers ACL permissions after login). + + Common commands: + - get name, set name + - get tx, set tx + - get radio, set radio + - tempradio + - setperm (0=guest, 1=read-only, 2=read-write, 3=admin) + - clock, clock sync + - reboot + - ver + """ + require_connected() + + # Get contact from database + contact = await _resolve_contact_or_404(public_key) + _require_repeater(contact) + + async with radio_manager.radio_operation( + "send_repeater_command", + pause_polling=True, + suspend_auto_fetch=True, + ) as mc: + # Add contact to radio with path from DB (non-fatal — contact may already be loaded) + logger.info("Adding repeater %s to radio", contact.public_key[:12]) + await _ensure_on_radio(mc, contact) + await asyncio.sleep(1.0) + + # Send the command + logger.info("Sending command to repeater %s: %s", contact.public_key[:12], request.command) + + send_result = await mc.commands.send_cmd(contact.public_key, request.command) + + if send_result.type == EventType.ERROR: + raise HTTPException( + status_code=500, detail=f"Failed to send command: {send_result.payload}" + ) + + # Wait for response using validated fetch loop + response_event = await _fetch_repeater_response(mc, contact.public_key[:12]) + + if response_event is None: + logger.warning( + "No response from repeater %s for command: %s", + contact.public_key[:12], + request.command, + ) + return CommandResponse( + command=request.command, + response="(no response - command may have been processed)", + ) + + # CONTACT_MSG_RECV payloads use sender_timestamp in meshcore. + response_text = _extract_response_text(response_event) + sender_timestamp = response_event.payload.get( + "sender_timestamp", + response_event.payload.get("timestamp"), + ) + logger.info("Received response from %s: %s", contact.public_key[:12], response_text) + + return CommandResponse( + command=request.command, + response=response_text, + sender_timestamp=sender_timestamp, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 99b32cf..86f2c7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,11 @@ import shutil import tempfile from pathlib import Path +import httpx import pytest +from app.database import Database + # Use an isolated file-backed SQLite DB for tests that import app.main/TestClient. # This must be set before app.config/app.database are imported, otherwise the global # Database instance will bind to the default runtime DB (data/meshcore.db). @@ -20,3 +23,52 @@ def cleanup_test_db_dir(): """Clean up temporary pytest DB directory after the test session.""" yield shutil.rmtree(_TEST_DB_DIR, ignore_errors=True) + + +@pytest.fixture +async def test_db(): + """Create an in-memory test database with schema + migrations.""" + from app.repository import channels, contacts, messages, raw_packets, settings + + db = Database(":memory:") + await db.connect() + + submodules = [contacts, channels, messages, raw_packets, settings] + originals = [(mod, mod.db) for mod in submodules] + + for mod in submodules: + mod.db = db + + # Also patch the db reference used by the packets router for VACUUM + import app.routers.packets as packets_module + + original_packets_db = packets_module.db + packets_module.db = db + + try: + yield db + finally: + for mod, original in originals: + mod.db = original + packets_module.db = original_packets_db + await db.disconnect() + + +@pytest.fixture +def client(): + """Create an httpx AsyncClient for testing the app.""" + from app.main import app + + transport = httpx.ASGITransport(app=app) + return httpx.AsyncClient(transport=transport, base_url="http://test") + + +@pytest.fixture +def captured_broadcasts(): + """Capture WebSocket broadcasts for verification.""" + broadcasts = [] + + def mock_broadcast(event_type: str, data: dict): + broadcasts.append({"type": event_type, "data": data}) + + return broadcasts, mock_broadcast diff --git a/tests/test_api.py b/tests/test_api.py index 66782ab..6f4a37e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,10 +8,8 @@ import hashlib import time from unittest.mock import AsyncMock, MagicMock, patch -import httpx import pytest -from app.database import Database from app.radio import radio_manager from app.repository import ( ChannelRepository, @@ -31,33 +29,6 @@ def _reset_radio_state(): radio_manager._operation_lock = prev_lock -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - -@pytest.fixture -def client(): - """Create an httpx AsyncClient for testing the app.""" - from app.main import app - - transport = httpx.ASGITransport(app=app) - return httpx.AsyncClient(transport=transport, base_url="http://test") - - async def _insert_contact(public_key, name="Alice", **overrides): """Insert a contact into the test database.""" data = { diff --git a/tests/test_channels_router.py b/tests/test_channels_router.py index 13a8149..3cadb77 100644 --- a/tests/test_channels_router.py +++ b/tests/test_channels_router.py @@ -7,33 +7,13 @@ from the radio and upserts them into the database. from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch -import httpx import pytest from meshcore import EventType -from app.database import Database from app.radio import radio_manager from app.repository import ChannelRepository -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - @pytest.fixture(autouse=True) def _reset_radio_state(): """Save/restore radio_manager state so tests don't leak.""" @@ -44,15 +24,6 @@ def _reset_radio_state(): radio_manager._operation_lock = prev_lock -@pytest.fixture -def client(): - """Create an httpx AsyncClient for testing the app.""" - from app.main import app - - transport = httpx.ASGITransport(app=app) - return httpx.AsyncClient(transport=transport, base_url="http://test") - - def _make_channel_info(name: str, secret: bytes): """Create a mock channel info response.""" result = MagicMock() diff --git a/tests/test_contacts_router.py b/tests/test_contacts_router.py index b682200..01eb8c2 100644 --- a/tests/test_contacts_router.py +++ b/tests/test_contacts_router.py @@ -9,11 +9,9 @@ Uses httpx.AsyncClient with real in-memory SQLite database. from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch -import httpx import pytest from meshcore import EventType -from app.database import Database from app.radio import radio_manager from app.repository import ContactAdvertPathRepository, ContactRepository, MessageRepository @@ -43,24 +41,6 @@ def _reset_radio_state(): radio_manager._operation_lock = prev_lock -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - async def _insert_contact(public_key=KEY_A, name="Alice", on_radio=False, **overrides): """Insert a contact into the test database.""" data = { @@ -82,15 +62,6 @@ async def _insert_contact(public_key=KEY_A, name="Alice", on_radio=False, **over await ContactRepository.upsert(data) -@pytest.fixture -def client(): - """Create an httpx AsyncClient for testing the app.""" - from app.main import app - - transport = httpx.ASGITransport(app=app) - return httpx.AsyncClient(transport=transport, base_url="http://test") - - class TestListContacts: """Test GET /api/contacts.""" diff --git a/tests/test_echo_dedup.py b/tests/test_echo_dedup.py index 1a0735d..7475f65 100644 --- a/tests/test_echo_dedup.py +++ b/tests/test_echo_dedup.py @@ -11,7 +11,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from app.database import Database from app.decoder import DecryptedDirectMessage from app.repository import ( ContactRepository, @@ -19,36 +18,6 @@ from app.repository import ( RawPacketRepository, ) - -@pytest.fixture -async def test_db(): - """Create an in-memory test database.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - -@pytest.fixture -def captured_broadcasts(): - """Capture WebSocket broadcasts for verification.""" - broadcasts = [] - - def mock_broadcast(event_type: str, data: dict): - broadcasts.append({"type": event_type, "data": data}) - - return broadcasts, mock_broadcast - - # Shared test constants CHANNEL_KEY = "ABC123DEF456ABC123DEF456ABC12345" CONTACT_PUB = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7" diff --git a/tests/test_event_handlers.py b/tests/test_event_handlers.py index d732e3a..6077a9e 100644 --- a/tests/test_event_handlers.py +++ b/tests/test_event_handlers.py @@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from app.database import Database from app.event_handlers import ( _active_subscriptions, _pending_acks, @@ -23,24 +22,6 @@ from app.repository import ( ) -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - @pytest.fixture(autouse=True) def clear_test_state(): """Clear pending ACKs and subscriptions before each test.""" diff --git a/tests/test_key_normalization.py b/tests/test_key_normalization.py index 8017192..af30967 100644 --- a/tests/test_key_normalization.py +++ b/tests/test_key_normalization.py @@ -2,28 +2,9 @@ import pytest -from app.database import Database from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository -@pytest.fixture -async def test_db(): - """Create an in-memory test database.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - @pytest.mark.asyncio async def test_upsert_stores_lowercase_key(test_db): await ContactRepository.upsert( diff --git a/tests/test_message_pagination.py b/tests/test_message_pagination.py index 1256786..dedbfa7 100644 --- a/tests/test_message_pagination.py +++ b/tests/test_message_pagination.py @@ -2,28 +2,8 @@ import pytest -from app.database import Database from app.repository import MessageRepository - -@pytest.fixture -async def test_db(): - """Create an in-memory test database.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - CHAN_KEY = "ABC123DEF456ABC123DEF456ABC12345" DM_KEY = "aa" * 32 diff --git a/tests/test_message_prefix_claim.py b/tests/test_message_prefix_claim.py index e98c4af..b988eac 100644 --- a/tests/test_message_prefix_claim.py +++ b/tests/test_message_prefix_claim.py @@ -2,28 +2,9 @@ import pytest -from app.database import Database from app.repository import ContactRepository, MessageRepository -@pytest.fixture -async def test_db(): - """Create an in-memory test database.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - @pytest.mark.asyncio async def test_claim_prefix_promotes_dm_to_full_key(test_db): full_key = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7" diff --git a/tests/test_packet_pipeline.py b/tests/test_packet_pipeline.py index 84e719c..18df349 100644 --- a/tests/test_packet_pipeline.py +++ b/tests/test_packet_pipeline.py @@ -13,7 +13,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from app.database import Database from app.decoder import DecryptedDirectMessage, PacketInfo, ParsedAdvertisement, PayloadType from app.repository import ( ChannelRepository, @@ -28,41 +27,6 @@ with open(FIXTURES_PATH) as f: FIXTURES = json.load(f) -@pytest.fixture -async def test_db(): - """Create an in-memory test database. - - We need to patch the db module-level variable before any repository - methods are called, so they use our test database. - """ - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - # Store original and patch the module attribute directly - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - -@pytest.fixture -def captured_broadcasts(): - """Capture WebSocket broadcasts for verification.""" - broadcasts = [] - - def mock_broadcast(event_type: str, data: dict): - """Synchronous mock that captures broadcasts.""" - broadcasts.append({"type": event_type, "data": data}) - - return broadcasts, mock_broadcast - - class TestChannelMessagePipeline: """Test channel message flow: packet → decrypt → store → broadcast.""" diff --git a/tests/test_packets_router.py b/tests/test_packets_router.py index a24ae95..536a93f 100644 --- a/tests/test_packets_router.py +++ b/tests/test_packets_router.py @@ -7,47 +7,11 @@ undecrypted count endpoint, and the maintenance endpoint. import time from unittest.mock import patch -import httpx import pytest -from app.database import Database from app.repository import ChannelRepository, MessageRepository, RawPacketRepository -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - # Also patch the db reference used by the packets router for VACUUM - import app.routers.packets as packets_module - - original_packets_db = packets_module.db - packets_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - packets_module.db = original_packets_db - await db.disconnect() - - -@pytest.fixture -def client(): - """Create an httpx AsyncClient for testing the app.""" - from app.main import app - - transport = httpx.ASGITransport(app=app) - return httpx.AsyncClient(transport=transport, base_url="http://test") - - async def _insert_raw_packets(count: int, decrypted: bool = False, age_days: int = 0) -> list[int]: """Insert raw packets and return their IDs.""" ids = [] diff --git a/tests/test_radio_sync.py b/tests/test_radio_sync.py index 7da25ac..d413b68 100644 --- a/tests/test_radio_sync.py +++ b/tests/test_radio_sync.py @@ -10,7 +10,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from meshcore import EventType -from app.database import Database from app.models import Favorite from app.radio import RadioManager, radio_manager from app.radio_sync import ( @@ -30,24 +29,6 @@ from app.repository import ( ) -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - @pytest.fixture(autouse=True) def reset_sync_state(): """Reset polling pause state, sync timestamp, and radio_manager before/after each test.""" diff --git a/tests/test_repeater_routes.py b/tests/test_repeater_routes.py index e39f2f8..5238b1f 100644 --- a/tests/test_repeater_routes.py +++ b/tests/test_repeater_routes.py @@ -6,11 +6,11 @@ import pytest from fastapi import HTTPException from meshcore import EventType -from app.database import Database from app.models import CommandRequest, Contact, RepeaterLoginRequest from app.radio import radio_manager from app.repository import ContactRepository -from app.routers.contacts import ( +from app.routers.contacts import request_trace +from app.routers.repeaters import ( _batch_cli_fetch, _fetch_repeater_response, repeater_acl, @@ -21,7 +21,6 @@ from app.routers.contacts import ( repeater_owner_info, repeater_radio_settings, repeater_status, - request_trace, send_repeater_command, ) @@ -29,7 +28,7 @@ KEY_A = "aa" * 32 # Patch target for the wall-clock wrapper used by _fetch_repeater_response. # We patch _monotonic (not time.monotonic) to avoid breaking the asyncio event loop. -_MONOTONIC = "app.routers.contacts._monotonic" +_MONOTONIC = "app.routers.repeaters._monotonic" @pytest.fixture(autouse=True) @@ -42,24 +41,6 @@ def _reset_radio_state(): radio_manager._operation_lock = prev_lock -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - def _radio_result(event_type=EventType.OK, payload=None): result = MagicMock() result.type = event_type @@ -210,7 +191,7 @@ class TestFetchRepeaterResponse: with ( patch(_MONOTONIC, side_effect=_advancing_clock()), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): result = await _fetch_repeater_response(mc, "aaaaaaaaaaaa", timeout=5.0) @@ -229,7 +210,7 @@ class TestFetchRepeaterResponse: with ( patch(_MONOTONIC, side_effect=times), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): result = await _fetch_repeater_response(mc, "aaaaaaaaaaaa", timeout=2.0) @@ -247,7 +228,7 @@ class TestFetchRepeaterResponse: with ( patch(_MONOTONIC, side_effect=_advancing_clock()), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): result = await _fetch_repeater_response(mc, "aaaaaaaaaaaa", timeout=5.0) @@ -290,7 +271,7 @@ class TestRepeaterCommandRoute: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -308,10 +289,10 @@ class TestRepeaterCommandRoute: # Expire the deadline after a couple of ticks with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=[0.0, 5.0, 25.0]), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -337,7 +318,7 @@ class TestRepeaterCommandRoute: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -365,7 +346,7 @@ class TestRepeaterCommandRoute: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -391,7 +372,7 @@ class TestRepeaterCommandRoute: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -419,7 +400,7 @@ class TestRepeaterCommandRoute: mc.commands.get_msg = AsyncMock(side_effect=[unrelated, expected]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -445,7 +426,7 @@ class TestRepeaterCommandRoute: mc.commands.get_msg = AsyncMock(side_effect=[channel_msg, expected]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -468,10 +449,10 @@ class TestRepeaterCommandRoute: mc.commands.get_msg = AsyncMock(side_effect=[no_msgs, expected]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): response = await send_repeater_command(KEY_A, CommandRequest(command="ver")) @@ -548,10 +529,10 @@ class TestRepeaterLogin: await _insert_contact(KEY_A, name="Repeater", contact_type=2) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch( - "app.routers.contacts.prepare_repeater_connection", + "app.routers.repeaters.prepare_repeater_connection", new_callable=AsyncMock, ) as mock_prepare, ): @@ -564,7 +545,7 @@ class TestRepeaterLogin: async def test_404_missing_contact(self, test_db): mc = _mock_mc() with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -576,7 +557,7 @@ class TestRepeaterLogin: mc = _mock_mc() await _insert_contact(KEY_A, name="Client", contact_type=1) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -593,9 +574,9 @@ class TestRepeaterLogin: raise HTTPException(status_code=401, detail="Login failed") with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), - patch("app.routers.contacts.prepare_repeater_connection", side_effect=_prepare_fail), + patch("app.routers.repeaters.prepare_repeater_connection", side_effect=_prepare_fail), ): with pytest.raises(HTTPException) as exc: await repeater_login(KEY_A, RepeaterLoginRequest(password="bad")) @@ -630,7 +611,7 @@ class TestRepeaterStatus: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_status(KEY_A) @@ -653,7 +634,7 @@ class TestRepeaterStatus: mc.commands.req_status_sync = AsyncMock(return_value=None) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -665,7 +646,7 @@ class TestRepeaterStatus: mc = _mock_mc() await _insert_contact(KEY_A, name="Client", contact_type=1) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -691,7 +672,7 @@ class TestRepeaterLppTelemetry: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_lpp_telemetry(KEY_A) @@ -713,7 +694,7 @@ class TestRepeaterLppTelemetry: mc.commands.req_telemetry_sync = AsyncMock(return_value=[]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_lpp_telemetry(KEY_A) @@ -727,7 +708,7 @@ class TestRepeaterLppTelemetry: mc.commands.req_telemetry_sync = AsyncMock(return_value=None) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -739,7 +720,7 @@ class TestRepeaterLppTelemetry: mc = _mock_mc() await _insert_contact(KEY_A, name="Client", contact_type=1) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -765,7 +746,7 @@ class TestRepeaterNeighbors: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_neighbors(KEY_A) @@ -783,7 +764,7 @@ class TestRepeaterNeighbors: mc.commands.fetch_all_neighbours = AsyncMock(return_value={"neighbours": []}) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_neighbors(KEY_A) @@ -797,7 +778,7 @@ class TestRepeaterNeighbors: mc.commands.fetch_all_neighbours = AsyncMock(return_value=None) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_neighbors(KEY_A) @@ -821,7 +802,7 @@ class TestRepeaterAcl: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_acl(KEY_A) @@ -839,7 +820,7 @@ class TestRepeaterAcl: mc.commands.req_acl_sync = AsyncMock(return_value=[]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_acl(KEY_A) @@ -853,7 +834,7 @@ class TestRepeaterAcl: mc.commands.req_acl_sync = AsyncMock(return_value=None) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): response = await repeater_acl(KEY_A) @@ -890,7 +871,7 @@ class TestRepeaterRadioSettings: mc.commands.get_msg = AsyncMock(side_effect=get_msg_results) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -927,10 +908,10 @@ class TestRepeaterRadioSettings: clock_ticks.extend([base, base + 5.0, base + 11.0]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=clock_ticks), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): response = await repeater_radio_settings(KEY_A) @@ -943,7 +924,7 @@ class TestRepeaterRadioSettings: mc = _mock_mc() await _insert_contact(KEY_A, name="Client", contact_type=1) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -970,7 +951,7 @@ class TestRepeaterAdvertIntervals: mc.commands.get_msg = AsyncMock(side_effect=responses) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -991,10 +972,10 @@ class TestRepeaterAdvertIntervals: clock_ticks.extend([base, base + 5.0, base + 11.0]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=clock_ticks), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): response = await repeater_advert_intervals(KEY_A) @@ -1025,7 +1006,7 @@ class TestRepeaterOwnerInfo: mc.commands.get_msg = AsyncMock(side_effect=responses) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), ): @@ -1046,10 +1027,10 @@ class TestRepeaterOwnerInfo: clock_ticks.extend([base, base + 5.0, base + 11.0]) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=clock_ticks), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): response = await repeater_owner_info(KEY_A) @@ -1107,7 +1088,7 @@ class TestBatchCliFetch: with ( patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=_advancing_clock()), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): results = await _batch_cli_fetch( contact, "test_op", [("bad_cmd", "field_a"), ("good_cmd", "field_b")] @@ -1128,7 +1109,7 @@ class TestBatchCliFetch: with ( patch.object(radio_manager, "_meshcore", mc), patch(_MONOTONIC, side_effect=[0.0, 5.0, 11.0]), - patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock), + patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock), ): results = await _batch_cli_fetch(contact, "test_op", [("clock", "clock_output")]) @@ -1147,7 +1128,7 @@ class TestRepeaterAddContactError: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -1165,7 +1146,7 @@ class TestRepeaterAddContactError: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -1183,7 +1164,7 @@ class TestRepeaterAddContactError: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: @@ -1201,7 +1182,7 @@ class TestRepeaterAddContactError: ) with ( - patch("app.routers.contacts.require_connected", return_value=mc), + patch("app.routers.repeaters.require_connected", return_value=mc), patch.object(radio_manager, "_meshcore", mc), ): with pytest.raises(HTTPException) as exc: diff --git a/tests/test_repository.py b/tests/test_repository.py index 3537190..5cbce1a 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -4,7 +4,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from app.database import Database from app.repository import ( ContactAdvertPathRepository, ContactNameHistoryRepository, @@ -13,24 +12,6 @@ from app.repository import ( ) -@pytest.fixture -async def test_db(): - """Create an in-memory test database with the module-level db swapped in.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - async def _create_message(test_db, **overrides) -> int: """Helper to insert a message and return its id.""" defaults = { @@ -90,7 +71,7 @@ class TestMessageRepositoryAddPath: """Adding a path without received_at uses current timestamp.""" msg_id = await _create_message(test_db) - with patch("app.repository.time") as mock_time: + with patch("app.repository.messages.time") as mock_time: mock_time.time.return_value = 1700000500.5 result = await MessageRepository.add_path(message_id=msg_id, path="1A2B") @@ -518,7 +499,7 @@ class TestAppSettingsRepository: mock_db = MagicMock() mock_db.conn = mock_conn - with patch("app.repository.db", mock_db): + with patch("app.repository.settings.db", mock_db): from app.repository import AppSettingsRepository settings = await AppSettingsRepository.get() diff --git a/tests/test_send_messages.py b/tests/test_send_messages.py index 52763bc..265b5f4 100644 --- a/tests/test_send_messages.py +++ b/tests/test_send_messages.py @@ -8,7 +8,6 @@ import pytest from fastapi import HTTPException from meshcore import EventType -from app.database import Database from app.models import ( SendChannelMessageRequest, SendDirectMessageRequest, @@ -36,24 +35,6 @@ def _reset_radio_state(): radio_manager._operation_lock = prev_lock -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - def _make_radio_result(payload=None): """Create a mock radio command result.""" result = MagicMock() diff --git a/tests/test_settings_router.py b/tests/test_settings_router.py index b932c89..c79fce8 100644 --- a/tests/test_settings_router.py +++ b/tests/test_settings_router.py @@ -3,7 +3,6 @@ import pytest from fastapi import HTTPException -from app.database import Database from app.models import AppSettings, BotConfig from app.repository import AppSettingsRepository from app.routers.settings import ( @@ -16,24 +15,6 @@ from app.routers.settings import ( ) -@pytest.fixture -async def test_db(): - """Create an in-memory test database with schema + migrations.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - class TestUpdateSettings: @pytest.mark.asyncio async def test_forwards_only_provided_fields(self, test_db): diff --git a/tests/test_statistics.py b/tests/test_statistics.py index cdb2ba1..acf4ee9 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -4,28 +4,9 @@ import time import pytest -from app.database import Database from app.repository import StatisticsRepository -@pytest.fixture -async def test_db(): - """Create an in-memory test database with the module-level db swapped in.""" - import app.repository as repo_module - - db = Database(":memory:") - await db.connect() - - original_db = repo_module.db - repo_module.db = db - - try: - yield db - finally: - repo_module.db = original_db - await db.disconnect() - - class TestStatisticsEmpty: @pytest.mark.asyncio async def test_empty_database(self, test_db):