Reorganize for great victory and move to blob for payload hasg

This commit is contained in:
Jack Kingsman
2026-02-27 21:03:34 -08:00
parent fc27361e37
commit ce99d63701
27 changed files with 2028 additions and 2303 deletions

View File

@@ -23,6 +23,7 @@ from app.routers import (
packets,
radio,
read_state,
repeaters,
settings,
statistics,
ws,
@@ -106,6 +107,7 @@ async def radio_disconnected_handler(request: Request, exc: RadioDisconnectedErr
app.include_router(health.router, prefix="/api")
app.include_router(radio.router, prefix="/api")
app.include_router(contacts.router, prefix="/api")
app.include_router(repeaters.router, prefix="/api")
app.include_router(channels.router, prefix="/api")
app.include_router(messages.router, prefix="/api")
app.include_router(packets.router, prefix="/api")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,22 @@
from app.repository.channels import ChannelRepository
from app.repository.contacts import (
AmbiguousPublicKeyPrefixError,
ContactAdvertPathRepository,
ContactNameHistoryRepository,
ContactRepository,
)
from app.repository.messages import MessageRepository
from app.repository.raw_packets import RawPacketRepository
from app.repository.settings import AppSettingsRepository, StatisticsRepository
__all__ = [
"AmbiguousPublicKeyPrefixError",
"AppSettingsRepository",
"ChannelRepository",
"ContactAdvertPathRepository",
"ContactNameHistoryRepository",
"ContactRepository",
"MessageRepository",
"RawPacketRepository",
"StatisticsRepository",
]

View File

@@ -0,0 +1,86 @@
import time
from app.database import db
from app.models import Channel
class ChannelRepository:
@staticmethod
async def upsert(key: str, name: str, is_hashtag: bool = False, on_radio: bool = False) -> None:
"""Upsert a channel. Key is 32-char hex string."""
await db.conn.execute(
"""
INSERT INTO channels (key, name, is_hashtag, on_radio)
VALUES (?, ?, ?, ?)
ON CONFLICT(key) DO UPDATE SET
name = excluded.name,
is_hashtag = excluded.is_hashtag,
on_radio = excluded.on_radio
""",
(key.upper(), name, is_hashtag, on_radio),
)
await db.conn.commit()
@staticmethod
async def get_by_key(key: str) -> Channel | None:
"""Get a channel by its key (32-char hex string)."""
cursor = await db.conn.execute(
"SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels WHERE key = ?",
(key.upper(),),
)
row = await cursor.fetchone()
if row:
return Channel(
key=row["key"],
name=row["name"],
is_hashtag=bool(row["is_hashtag"]),
on_radio=bool(row["on_radio"]),
last_read_at=row["last_read_at"],
)
return None
@staticmethod
async def get_all() -> list[Channel]:
cursor = await db.conn.execute(
"SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels ORDER BY name"
)
rows = await cursor.fetchall()
return [
Channel(
key=row["key"],
name=row["name"],
is_hashtag=bool(row["is_hashtag"]),
on_radio=bool(row["on_radio"]),
last_read_at=row["last_read_at"],
)
for row in rows
]
@staticmethod
async def delete(key: str) -> None:
"""Delete a channel by key."""
await db.conn.execute(
"DELETE FROM channels WHERE key = ?",
(key.upper(),),
)
await db.conn.commit()
@staticmethod
async def update_last_read_at(key: str, timestamp: int | None = None) -> bool:
"""Update the last_read_at timestamp for a channel.
Returns True if a row was updated, False if channel not found.
"""
ts = timestamp if timestamp is not None else int(time.time())
cursor = await db.conn.execute(
"UPDATE channels SET last_read_at = ? WHERE key = ?",
(ts, key.upper()),
)
await db.conn.commit()
return cursor.rowcount > 0
@staticmethod
async def mark_all_read(timestamp: int) -> None:
"""Mark all channels as read at the given timestamp."""
await db.conn.execute("UPDATE channels SET last_read_at = ?", (timestamp,))
await db.conn.commit()

412
app/repository/contacts.py Normal file
View File

@@ -0,0 +1,412 @@
import time
from typing import Any
from app.database import db
from app.models import (
Contact,
ContactAdvertPath,
ContactAdvertPathSummary,
ContactNameHistory,
)
class AmbiguousPublicKeyPrefixError(ValueError):
"""Raised when a public key prefix matches multiple contacts."""
def __init__(self, prefix: str, matches: list[str]):
self.prefix = prefix.lower()
self.matches = matches
super().__init__(f"Ambiguous public key prefix '{self.prefix}'")
class ContactRepository:
@staticmethod
async def upsert(contact: dict[str, Any]) -> None:
await db.conn.execute(
"""
INSERT INTO contacts (public_key, name, type, flags, last_path, last_path_len,
last_advert, lat, lon, last_seen, on_radio, last_contacted,
first_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(public_key) DO UPDATE SET
name = COALESCE(excluded.name, contacts.name),
type = CASE WHEN excluded.type = 0 THEN contacts.type ELSE excluded.type END,
flags = excluded.flags,
last_path = COALESCE(excluded.last_path, contacts.last_path),
last_path_len = excluded.last_path_len,
last_advert = COALESCE(excluded.last_advert, contacts.last_advert),
lat = COALESCE(excluded.lat, contacts.lat),
lon = COALESCE(excluded.lon, contacts.lon),
last_seen = excluded.last_seen,
on_radio = COALESCE(excluded.on_radio, contacts.on_radio),
last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted),
first_seen = COALESCE(contacts.first_seen, excluded.first_seen)
""",
(
contact.get("public_key", "").lower(),
contact.get("name"),
contact.get("type", 0),
contact.get("flags", 0),
contact.get("last_path"),
contact.get("last_path_len", -1),
contact.get("last_advert"),
contact.get("lat"),
contact.get("lon"),
contact.get("last_seen", int(time.time())),
contact.get("on_radio"),
contact.get("last_contacted"),
contact.get("first_seen"),
),
)
await db.conn.commit()
@staticmethod
def _row_to_contact(row) -> Contact:
"""Convert a database row to a Contact model."""
return Contact(
public_key=row["public_key"],
name=row["name"],
type=row["type"],
flags=row["flags"],
last_path=row["last_path"],
last_path_len=row["last_path_len"],
last_advert=row["last_advert"],
lat=row["lat"],
lon=row["lon"],
last_seen=row["last_seen"],
on_radio=bool(row["on_radio"]),
last_contacted=row["last_contacted"],
last_read_at=row["last_read_at"],
first_seen=row["first_seen"],
)
@staticmethod
async def get_by_key(public_key: str) -> Contact | None:
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),)
)
row = await cursor.fetchone()
return ContactRepository._row_to_contact(row) if row else None
@staticmethod
async def get_by_key_prefix(prefix: str) -> Contact | None:
"""Get a contact by key prefix only if it resolves uniquely.
Returns None when no contacts match OR when multiple contacts match
the prefix (to avoid silently selecting the wrong contact).
"""
normalized_prefix = prefix.lower()
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2",
(f"{normalized_prefix}%",),
)
rows = list(await cursor.fetchall())
if len(rows) != 1:
return None
return ContactRepository._row_to_contact(rows[0])
@staticmethod
async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]:
"""Get contacts matching a key prefix, up to limit."""
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?",
(f"{prefix.lower()}%", limit),
)
rows = list(await cursor.fetchall())
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def get_by_key_or_prefix(key_or_prefix: str) -> Contact | None:
"""Get a contact by exact key match, falling back to prefix match.
Useful when the input might be a full 64-char public key or a shorter prefix.
"""
contact = await ContactRepository.get_by_key(key_or_prefix)
if contact:
return contact
matches = await ContactRepository._get_prefix_matches(key_or_prefix, limit=2)
if len(matches) == 1:
return matches[0]
if len(matches) > 1:
raise AmbiguousPublicKeyPrefixError(
key_or_prefix,
[m.public_key for m in matches],
)
return None
@staticmethod
async def get_by_name(name: str) -> list[Contact]:
"""Get all contacts with the given exact name."""
cursor = await db.conn.execute("SELECT * FROM contacts WHERE name = ?", (name,))
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def resolve_prefixes(prefixes: list[str]) -> dict[str, Contact]:
"""Resolve multiple key prefixes to contacts in a single query.
Returns a dict mapping each prefix to its Contact, only for prefixes
that resolve uniquely (exactly one match). Ambiguous or unmatched
prefixes are omitted.
"""
if not prefixes:
return {}
normalized = [p.lower() for p in prefixes]
conditions = " OR ".join(["public_key LIKE ?"] * len(normalized))
params = [f"{p}%" for p in normalized]
cursor = await db.conn.execute(f"SELECT * FROM contacts WHERE {conditions}", params)
rows = await cursor.fetchall()
# Group by which prefix each row matches
prefix_to_rows: dict[str, list] = {p: [] for p in normalized}
for row in rows:
pk = row["public_key"]
for p in normalized:
if pk.startswith(p):
prefix_to_rows[p].append(row)
# Only include uniquely-resolved prefixes
result: dict[str, Contact] = {}
for p in normalized:
if len(prefix_to_rows[p]) == 1:
result[p] = ContactRepository._row_to_contact(prefix_to_rows[p][0])
return result
@staticmethod
async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]:
cursor = await db.conn.execute(
"SELECT * FROM contacts ORDER BY COALESCE(name, public_key) LIMIT ? OFFSET ?",
(limit, offset),
)
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def get_recent_non_repeaters(limit: int = 200) -> list[Contact]:
"""Get the most recently active non-repeater contacts.
Orders by most recent activity (last_contacted or last_advert),
excluding repeaters (type=2).
"""
cursor = await db.conn.execute(
"""
SELECT * FROM contacts
WHERE type != 2
ORDER BY COALESCE(last_contacted, 0) DESC, COALESCE(last_advert, 0) DESC
LIMIT ?
""",
(limit,),
)
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def update_path(public_key: str, path: str, path_len: int) -> None:
await db.conn.execute(
"UPDATE contacts SET last_path = ?, last_path_len = ?, last_seen = ? WHERE public_key = ?",
(path, path_len, int(time.time()), public_key.lower()),
)
await db.conn.commit()
@staticmethod
async def set_on_radio(public_key: str, on_radio: bool) -> None:
await db.conn.execute(
"UPDATE contacts SET on_radio = ? WHERE public_key = ?",
(on_radio, public_key.lower()),
)
await db.conn.commit()
@staticmethod
async def delete(public_key: str) -> None:
normalized = public_key.lower()
await db.conn.execute(
"DELETE FROM contact_name_history WHERE public_key = ?", (normalized,)
)
await db.conn.execute(
"DELETE FROM contact_advert_paths WHERE public_key = ?", (normalized,)
)
await db.conn.execute("DELETE FROM contacts WHERE public_key = ?", (normalized,))
await db.conn.commit()
@staticmethod
async def update_last_contacted(public_key: str, timestamp: int | None = None) -> None:
"""Update the last_contacted timestamp for a contact."""
ts = timestamp if timestamp is not None else int(time.time())
await db.conn.execute(
"UPDATE contacts SET last_contacted = ?, last_seen = ? WHERE public_key = ?",
(ts, ts, public_key.lower()),
)
await db.conn.commit()
@staticmethod
async def update_last_read_at(public_key: str, timestamp: int | None = None) -> bool:
"""Update the last_read_at timestamp for a contact.
Returns True if a row was updated, False if contact not found.
"""
ts = timestamp if timestamp is not None else int(time.time())
cursor = await db.conn.execute(
"UPDATE contacts SET last_read_at = ? WHERE public_key = ?",
(ts, public_key.lower()),
)
await db.conn.commit()
return cursor.rowcount > 0
@staticmethod
async def mark_all_read(timestamp: int) -> None:
"""Mark all contacts as read at the given timestamp."""
await db.conn.execute("UPDATE contacts SET last_read_at = ?", (timestamp,))
await db.conn.commit()
@staticmethod
async def get_by_pubkey_first_byte(hex_byte: str) -> list[Contact]:
"""Get contacts whose public key starts with the given hex byte (2 chars)."""
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?",
(hex_byte.lower(),),
)
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
class ContactAdvertPathRepository:
"""Repository for recent unique advertisement paths per contact."""
@staticmethod
def _row_to_path(row) -> ContactAdvertPath:
path = row["path_hex"] or ""
next_hop = path[:2].lower() if len(path) >= 2 else None
return ContactAdvertPath(
path=path,
path_len=row["path_len"],
next_hop=next_hop,
first_seen=row["first_seen"],
last_seen=row["last_seen"],
heard_count=row["heard_count"],
)
@staticmethod
async def record_observation(
public_key: str,
path_hex: str,
timestamp: int,
max_paths: int = 10,
) -> None:
"""
Upsert a unique advert path observation for a contact and prune to N most recent.
"""
if max_paths < 1:
max_paths = 1
normalized_key = public_key.lower()
normalized_path = path_hex.lower()
path_len = len(normalized_path) // 2
await db.conn.execute(
"""
INSERT INTO contact_advert_paths
(public_key, path_hex, path_len, first_seen, last_seen, heard_count)
VALUES (?, ?, ?, ?, ?, 1)
ON CONFLICT(public_key, path_hex) DO UPDATE SET
last_seen = MAX(contact_advert_paths.last_seen, excluded.last_seen),
path_len = excluded.path_len,
heard_count = contact_advert_paths.heard_count + 1
""",
(normalized_key, normalized_path, path_len, timestamp, timestamp),
)
# Keep only the N most recent unique paths per contact.
await db.conn.execute(
"""
DELETE FROM contact_advert_paths
WHERE public_key = ?
AND path_hex NOT IN (
SELECT path_hex
FROM contact_advert_paths
WHERE public_key = ?
ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC
LIMIT ?
)
""",
(normalized_key, normalized_key, max_paths),
)
await db.conn.commit()
@staticmethod
async def get_recent_for_contact(public_key: str, limit: int = 10) -> list[ContactAdvertPath]:
cursor = await db.conn.execute(
"""
SELECT path_hex, path_len, first_seen, last_seen, heard_count
FROM contact_advert_paths
WHERE public_key = ?
ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC
LIMIT ?
""",
(public_key.lower(), limit),
)
rows = await cursor.fetchall()
return [ContactAdvertPathRepository._row_to_path(row) for row in rows]
@staticmethod
async def get_recent_for_all_contacts(
limit_per_contact: int = 10,
) -> list[ContactAdvertPathSummary]:
cursor = await db.conn.execute(
"""
SELECT public_key, path_hex, path_len, first_seen, last_seen, heard_count
FROM contact_advert_paths
ORDER BY public_key ASC, last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC
"""
)
rows = await cursor.fetchall()
grouped: dict[str, list[ContactAdvertPath]] = {}
for row in rows:
key = row["public_key"]
paths = grouped.get(key)
if paths is None:
paths = []
grouped[key] = paths
if len(paths) >= limit_per_contact:
continue
paths.append(ContactAdvertPathRepository._row_to_path(row))
return [
ContactAdvertPathSummary(public_key=key, paths=paths) for key, paths in grouped.items()
]
class ContactNameHistoryRepository:
"""Repository for contact name change history."""
@staticmethod
async def record_name(public_key: str, name: str, timestamp: int) -> None:
"""Record a name observation. Upserts: updates last_seen if name already known."""
await db.conn.execute(
"""
INSERT INTO contact_name_history (public_key, name, first_seen, last_seen)
VALUES (?, ?, ?, ?)
ON CONFLICT(public_key, name) DO UPDATE SET
last_seen = MAX(contact_name_history.last_seen, excluded.last_seen)
""",
(public_key.lower(), name, timestamp, timestamp),
)
await db.conn.commit()
@staticmethod
async def get_history(public_key: str) -> list[ContactNameHistory]:
cursor = await db.conn.execute(
"""
SELECT name, first_seen, last_seen
FROM contact_name_history
WHERE public_key = ?
ORDER BY last_seen DESC
""",
(public_key.lower(),),
)
rows = await cursor.fetchall()
return [
ContactNameHistory(
name=row["name"], first_seen=row["first_seen"], last_seen=row["last_seen"]
)
for row in rows
]

411
app/repository/messages.py Normal file
View File

@@ -0,0 +1,411 @@
import json
import time
from typing import Any
from app.database import db
from app.models import Message, MessagePath
class MessageRepository:
@staticmethod
def _parse_paths(paths_json: str | None) -> list[MessagePath] | None:
"""Parse paths JSON string to list of MessagePath objects."""
if not paths_json:
return None
try:
paths_data = json.loads(paths_json)
return [MessagePath(**p) for p in paths_data]
except (json.JSONDecodeError, TypeError, KeyError):
return None
@staticmethod
async def create(
msg_type: str,
text: str,
received_at: int,
conversation_key: str,
sender_timestamp: int | None = None,
path: str | None = None,
txt_type: int = 0,
signature: str | None = None,
outgoing: bool = False,
sender_name: str | None = None,
sender_key: str | None = None,
) -> int | None:
"""Create a message, returning the ID or None if duplicate.
Uses INSERT OR IGNORE to handle the UNIQUE constraint on
(type, conversation_key, text, sender_timestamp). This prevents
duplicate messages when the same message arrives via multiple RF paths.
The path parameter is converted to the paths JSON array format.
"""
# Convert single path to paths array format
paths_json = None
if path is not None:
paths_json = json.dumps([{"path": path, "received_at": received_at}])
cursor = await db.conn.execute(
"""
INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp,
received_at, paths, txt_type, signature, outgoing,
sender_name, sender_key)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
msg_type,
conversation_key,
text,
sender_timestamp,
received_at,
paths_json,
txt_type,
signature,
outgoing,
sender_name,
sender_key,
),
)
await db.conn.commit()
# rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation
if cursor.rowcount == 0:
return None
return cursor.lastrowid
@staticmethod
async def add_path(
message_id: int, path: str, received_at: int | None = None
) -> list[MessagePath]:
"""Add a new path to an existing message.
This is used when a repeat/echo of a message arrives via a different route.
Returns the updated list of paths.
"""
ts = received_at if received_at is not None else int(time.time())
# Atomic append: use json_insert to avoid read-modify-write race when
# multiple duplicate packets arrive concurrently for the same message.
new_entry = json.dumps({"path": path, "received_at": ts})
await db.conn.execute(
"""UPDATE messages SET paths = json_insert(
COALESCE(paths, '[]'), '$[#]', json(?)
) WHERE id = ?""",
(new_entry, message_id),
)
await db.conn.commit()
# Read back the full list for the return value
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
row = await cursor.fetchone()
if not row or not row["paths"]:
return []
try:
all_paths = json.loads(row["paths"])
except json.JSONDecodeError:
return []
return [MessagePath(**p) for p in all_paths]
@staticmethod
async def claim_prefix_messages(full_key: str) -> int:
"""Promote prefix-stored messages to the full conversation key.
When a full key becomes known for a contact, any messages stored with
only a prefix as conversation_key are updated to use the full key.
"""
lower_key = full_key.lower()
cursor = await db.conn.execute(
"""UPDATE messages SET conversation_key = ?
WHERE type = 'PRIV' AND length(conversation_key) < 64
AND ? LIKE conversation_key || '%'
AND (
SELECT COUNT(*) FROM contacts
WHERE public_key LIKE messages.conversation_key || '%'
) = 1""",
(lower_key, lower_key),
)
await db.conn.commit()
return cursor.rowcount
@staticmethod
async def get_all(
limit: int = 100,
offset: int = 0,
msg_type: str | None = None,
conversation_key: str | None = None,
before: int | None = None,
before_id: int | None = None,
) -> list[Message]:
query = "SELECT * FROM messages WHERE 1=1"
params: list[Any] = []
if msg_type:
query += " AND type = ?"
params.append(msg_type)
if conversation_key:
normalized_key = conversation_key
# Prefer exact matching for full keys.
if len(conversation_key) == 64:
normalized_key = conversation_key.lower()
query += " AND conversation_key = ?"
params.append(normalized_key)
elif len(conversation_key) == 32:
normalized_key = conversation_key.upper()
query += " AND conversation_key = ?"
params.append(normalized_key)
else:
# Prefix match is only for legacy/partial key callers.
query += " AND conversation_key LIKE ?"
params.append(f"{conversation_key}%")
if before is not None and before_id is not None:
query += " AND (received_at < ? OR (received_at = ? AND id < ?))"
params.extend([before, before, before_id])
query += " ORDER BY received_at DESC, id DESC LIMIT ?"
params.append(limit)
if before is None or before_id is None:
query += " OFFSET ?"
params.append(offset)
cursor = await db.conn.execute(query, params)
rows = await cursor.fetchall()
return [
Message(
id=row["id"],
type=row["type"],
conversation_key=row["conversation_key"],
text=row["text"],
sender_timestamp=row["sender_timestamp"],
received_at=row["received_at"],
paths=MessageRepository._parse_paths(row["paths"]),
txt_type=row["txt_type"],
signature=row["signature"],
outgoing=bool(row["outgoing"]),
acked=row["acked"],
)
for row in rows
]
@staticmethod
async def increment_ack_count(message_id: int) -> int:
"""Increment ack count and return the new value."""
await db.conn.execute("UPDATE messages SET acked = acked + 1 WHERE id = ?", (message_id,))
await db.conn.commit()
cursor = await db.conn.execute("SELECT acked FROM messages WHERE id = ?", (message_id,))
row = await cursor.fetchone()
return row["acked"] if row else 1
@staticmethod
async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]:
"""Get the current ack count and paths for a message."""
cursor = await db.conn.execute(
"SELECT acked, paths FROM messages WHERE id = ?", (message_id,)
)
row = await cursor.fetchone()
if not row:
return 0, None
return row["acked"], MessageRepository._parse_paths(row["paths"])
@staticmethod
async def get_by_id(message_id: int) -> "Message | None":
"""Look up a message by its ID."""
cursor = await db.conn.execute(
"""
SELECT id, type, conversation_key, text, sender_timestamp, received_at,
paths, txt_type, signature, outgoing, acked
FROM messages
WHERE id = ?
""",
(message_id,),
)
row = await cursor.fetchone()
if not row:
return None
return Message(
id=row["id"],
type=row["type"],
conversation_key=row["conversation_key"],
text=row["text"],
sender_timestamp=row["sender_timestamp"],
received_at=row["received_at"],
paths=MessageRepository._parse_paths(row["paths"]),
txt_type=row["txt_type"],
signature=row["signature"],
outgoing=bool(row["outgoing"]),
acked=row["acked"],
)
@staticmethod
async def get_by_content(
msg_type: str,
conversation_key: str,
text: str,
sender_timestamp: int | None,
) -> "Message | None":
"""Look up a message by its unique content fields."""
cursor = await db.conn.execute(
"""
SELECT id, type, conversation_key, text, sender_timestamp, received_at,
paths, txt_type, signature, outgoing, acked
FROM messages
WHERE type = ? AND conversation_key = ? AND text = ?
AND (sender_timestamp = ? OR (sender_timestamp IS NULL AND ? IS NULL))
""",
(msg_type, conversation_key, text, sender_timestamp, sender_timestamp),
)
row = await cursor.fetchone()
if not row:
return None
paths = None
if row["paths"]:
try:
paths_data = json.loads(row["paths"])
paths = [
MessagePath(path=p["path"], received_at=p["received_at"]) for p in paths_data
]
except (json.JSONDecodeError, KeyError):
pass
return Message(
id=row["id"],
type=row["type"],
conversation_key=row["conversation_key"],
text=row["text"],
sender_timestamp=row["sender_timestamp"],
received_at=row["received_at"],
paths=paths,
txt_type=row["txt_type"],
signature=row["signature"],
outgoing=bool(row["outgoing"]),
acked=row["acked"],
)
@staticmethod
async def get_unread_counts(name: str | None = None) -> dict:
"""Get unread message counts, mention flags, and last message times for all conversations.
Args:
name: User's display name for @[name] mention detection. If None, mentions are skipped.
Returns:
Dict with 'counts', 'mentions', and 'last_message_times' keys.
"""
counts: dict[str, int] = {}
mention_flags: dict[str, bool] = {}
last_message_times: dict[str, int] = {}
mention_token = f"@[{name}]" if name else None
# Channel unreads
cursor = await db.conn.execute(
"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0
END) > 0 as has_mention
FROM messages m
JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.outgoing = 0
AND m.received_at > COALESCE(c.last_read_at, 0)
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or ""),
)
rows = await cursor.fetchall()
for row in rows:
state_key = f"channel-{row['conversation_key']}"
counts[state_key] = row["unread_count"]
if mention_token and row["has_mention"]:
mention_flags[state_key] = True
# Contact unreads
cursor = await db.conn.execute(
"""
SELECT m.conversation_key,
COUNT(*) as unread_count,
SUM(CASE
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
ELSE 0
END) > 0 as has_mention
FROM messages m
JOIN contacts ct ON m.conversation_key = ct.public_key
WHERE m.type = 'PRIV' AND m.outgoing = 0
AND m.received_at > COALESCE(ct.last_read_at, 0)
GROUP BY m.conversation_key
""",
(mention_token or "", mention_token or ""),
)
rows = await cursor.fetchall()
for row in rows:
state_key = f"contact-{row['conversation_key']}"
counts[state_key] = row["unread_count"]
if mention_token and row["has_mention"]:
mention_flags[state_key] = True
# Last message times for all conversations (including read ones)
cursor = await db.conn.execute(
"""
SELECT type, conversation_key, MAX(received_at) as last_message_time
FROM messages
GROUP BY type, conversation_key
"""
)
rows = await cursor.fetchall()
for row in rows:
prefix = "channel" if row["type"] == "CHAN" else "contact"
state_key = f"{prefix}-{row['conversation_key']}"
last_message_times[state_key] = row["last_message_time"]
return {
"counts": counts,
"mentions": mention_flags,
"last_message_times": last_message_times,
}
@staticmethod
async def count_dm_messages(contact_key: str) -> int:
"""Count total DM messages for a contact."""
cursor = await db.conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?",
(contact_key.lower(),),
)
row = await cursor.fetchone()
return row["cnt"] if row else 0
@staticmethod
async def count_channel_messages_by_sender(sender_key: str) -> int:
"""Count channel messages sent by a specific contact."""
cursor = await db.conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?",
(sender_key.lower(),),
)
row = await cursor.fetchone()
return row["cnt"] if row else 0
@staticmethod
async def get_most_active_rooms(sender_key: str, limit: int = 5) -> list[tuple[str, str, int]]:
"""Get channels where a contact has sent the most messages.
Returns list of (channel_key, channel_name, message_count) tuples.
"""
cursor = await db.conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS cnt
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.sender_key = ?
GROUP BY m.conversation_key
ORDER BY cnt DESC
LIMIT ?
""",
(sender_key.lower(), limit),
)
rows = await cursor.fetchall()
return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows]

View File

@@ -0,0 +1,150 @@
import logging
import sqlite3
import time
from hashlib import sha256
from app.database import db
from app.decoder import PayloadType, extract_payload, get_packet_payload_type
logger = logging.getLogger(__name__)
class RawPacketRepository:
@staticmethod
async def create(data: bytes, timestamp: int | None = None) -> tuple[int, bool]:
"""
Create a raw packet with payload-based deduplication.
Returns (packet_id, is_new) tuple:
- is_new=True: New packet stored, packet_id is the new row ID
- is_new=False: Duplicate payload detected, packet_id is the existing row ID
Deduplication is based on the SHA-256 hash of the packet payload
(excluding routing/path information).
"""
ts = timestamp if timestamp is not None else int(time.time())
# Compute payload hash for deduplication
payload = extract_payload(data)
if payload:
payload_hash = sha256(payload).digest()
else:
# For malformed packets, hash the full data
payload_hash = sha256(data).digest()
# Check if this payload already exists
cursor = await db.conn.execute(
"SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,)
)
existing = await cursor.fetchone()
if existing:
# Duplicate - return existing packet ID
logger.debug(
"Duplicate payload detected (hash=%s..., existing_id=%d)",
payload_hash.hex()[:12],
existing["id"],
)
return (existing["id"], False)
# New packet - insert with hash
try:
cursor = await db.conn.execute(
"INSERT INTO raw_packets (timestamp, data, payload_hash) VALUES (?, ?, ?)",
(ts, data, payload_hash),
)
await db.conn.commit()
assert cursor.lastrowid is not None # INSERT always returns a row ID
return (cursor.lastrowid, True)
except sqlite3.IntegrityError:
# Race condition: another insert with same payload_hash happened between
# our SELECT and INSERT. This is expected for duplicate packets arriving
# close together. Query again to get the existing ID.
logger.debug(
"Duplicate packet detected via race condition (payload_hash=%s), dropping",
payload_hash.hex()[:16],
)
cursor = await db.conn.execute(
"SELECT id FROM raw_packets WHERE payload_hash = ?", (payload_hash,)
)
existing = await cursor.fetchone()
if existing:
return (existing["id"], False)
# This shouldn't happen, but if it does, re-raise
raise
@staticmethod
async def get_undecrypted_count() -> int:
"""Get count of undecrypted packets (those without a linked message)."""
cursor = await db.conn.execute(
"SELECT COUNT(*) as count FROM raw_packets WHERE message_id IS NULL"
)
row = await cursor.fetchone()
return row["count"] if row else 0
@staticmethod
async def get_oldest_undecrypted() -> int | None:
"""Get timestamp of oldest undecrypted packet, or None if none exist."""
cursor = await db.conn.execute(
"SELECT MIN(timestamp) as oldest FROM raw_packets WHERE message_id IS NULL"
)
row = await cursor.fetchone()
return row["oldest"] if row and row["oldest"] is not None else None
@staticmethod
async def get_all_undecrypted() -> list[tuple[int, bytes, int]]:
"""Get all undecrypted packets as (id, data, timestamp) tuples."""
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
)
rows = await cursor.fetchall()
return [(row["id"], bytes(row["data"]), row["timestamp"]) for row in rows]
@staticmethod
async def mark_decrypted(packet_id: int, message_id: int) -> None:
"""Link a raw packet to its decrypted message."""
await db.conn.execute(
"UPDATE raw_packets SET message_id = ? WHERE id = ?",
(message_id, packet_id),
)
await db.conn.commit()
@staticmethod
async def prune_old_undecrypted(max_age_days: int) -> int:
"""Delete undecrypted packets older than max_age_days. Returns count deleted."""
cutoff = int(time.time()) - (max_age_days * 86400)
cursor = await db.conn.execute(
"DELETE FROM raw_packets WHERE message_id IS NULL AND timestamp < ?",
(cutoff,),
)
await db.conn.commit()
return cursor.rowcount
@staticmethod
async def purge_linked_to_messages() -> int:
"""Delete raw packets that are already linked to a stored message."""
cursor = await db.conn.execute("DELETE FROM raw_packets WHERE message_id IS NOT NULL")
await db.conn.commit()
return cursor.rowcount
@staticmethod
async def get_undecrypted_text_messages() -> list[tuple[int, bytes, int]]:
"""Get all undecrypted TEXT_MESSAGE packets as (id, data, timestamp) tuples.
Filters raw packets to only include those with PayloadType.TEXT_MESSAGE (0x02).
These are direct messages that can be decrypted with contact ECDH keys.
"""
cursor = await db.conn.execute(
"SELECT id, data, timestamp FROM raw_packets WHERE message_id IS NULL ORDER BY timestamp ASC"
)
rows = await cursor.fetchall()
# Filter for TEXT_MESSAGE packets
result = []
for row in rows:
data = bytes(row["data"])
payload_type = get_packet_payload_type(data)
if payload_type == PayloadType.TEXT_MESSAGE:
result.append((row["id"], data, row["timestamp"]))
return result

330
app/repository/settings.py Normal file
View File

@@ -0,0 +1,330 @@
import json
import logging
import time
from typing import Any, Literal
from app.database import db
from app.models import AppSettings, BotConfig, Favorite
logger = logging.getLogger(__name__)
SECONDS_1H = 3600
SECONDS_24H = 86400
SECONDS_7D = 604800
class AppSettingsRepository:
"""Repository for app_settings table (single-row pattern)."""
@staticmethod
async def get() -> AppSettings:
"""Get the current app settings.
Always returns settings - creates default row if needed (migration handles initial row).
"""
cursor = await db.conn.execute(
"""
SELECT max_radio_contacts, favorites, auto_decrypt_dm_on_advert,
sidebar_sort_order, last_message_times, preferences_migrated,
advert_interval, last_advert_time, bots
FROM app_settings WHERE id = 1
"""
)
row = await cursor.fetchone()
if not row:
# Should not happen after migration, but handle gracefully
return AppSettings()
# Parse favorites JSON
favorites = []
if row["favorites"]:
try:
favorites_data = json.loads(row["favorites"])
favorites = [Favorite(**f) for f in favorites_data]
except (json.JSONDecodeError, TypeError, KeyError) as e:
logger.warning(
"Failed to parse favorites JSON, using empty list: %s (data=%r)",
e,
row["favorites"][:100] if row["favorites"] else None,
)
favorites = []
# Parse last_message_times JSON
last_message_times: dict[str, int] = {}
if row["last_message_times"]:
try:
last_message_times = json.loads(row["last_message_times"])
except (json.JSONDecodeError, TypeError) as e:
logger.warning(
"Failed to parse last_message_times JSON, using empty dict: %s",
e,
)
last_message_times = {}
# Parse bots JSON
bots: list[BotConfig] = []
if row["bots"]:
try:
bots_data = json.loads(row["bots"])
bots = [BotConfig(**b) for b in bots_data]
except (json.JSONDecodeError, TypeError, KeyError) as e:
logger.warning(
"Failed to parse bots JSON, using empty list: %s (data=%r)",
e,
row["bots"][:100] if row["bots"] else None,
)
bots = []
# Validate sidebar_sort_order (fallback to "recent" if invalid)
sort_order = row["sidebar_sort_order"]
if sort_order not in ("recent", "alpha"):
sort_order = "recent"
return AppSettings(
max_radio_contacts=row["max_radio_contacts"],
favorites=favorites,
auto_decrypt_dm_on_advert=bool(row["auto_decrypt_dm_on_advert"]),
sidebar_sort_order=sort_order,
last_message_times=last_message_times,
preferences_migrated=bool(row["preferences_migrated"]),
advert_interval=row["advert_interval"] or 0,
last_advert_time=row["last_advert_time"] or 0,
bots=bots,
)
@staticmethod
async def update(
max_radio_contacts: int | None = None,
favorites: list[Favorite] | None = None,
auto_decrypt_dm_on_advert: bool | None = None,
sidebar_sort_order: str | None = None,
last_message_times: dict[str, int] | None = None,
preferences_migrated: bool | None = None,
advert_interval: int | None = None,
last_advert_time: int | None = None,
bots: list[BotConfig] | None = None,
) -> AppSettings:
"""Update app settings. Only provided fields are updated."""
updates = []
params: list[Any] = []
if max_radio_contacts is not None:
updates.append("max_radio_contacts = ?")
params.append(max_radio_contacts)
if favorites is not None:
updates.append("favorites = ?")
favorites_json = json.dumps([f.model_dump() for f in favorites])
params.append(favorites_json)
if auto_decrypt_dm_on_advert is not None:
updates.append("auto_decrypt_dm_on_advert = ?")
params.append(1 if auto_decrypt_dm_on_advert else 0)
if sidebar_sort_order is not None:
updates.append("sidebar_sort_order = ?")
params.append(sidebar_sort_order)
if last_message_times is not None:
updates.append("last_message_times = ?")
params.append(json.dumps(last_message_times))
if preferences_migrated is not None:
updates.append("preferences_migrated = ?")
params.append(1 if preferences_migrated else 0)
if advert_interval is not None:
updates.append("advert_interval = ?")
params.append(advert_interval)
if last_advert_time is not None:
updates.append("last_advert_time = ?")
params.append(last_advert_time)
if bots is not None:
updates.append("bots = ?")
bots_json = json.dumps([b.model_dump() for b in bots])
params.append(bots_json)
if updates:
query = f"UPDATE app_settings SET {', '.join(updates)} WHERE id = 1"
await db.conn.execute(query, params)
await db.conn.commit()
return await AppSettingsRepository.get()
@staticmethod
async def add_favorite(fav_type: Literal["channel", "contact"], fav_id: str) -> AppSettings:
"""Add a favorite, avoiding duplicates."""
settings = await AppSettingsRepository.get()
# Check if already favorited
if any(f.type == fav_type and f.id == fav_id for f in settings.favorites):
return settings
new_favorites = settings.favorites + [Favorite(type=fav_type, id=fav_id)]
return await AppSettingsRepository.update(favorites=new_favorites)
@staticmethod
async def remove_favorite(fav_type: Literal["channel", "contact"], fav_id: str) -> AppSettings:
"""Remove a favorite."""
settings = await AppSettingsRepository.get()
new_favorites = [
f for f in settings.favorites if not (f.type == fav_type and f.id == fav_id)
]
return await AppSettingsRepository.update(favorites=new_favorites)
@staticmethod
async def migrate_preferences_from_frontend(
favorites: list[dict],
sort_order: str,
last_message_times: dict[str, int],
) -> tuple[AppSettings, bool]:
"""Migrate all preferences from frontend localStorage.
This is a one-time migration. If already migrated, returns current settings
without overwriting. Returns (settings, did_migrate) tuple.
"""
settings = await AppSettingsRepository.get()
if settings.preferences_migrated:
# Already migrated, don't overwrite
return settings, False
# Convert frontend favorites format to Favorite objects
new_favorites = []
for f in favorites:
if f.get("type") in ("channel", "contact") and f.get("id"):
new_favorites.append(Favorite(type=f["type"], id=f["id"]))
# Update with migrated preferences and mark as migrated
settings = await AppSettingsRepository.update(
favorites=new_favorites,
sidebar_sort_order=sort_order if sort_order in ("recent", "alpha") else "recent",
last_message_times=last_message_times,
preferences_migrated=True,
)
return settings, True
class StatisticsRepository:
@staticmethod
async def _activity_counts(*, contact_type: int, exclude: bool = False) -> dict[str, int]:
"""Get time-windowed counts for contacts/repeaters heard."""
now = int(time.time())
op = "!=" if exclude else "="
cursor = await db.conn.execute(
f"""
SELECT
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_hour,
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_24_hours,
SUM(CASE WHEN last_seen >= ? THEN 1 ELSE 0 END) AS last_week
FROM contacts
WHERE type {op} ? AND last_seen IS NOT NULL
""",
(now - SECONDS_1H, now - SECONDS_24H, now - SECONDS_7D, contact_type),
)
row = await cursor.fetchone()
assert row is not None # Aggregate query always returns a row
return {
"last_hour": row["last_hour"] or 0,
"last_24_hours": row["last_24_hours"] or 0,
"last_week": row["last_week"] or 0,
}
@staticmethod
async def get_all() -> dict:
"""Aggregate all statistics from existing tables."""
now = int(time.time())
# Top 5 busiest channels in last 24h
cursor = await db.conn.execute(
"""
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
COUNT(*) AS message_count
FROM messages m
LEFT JOIN channels c ON m.conversation_key = c.key
WHERE m.type = 'CHAN' AND m.received_at >= ?
GROUP BY m.conversation_key
ORDER BY COUNT(*) DESC
LIMIT 5
""",
(now - SECONDS_24H,),
)
rows = await cursor.fetchall()
busiest_channels_24h = [
{
"channel_key": row["conversation_key"],
"channel_name": row["channel_name"],
"message_count": row["message_count"],
}
for row in rows
]
# Entity counts
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type != 2")
row = await cursor.fetchone()
assert row is not None
contact_count: int = row["cnt"]
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM contacts WHERE type = 2")
row = await cursor.fetchone()
assert row is not None
repeater_count: int = row["cnt"]
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM channels")
row = await cursor.fetchone()
assert row is not None
channel_count: int = row["cnt"]
# Packet split
cursor = await db.conn.execute(
"""
SELECT COUNT(*) AS total,
SUM(CASE WHEN message_id IS NOT NULL THEN 1 ELSE 0 END) AS decrypted
FROM raw_packets
"""
)
pkt_row = await cursor.fetchone()
assert pkt_row is not None
total_packets = pkt_row["total"] or 0
decrypted_packets = pkt_row["decrypted"] or 0
undecrypted_packets = total_packets - decrypted_packets
# Message type counts
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'PRIV'")
row = await cursor.fetchone()
assert row is not None
total_dms: int = row["cnt"]
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE type = 'CHAN'")
row = await cursor.fetchone()
assert row is not None
total_channel_messages: int = row["cnt"]
# Outgoing count
cursor = await db.conn.execute("SELECT COUNT(*) AS cnt FROM messages WHERE outgoing = 1")
row = await cursor.fetchone()
assert row is not None
total_outgoing: int = row["cnt"]
# Activity windows
contacts_heard = await StatisticsRepository._activity_counts(contact_type=2, exclude=True)
repeaters_heard = await StatisticsRepository._activity_counts(contact_type=2)
return {
"busiest_channels_24h": busiest_channels_24h,
"contact_count": contact_count,
"repeater_count": repeater_count,
"channel_count": channel_count,
"total_packets": total_packets,
"decrypted_packets": decrypted_packets,
"undecrypted_packets": undecrypted_packets,
"total_dms": total_dms,
"total_channel_messages": total_channel_messages,
"total_outgoing": total_outgoing,
"contacts_heard": contacts_heard,
"repeaters_heard": repeaters_heard,
}

View File

@@ -1,36 +1,18 @@
import asyncio
import logging
import random
import time
from typing import TYPE_CHECKING
from fastapi import APIRouter, BackgroundTasks, HTTPException, Query
from meshcore import EventType
from app.dependencies import require_connected
from app.models import (
CONTACT_TYPE_REPEATER,
AclEntry,
CommandRequest,
CommandResponse,
Contact,
ContactActiveRoom,
ContactAdvertPath,
ContactAdvertPathSummary,
ContactDetail,
CreateContactRequest,
LppSensor,
NearestRepeater,
NeighborInfo,
RepeaterAclResponse,
RepeaterAdvertIntervalsResponse,
RepeaterLoginRequest,
RepeaterLoginResponse,
RepeaterLppTelemetryResponse,
RepeaterNeighborsResponse,
RepeaterOwnerInfoResponse,
RepeaterRadioSettingsResponse,
RepeaterStatusResponse,
TraceResponse,
)
from app.packet_processor import start_historical_dm_decryption
@@ -43,111 +25,10 @@ from app.repository import (
MessageRepository,
)
if TYPE_CHECKING:
from meshcore.events import Event
logger = logging.getLogger(__name__)
# ACL permission level names
ACL_PERMISSION_NAMES = {
0: "Guest",
1: "Read-only",
2: "Read-write",
3: "Admin",
}
router = APIRouter(prefix="/contacts", tags=["contacts"])
# Delay between repeater radio operations to allow key exchange and path establishment
REPEATER_OP_DELAY_SECONDS = 2.0
def _monotonic() -> float:
"""Wrapper around time.monotonic() for testability.
Patching time.monotonic directly breaks the asyncio event loop which also
uses it. This indirection allows tests to control the clock safely.
"""
return time.monotonic()
def _extract_response_text(event) -> str:
"""Extract text from a CLI response event, stripping the firmware '> ' prefix."""
text = event.payload.get("text", str(event.payload))
if text.startswith("> "):
text = text[2:]
return text
async def _fetch_repeater_response(
mc,
target_pubkey_prefix: str,
timeout: float = 20.0,
) -> "Event | None":
"""Fetch a CLI response from a specific repeater via a validated get_msg() loop.
Calls get_msg() repeatedly until a matching CLI response (txt_type=1) from the
target repeater arrives or the wall-clock deadline expires. Unrelated messages
are safe to skip — meshcore's event dispatcher already delivers them to the
normal subscription handlers (on_contact_message, etc.) when get_msg() returns.
Args:
mc: MeshCore instance
target_pubkey_prefix: 12-char hex prefix of the repeater's public key
timeout: Wall-clock seconds before giving up
Returns:
The matching Event, or None if no response arrived before the deadline.
"""
deadline = _monotonic() + timeout
while _monotonic() < deadline:
try:
result = await mc.commands.get_msg(timeout=2.0)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.debug("get_msg() exception: %s", e)
await asyncio.sleep(1.0)
continue
if result.type == EventType.NO_MORE_MSGS:
# No messages queued yet — wait and retry
await asyncio.sleep(1.0)
continue
if result.type == EventType.ERROR:
logger.debug("get_msg() error: %s", result.payload)
await asyncio.sleep(1.0)
continue
if result.type == EventType.CONTACT_MSG_RECV:
msg_prefix = result.payload.get("pubkey_prefix", "")
txt_type = result.payload.get("txt_type", 0)
if msg_prefix == target_pubkey_prefix and txt_type == 1:
return result
# Not our target — already dispatched to subscribers by meshcore,
# so just continue draining the queue.
logger.debug(
"Skipping non-target message (from=%s, txt_type=%d) while waiting for %s",
msg_prefix,
txt_type,
target_pubkey_prefix,
)
continue
if result.type == EventType.CHANNEL_MSG_RECV:
# Already dispatched to subscribers by meshcore; skip.
logger.debug(
"Skipping channel message (channel_idx=%s) during repeater fetch",
result.payload.get("channel_idx"),
)
continue
logger.debug("Unexpected event type %s during repeater fetch, skipping", result.type)
logger.warning("No CLI response from repeater %s within %.1fs", target_pubkey_prefix, timeout)
return None
def _ambiguous_contact_detail(err: AmbiguousPublicKeyPrefixError) -> str:
sample = ", ".join(key[:12] for key in err.matches[:2])
@@ -169,42 +50,6 @@ async def _resolve_contact_or_404(
return contact
async def prepare_repeater_connection(mc, contact: Contact, password: str) -> None:
"""Prepare connection to a repeater by adding to radio and logging in.
Args:
mc: MeshCore instance
contact: The repeater contact
password: Password for login (empty string for no password)
Raises:
HTTPException: If login fails
"""
# Add contact to radio with path from DB (non-fatal — contact may already be loaded)
logger.info("Adding repeater %s to radio", contact.public_key[:12])
await _ensure_on_radio(mc, contact)
# Send login with password
logger.info("Sending login to repeater %s", contact.public_key[:12])
login_result = await mc.commands.send_login(contact.public_key, password)
if login_result.type == EventType.ERROR:
raise HTTPException(status_code=401, detail=f"Login failed: {login_result.payload}")
# Wait for key exchange to complete before sending requests
logger.debug("Waiting %.1fs for key exchange to complete", REPEATER_OP_DELAY_SECONDS)
await asyncio.sleep(REPEATER_OP_DELAY_SECONDS)
def _require_repeater(contact: Contact) -> None:
"""Raise 400 if contact is not a repeater."""
if contact.type != CONTACT_TYPE_REPEATER:
raise HTTPException(
status_code=400,
detail=f"Contact is not a repeater (type={contact.type}, expected {CONTACT_TYPE_REPEATER})",
)
async def _ensure_on_radio(mc, contact: Contact) -> None:
"""Add a contact to the radio for routing, raising 500 on failure."""
add_result = await mc.commands.add_contact(contact.to_radio_dict())
@@ -214,272 +59,6 @@ async def _ensure_on_radio(mc, contact: Contact) -> None:
)
# ---------------------------------------------------------------------------
# Granular repeater endpoints — one attempt, no server-side retries.
# Frontend manages retry logic for better UX control.
# ---------------------------------------------------------------------------
@router.post("/{public_key}/repeater/login", response_model=RepeaterLoginResponse)
async def repeater_login(public_key: str, request: RepeaterLoginRequest) -> RepeaterLoginResponse:
"""Log in to a repeater. Adds contact to radio, sends login, waits for key exchange."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_login",
pause_polling=True,
suspend_auto_fetch=True,
) as mc:
await prepare_repeater_connection(mc, contact, request.password)
return RepeaterLoginResponse(status="ok")
@router.post("/{public_key}/repeater/status", response_model=RepeaterStatusResponse)
async def repeater_status(public_key: str) -> RepeaterStatusResponse:
"""Fetch status telemetry from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_status", pause_polling=True, suspend_auto_fetch=True
) as mc:
# Ensure contact is on radio for routing
await _ensure_on_radio(mc, contact)
status = await mc.commands.req_status_sync(contact.public_key, timeout=10, min_timeout=5)
if status is None:
raise HTTPException(status_code=504, detail="No status response from repeater")
return RepeaterStatusResponse(
battery_volts=status.get("bat", 0) / 1000.0,
tx_queue_len=status.get("tx_queue_len", 0),
noise_floor_dbm=status.get("noise_floor", 0),
last_rssi_dbm=status.get("last_rssi", 0),
last_snr_db=status.get("last_snr", 0.0),
packets_received=status.get("nb_recv", 0),
packets_sent=status.get("nb_sent", 0),
airtime_seconds=status.get("airtime", 0),
rx_airtime_seconds=status.get("rx_airtime", 0),
uptime_seconds=status.get("uptime", 0),
sent_flood=status.get("sent_flood", 0),
sent_direct=status.get("sent_direct", 0),
recv_flood=status.get("recv_flood", 0),
recv_direct=status.get("recv_direct", 0),
flood_dups=status.get("flood_dups", 0),
direct_dups=status.get("direct_dups", 0),
full_events=status.get("full_evts", 0),
)
@router.post("/{public_key}/repeater/lpp-telemetry", response_model=RepeaterLppTelemetryResponse)
async def repeater_lpp_telemetry(public_key: str) -> RepeaterLppTelemetryResponse:
"""Fetch CayenneLPP sensor telemetry from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_lpp_telemetry", pause_polling=True, suspend_auto_fetch=True
) as mc:
await _ensure_on_radio(mc, contact)
telemetry = await mc.commands.req_telemetry_sync(
contact.public_key, timeout=10, min_timeout=5
)
if telemetry is None:
raise HTTPException(status_code=504, detail="No telemetry response from repeater")
sensors: list[LppSensor] = []
for entry in telemetry:
channel = entry.get("channel", 0)
type_name = str(entry.get("type", "unknown"))
value = entry.get("value", 0)
sensors.append(LppSensor(channel=channel, type_name=type_name, value=value))
return RepeaterLppTelemetryResponse(sensors=sensors)
@router.post("/{public_key}/repeater/neighbors", response_model=RepeaterNeighborsResponse)
async def repeater_neighbors(public_key: str) -> RepeaterNeighborsResponse:
"""Fetch neighbors from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_neighbors", pause_polling=True, suspend_auto_fetch=True
) as mc:
# Ensure contact is on radio for routing
await _ensure_on_radio(mc, contact)
neighbors_data = await mc.commands.fetch_all_neighbours(
contact.public_key, timeout=10, min_timeout=5
)
neighbors: list[NeighborInfo] = []
if neighbors_data and "neighbours" in neighbors_data:
for n in neighbors_data["neighbours"]:
pubkey_prefix = n.get("pubkey", "")
resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix)
neighbors.append(
NeighborInfo(
pubkey_prefix=pubkey_prefix,
name=resolved_contact.name if resolved_contact else None,
snr=n.get("snr", 0.0),
last_heard_seconds=n.get("secs_ago", 0),
)
)
return RepeaterNeighborsResponse(neighbors=neighbors)
@router.post("/{public_key}/repeater/acl", response_model=RepeaterAclResponse)
async def repeater_acl(public_key: str) -> RepeaterAclResponse:
"""Fetch ACL from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_acl", pause_polling=True, suspend_auto_fetch=True
) as mc:
# Ensure contact is on radio for routing
await _ensure_on_radio(mc, contact)
acl_data = await mc.commands.req_acl_sync(contact.public_key, timeout=10, min_timeout=5)
acl_entries: list[AclEntry] = []
if acl_data and isinstance(acl_data, list):
for entry in acl_data:
pubkey_prefix = entry.get("key", "")
perm = entry.get("perm", 0)
resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix)
acl_entries.append(
AclEntry(
pubkey_prefix=pubkey_prefix,
name=resolved_contact.name if resolved_contact else None,
permission=perm,
permission_name=ACL_PERMISSION_NAMES.get(perm, f"Unknown({perm})"),
)
)
return RepeaterAclResponse(acl=acl_entries)
async def _batch_cli_fetch(
contact: Contact,
operation_name: str,
commands: list[tuple[str, str]],
) -> dict[str, str | None]:
"""Send a batch of CLI commands to a repeater and collect responses.
Opens a radio operation with polling paused and auto-fetch suspended (since
we call get_msg() directly via _fetch_repeater_response), adds the contact
to the radio for routing, then sends each command sequentially with a 1-second
gap between them.
Returns a dict mapping field names to response strings (or None on timeout).
"""
results: dict[str, str | None] = {field: None for _, field in commands}
async with radio_manager.radio_operation(
operation_name,
pause_polling=True,
suspend_auto_fetch=True,
) as mc:
await _ensure_on_radio(mc, contact)
await asyncio.sleep(1.0)
for i, (cmd, field) in enumerate(commands):
if i > 0:
await asyncio.sleep(1.0)
send_result = await mc.commands.send_cmd(contact.public_key, cmd)
if send_result.type == EventType.ERROR:
logger.debug("Command '%s' send error: %s", cmd, send_result.payload)
continue
response_event = await _fetch_repeater_response(
mc, contact.public_key[:12], timeout=10.0
)
if response_event is not None:
results[field] = _extract_response_text(response_event)
else:
logger.warning("No response for command '%s' (%s)", cmd, field)
return results
@router.post("/{public_key}/repeater/radio-settings", response_model=RepeaterRadioSettingsResponse)
async def repeater_radio_settings(public_key: str) -> RepeaterRadioSettingsResponse:
"""Fetch radio settings from a repeater via batch CLI commands."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
results = await _batch_cli_fetch(
contact,
"repeater_radio_settings",
[
("ver", "firmware_version"),
("get radio", "radio"),
("get tx", "tx_power"),
("get af", "airtime_factor"),
("get repeat", "repeat_enabled"),
("get flood.max", "flood_max"),
("get name", "name"),
("get lat", "lat"),
("get lon", "lon"),
("clock", "clock_utc"),
],
)
return RepeaterRadioSettingsResponse(**results)
@router.post(
"/{public_key}/repeater/advert-intervals", response_model=RepeaterAdvertIntervalsResponse
)
async def repeater_advert_intervals(public_key: str) -> RepeaterAdvertIntervalsResponse:
"""Fetch advertisement intervals from a repeater via CLI commands."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
results = await _batch_cli_fetch(
contact,
"repeater_advert_intervals",
[
("get advert.interval", "advert_interval"),
("get flood.advert.interval", "flood_advert_interval"),
],
)
return RepeaterAdvertIntervalsResponse(**results)
@router.post("/{public_key}/repeater/owner-info", response_model=RepeaterOwnerInfoResponse)
async def repeater_owner_info(public_key: str) -> RepeaterOwnerInfoResponse:
"""Fetch owner info and guest password from a repeater via CLI commands."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
results = await _batch_cli_fetch(
contact,
"repeater_owner_info",
[
("get owner.info", "owner_info"),
("get guest.password", "guest_password"),
],
)
return RepeaterOwnerInfoResponse(**results)
@router.get("", response_model=list[Contact])
async def list_contacts(
limit: int = Query(default=100, ge=1, le=1000),
@@ -792,79 +371,6 @@ async def delete_contact(public_key: str) -> dict:
return {"status": "ok"}
@router.post("/{public_key}/command", response_model=CommandResponse)
async def send_repeater_command(public_key: str, request: CommandRequest) -> CommandResponse:
"""Send a CLI command to a repeater.
The contact must be a repeater (type=2). The user must have already logged in
via the repeater/login endpoint. This endpoint ensures the contact is on the
radio before sending commands (the repeater remembers ACL permissions after login).
Common commands:
- get name, set name <value>
- get tx, set tx <dbm>
- get radio, set radio <freq,bw,sf,cr>
- tempradio <freq,bw,sf,cr,minutes>
- setperm <pubkey> <permission> (0=guest, 1=read-only, 2=read-write, 3=admin)
- clock, clock sync
- reboot
- ver
"""
require_connected()
# Get contact from database
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"send_repeater_command",
pause_polling=True,
suspend_auto_fetch=True,
) as mc:
# Add contact to radio with path from DB (non-fatal — contact may already be loaded)
logger.info("Adding repeater %s to radio", contact.public_key[:12])
await _ensure_on_radio(mc, contact)
await asyncio.sleep(1.0)
# Send the command
logger.info("Sending command to repeater %s: %s", contact.public_key[:12], request.command)
send_result = await mc.commands.send_cmd(contact.public_key, request.command)
if send_result.type == EventType.ERROR:
raise HTTPException(
status_code=500, detail=f"Failed to send command: {send_result.payload}"
)
# Wait for response using validated fetch loop
response_event = await _fetch_repeater_response(mc, contact.public_key[:12])
if response_event is None:
logger.warning(
"No response from repeater %s for command: %s",
contact.public_key[:12],
request.command,
)
return CommandResponse(
command=request.command,
response="(no response - command may have been processed)",
)
# CONTACT_MSG_RECV payloads use sender_timestamp in meshcore.
response_text = _extract_response_text(response_event)
sender_timestamp = response_event.payload.get(
"sender_timestamp",
response_event.payload.get("timestamp"),
)
logger.info("Received response from %s: %s", contact.public_key[:12], response_text)
return CommandResponse(
command=request.command,
response=response_text,
sender_timestamp=sender_timestamp,
)
@router.post("/{public_key}/trace", response_model=TraceResponse)
async def request_trace(public_key: str) -> TraceResponse:
"""Send a single-hop trace to a contact and wait for the result.

510
app/routers/repeaters.py Normal file
View File

@@ -0,0 +1,510 @@
import asyncio
import logging
import time
from typing import TYPE_CHECKING
from fastapi import APIRouter, HTTPException
from meshcore import EventType
from app.dependencies import require_connected
from app.models import (
CONTACT_TYPE_REPEATER,
AclEntry,
CommandRequest,
CommandResponse,
Contact,
LppSensor,
NeighborInfo,
RepeaterAclResponse,
RepeaterAdvertIntervalsResponse,
RepeaterLoginRequest,
RepeaterLoginResponse,
RepeaterLppTelemetryResponse,
RepeaterNeighborsResponse,
RepeaterOwnerInfoResponse,
RepeaterRadioSettingsResponse,
RepeaterStatusResponse,
)
from app.radio import radio_manager
from app.repository import ContactRepository
from app.routers.contacts import _ensure_on_radio, _resolve_contact_or_404
if TYPE_CHECKING:
from meshcore.events import Event
logger = logging.getLogger(__name__)
# ACL permission level names
ACL_PERMISSION_NAMES = {
0: "Guest",
1: "Read-only",
2: "Read-write",
3: "Admin",
}
router = APIRouter(prefix="/contacts", tags=["repeaters"])
# Delay between repeater radio operations to allow key exchange and path establishment
REPEATER_OP_DELAY_SECONDS = 2.0
def _monotonic() -> float:
"""Wrapper around time.monotonic() for testability.
Patching time.monotonic directly breaks the asyncio event loop which also
uses it. This indirection allows tests to control the clock safely.
"""
return time.monotonic()
def _extract_response_text(event) -> str:
"""Extract text from a CLI response event, stripping the firmware '> ' prefix."""
text = event.payload.get("text", str(event.payload))
if text.startswith("> "):
text = text[2:]
return text
async def _fetch_repeater_response(
mc,
target_pubkey_prefix: str,
timeout: float = 20.0,
) -> "Event | None":
"""Fetch a CLI response from a specific repeater via a validated get_msg() loop.
Calls get_msg() repeatedly until a matching CLI response (txt_type=1) from the
target repeater arrives or the wall-clock deadline expires. Unrelated messages
are safe to skip — meshcore's event dispatcher already delivers them to the
normal subscription handlers (on_contact_message, etc.) when get_msg() returns.
Args:
mc: MeshCore instance
target_pubkey_prefix: 12-char hex prefix of the repeater's public key
timeout: Wall-clock seconds before giving up
Returns:
The matching Event, or None if no response arrived before the deadline.
"""
deadline = _monotonic() + timeout
while _monotonic() < deadline:
try:
result = await mc.commands.get_msg(timeout=2.0)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.debug("get_msg() exception: %s", e)
await asyncio.sleep(1.0)
continue
if result.type == EventType.NO_MORE_MSGS:
# No messages queued yet — wait and retry
await asyncio.sleep(1.0)
continue
if result.type == EventType.ERROR:
logger.debug("get_msg() error: %s", result.payload)
await asyncio.sleep(1.0)
continue
if result.type == EventType.CONTACT_MSG_RECV:
msg_prefix = result.payload.get("pubkey_prefix", "")
txt_type = result.payload.get("txt_type", 0)
if msg_prefix == target_pubkey_prefix and txt_type == 1:
return result
# Not our target — already dispatched to subscribers by meshcore,
# so just continue draining the queue.
logger.debug(
"Skipping non-target message (from=%s, txt_type=%d) while waiting for %s",
msg_prefix,
txt_type,
target_pubkey_prefix,
)
continue
if result.type == EventType.CHANNEL_MSG_RECV:
# Already dispatched to subscribers by meshcore; skip.
logger.debug(
"Skipping channel message (channel_idx=%s) during repeater fetch",
result.payload.get("channel_idx"),
)
continue
logger.debug("Unexpected event type %s during repeater fetch, skipping", result.type)
logger.warning("No CLI response from repeater %s within %.1fs", target_pubkey_prefix, timeout)
return None
async def prepare_repeater_connection(mc, contact: Contact, password: str) -> None:
"""Prepare connection to a repeater by adding to radio and logging in.
Args:
mc: MeshCore instance
contact: The repeater contact
password: Password for login (empty string for no password)
Raises:
HTTPException: If login fails
"""
# Add contact to radio with path from DB (non-fatal — contact may already be loaded)
logger.info("Adding repeater %s to radio", contact.public_key[:12])
await _ensure_on_radio(mc, contact)
# Send login with password
logger.info("Sending login to repeater %s", contact.public_key[:12])
login_result = await mc.commands.send_login(contact.public_key, password)
if login_result.type == EventType.ERROR:
raise HTTPException(status_code=401, detail=f"Login failed: {login_result.payload}")
# Wait for key exchange to complete before sending requests
logger.debug("Waiting %.1fs for key exchange to complete", REPEATER_OP_DELAY_SECONDS)
await asyncio.sleep(REPEATER_OP_DELAY_SECONDS)
def _require_repeater(contact: Contact) -> None:
"""Raise 400 if contact is not a repeater."""
if contact.type != CONTACT_TYPE_REPEATER:
raise HTTPException(
status_code=400,
detail=f"Contact is not a repeater (type={contact.type}, expected {CONTACT_TYPE_REPEATER})",
)
# ---------------------------------------------------------------------------
# Granular repeater endpoints — one attempt, no server-side retries.
# Frontend manages retry logic for better UX control.
# ---------------------------------------------------------------------------
@router.post("/{public_key}/repeater/login", response_model=RepeaterLoginResponse)
async def repeater_login(public_key: str, request: RepeaterLoginRequest) -> RepeaterLoginResponse:
"""Log in to a repeater. Adds contact to radio, sends login, waits for key exchange."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_login",
pause_polling=True,
suspend_auto_fetch=True,
) as mc:
await prepare_repeater_connection(mc, contact, request.password)
return RepeaterLoginResponse(status="ok")
@router.post("/{public_key}/repeater/status", response_model=RepeaterStatusResponse)
async def repeater_status(public_key: str) -> RepeaterStatusResponse:
"""Fetch status telemetry from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_status", pause_polling=True, suspend_auto_fetch=True
) as mc:
# Ensure contact is on radio for routing
await _ensure_on_radio(mc, contact)
status = await mc.commands.req_status_sync(contact.public_key, timeout=10, min_timeout=5)
if status is None:
raise HTTPException(status_code=504, detail="No status response from repeater")
return RepeaterStatusResponse(
battery_volts=status.get("bat", 0) / 1000.0,
tx_queue_len=status.get("tx_queue_len", 0),
noise_floor_dbm=status.get("noise_floor", 0),
last_rssi_dbm=status.get("last_rssi", 0),
last_snr_db=status.get("last_snr", 0.0),
packets_received=status.get("nb_recv", 0),
packets_sent=status.get("nb_sent", 0),
airtime_seconds=status.get("airtime", 0),
rx_airtime_seconds=status.get("rx_airtime", 0),
uptime_seconds=status.get("uptime", 0),
sent_flood=status.get("sent_flood", 0),
sent_direct=status.get("sent_direct", 0),
recv_flood=status.get("recv_flood", 0),
recv_direct=status.get("recv_direct", 0),
flood_dups=status.get("flood_dups", 0),
direct_dups=status.get("direct_dups", 0),
full_events=status.get("full_evts", 0),
)
@router.post("/{public_key}/repeater/lpp-telemetry", response_model=RepeaterLppTelemetryResponse)
async def repeater_lpp_telemetry(public_key: str) -> RepeaterLppTelemetryResponse:
"""Fetch CayenneLPP sensor telemetry from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_lpp_telemetry", pause_polling=True, suspend_auto_fetch=True
) as mc:
await _ensure_on_radio(mc, contact)
telemetry = await mc.commands.req_telemetry_sync(
contact.public_key, timeout=10, min_timeout=5
)
if telemetry is None:
raise HTTPException(status_code=504, detail="No telemetry response from repeater")
sensors: list[LppSensor] = []
for entry in telemetry:
channel = entry.get("channel", 0)
type_name = str(entry.get("type", "unknown"))
value = entry.get("value", 0)
sensors.append(LppSensor(channel=channel, type_name=type_name, value=value))
return RepeaterLppTelemetryResponse(sensors=sensors)
@router.post("/{public_key}/repeater/neighbors", response_model=RepeaterNeighborsResponse)
async def repeater_neighbors(public_key: str) -> RepeaterNeighborsResponse:
"""Fetch neighbors from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_neighbors", pause_polling=True, suspend_auto_fetch=True
) as mc:
# Ensure contact is on radio for routing
await _ensure_on_radio(mc, contact)
neighbors_data = await mc.commands.fetch_all_neighbours(
contact.public_key, timeout=10, min_timeout=5
)
neighbors: list[NeighborInfo] = []
if neighbors_data and "neighbours" in neighbors_data:
for n in neighbors_data["neighbours"]:
pubkey_prefix = n.get("pubkey", "")
resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix)
neighbors.append(
NeighborInfo(
pubkey_prefix=pubkey_prefix,
name=resolved_contact.name if resolved_contact else None,
snr=n.get("snr", 0.0),
last_heard_seconds=n.get("secs_ago", 0),
)
)
return RepeaterNeighborsResponse(neighbors=neighbors)
@router.post("/{public_key}/repeater/acl", response_model=RepeaterAclResponse)
async def repeater_acl(public_key: str) -> RepeaterAclResponse:
"""Fetch ACL from a repeater (single attempt, 10s timeout)."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"repeater_acl", pause_polling=True, suspend_auto_fetch=True
) as mc:
# Ensure contact is on radio for routing
await _ensure_on_radio(mc, contact)
acl_data = await mc.commands.req_acl_sync(contact.public_key, timeout=10, min_timeout=5)
acl_entries: list[AclEntry] = []
if acl_data and isinstance(acl_data, list):
for entry in acl_data:
pubkey_prefix = entry.get("key", "")
perm = entry.get("perm", 0)
resolved_contact = await ContactRepository.get_by_key_prefix(pubkey_prefix)
acl_entries.append(
AclEntry(
pubkey_prefix=pubkey_prefix,
name=resolved_contact.name if resolved_contact else None,
permission=perm,
permission_name=ACL_PERMISSION_NAMES.get(perm, f"Unknown({perm})"),
)
)
return RepeaterAclResponse(acl=acl_entries)
async def _batch_cli_fetch(
contact: Contact,
operation_name: str,
commands: list[tuple[str, str]],
) -> dict[str, str | None]:
"""Send a batch of CLI commands to a repeater and collect responses.
Opens a radio operation with polling paused and auto-fetch suspended (since
we call get_msg() directly via _fetch_repeater_response), adds the contact
to the radio for routing, then sends each command sequentially with a 1-second
gap between them.
Returns a dict mapping field names to response strings (or None on timeout).
"""
results: dict[str, str | None] = {field: None for _, field in commands}
async with radio_manager.radio_operation(
operation_name,
pause_polling=True,
suspend_auto_fetch=True,
) as mc:
await _ensure_on_radio(mc, contact)
await asyncio.sleep(1.0)
for i, (cmd, field) in enumerate(commands):
if i > 0:
await asyncio.sleep(1.0)
send_result = await mc.commands.send_cmd(contact.public_key, cmd)
if send_result.type == EventType.ERROR:
logger.debug("Command '%s' send error: %s", cmd, send_result.payload)
continue
response_event = await _fetch_repeater_response(
mc, contact.public_key[:12], timeout=10.0
)
if response_event is not None:
results[field] = _extract_response_text(response_event)
else:
logger.warning("No response for command '%s' (%s)", cmd, field)
return results
@router.post("/{public_key}/repeater/radio-settings", response_model=RepeaterRadioSettingsResponse)
async def repeater_radio_settings(public_key: str) -> RepeaterRadioSettingsResponse:
"""Fetch radio settings from a repeater via batch CLI commands."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
results = await _batch_cli_fetch(
contact,
"repeater_radio_settings",
[
("ver", "firmware_version"),
("get radio", "radio"),
("get tx", "tx_power"),
("get af", "airtime_factor"),
("get repeat", "repeat_enabled"),
("get flood.max", "flood_max"),
("get name", "name"),
("get lat", "lat"),
("get lon", "lon"),
("clock", "clock_utc"),
],
)
return RepeaterRadioSettingsResponse(**results)
@router.post(
"/{public_key}/repeater/advert-intervals", response_model=RepeaterAdvertIntervalsResponse
)
async def repeater_advert_intervals(public_key: str) -> RepeaterAdvertIntervalsResponse:
"""Fetch advertisement intervals from a repeater via CLI commands."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
results = await _batch_cli_fetch(
contact,
"repeater_advert_intervals",
[
("get advert.interval", "advert_interval"),
("get flood.advert.interval", "flood_advert_interval"),
],
)
return RepeaterAdvertIntervalsResponse(**results)
@router.post("/{public_key}/repeater/owner-info", response_model=RepeaterOwnerInfoResponse)
async def repeater_owner_info(public_key: str) -> RepeaterOwnerInfoResponse:
"""Fetch owner info and guest password from a repeater via CLI commands."""
require_connected()
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
results = await _batch_cli_fetch(
contact,
"repeater_owner_info",
[
("get owner.info", "owner_info"),
("get guest.password", "guest_password"),
],
)
return RepeaterOwnerInfoResponse(**results)
@router.post("/{public_key}/command", response_model=CommandResponse)
async def send_repeater_command(public_key: str, request: CommandRequest) -> CommandResponse:
"""Send a CLI command to a repeater.
The contact must be a repeater (type=2). The user must have already logged in
via the repeater/login endpoint. This endpoint ensures the contact is on the
radio before sending commands (the repeater remembers ACL permissions after login).
Common commands:
- get name, set name <value>
- get tx, set tx <dbm>
- get radio, set radio <freq,bw,sf,cr>
- tempradio <freq,bw,sf,cr,minutes>
- setperm <pubkey> <permission> (0=guest, 1=read-only, 2=read-write, 3=admin)
- clock, clock sync
- reboot
- ver
"""
require_connected()
# Get contact from database
contact = await _resolve_contact_or_404(public_key)
_require_repeater(contact)
async with radio_manager.radio_operation(
"send_repeater_command",
pause_polling=True,
suspend_auto_fetch=True,
) as mc:
# Add contact to radio with path from DB (non-fatal — contact may already be loaded)
logger.info("Adding repeater %s to radio", contact.public_key[:12])
await _ensure_on_radio(mc, contact)
await asyncio.sleep(1.0)
# Send the command
logger.info("Sending command to repeater %s: %s", contact.public_key[:12], request.command)
send_result = await mc.commands.send_cmd(contact.public_key, request.command)
if send_result.type == EventType.ERROR:
raise HTTPException(
status_code=500, detail=f"Failed to send command: {send_result.payload}"
)
# Wait for response using validated fetch loop
response_event = await _fetch_repeater_response(mc, contact.public_key[:12])
if response_event is None:
logger.warning(
"No response from repeater %s for command: %s",
contact.public_key[:12],
request.command,
)
return CommandResponse(
command=request.command,
response="(no response - command may have been processed)",
)
# CONTACT_MSG_RECV payloads use sender_timestamp in meshcore.
response_text = _extract_response_text(response_event)
sender_timestamp = response_event.payload.get(
"sender_timestamp",
response_event.payload.get("timestamp"),
)
logger.info("Received response from %s: %s", contact.public_key[:12], response_text)
return CommandResponse(
command=request.command,
response=response_text,
sender_timestamp=sender_timestamp,
)

View File

@@ -5,8 +5,11 @@ import shutil
import tempfile
from pathlib import Path
import httpx
import pytest
from app.database import Database
# Use an isolated file-backed SQLite DB for tests that import app.main/TestClient.
# This must be set before app.config/app.database are imported, otherwise the global
# Database instance will bind to the default runtime DB (data/meshcore.db).
@@ -20,3 +23,52 @@ def cleanup_test_db_dir():
"""Clean up temporary pytest DB directory after the test session."""
yield
shutil.rmtree(_TEST_DB_DIR, ignore_errors=True)
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
from app.repository import channels, contacts, messages, raw_packets, settings
db = Database(":memory:")
await db.connect()
submodules = [contacts, channels, messages, raw_packets, settings]
originals = [(mod, mod.db) for mod in submodules]
for mod in submodules:
mod.db = db
# Also patch the db reference used by the packets router for VACUUM
import app.routers.packets as packets_module
original_packets_db = packets_module.db
packets_module.db = db
try:
yield db
finally:
for mod, original in originals:
mod.db = original
packets_module.db = original_packets_db
await db.disconnect()
@pytest.fixture
def client():
"""Create an httpx AsyncClient for testing the app."""
from app.main import app
transport = httpx.ASGITransport(app=app)
return httpx.AsyncClient(transport=transport, base_url="http://test")
@pytest.fixture
def captured_broadcasts():
"""Capture WebSocket broadcasts for verification."""
broadcasts = []
def mock_broadcast(event_type: str, data: dict):
broadcasts.append({"type": event_type, "data": data})
return broadcasts, mock_broadcast

View File

@@ -8,10 +8,8 @@ import hashlib
import time
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from app.database import Database
from app.radio import radio_manager
from app.repository import (
ChannelRepository,
@@ -31,33 +29,6 @@ def _reset_radio_state():
radio_manager._operation_lock = prev_lock
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.fixture
def client():
"""Create an httpx AsyncClient for testing the app."""
from app.main import app
transport = httpx.ASGITransport(app=app)
return httpx.AsyncClient(transport=transport, base_url="http://test")
async def _insert_contact(public_key, name="Alice", **overrides):
"""Insert a contact into the test database."""
data = {

View File

@@ -7,33 +7,13 @@ from the radio and upserts them into the database.
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from meshcore import EventType
from app.database import Database
from app.radio import radio_manager
from app.repository import ChannelRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.fixture(autouse=True)
def _reset_radio_state():
"""Save/restore radio_manager state so tests don't leak."""
@@ -44,15 +24,6 @@ def _reset_radio_state():
radio_manager._operation_lock = prev_lock
@pytest.fixture
def client():
"""Create an httpx AsyncClient for testing the app."""
from app.main import app
transport = httpx.ASGITransport(app=app)
return httpx.AsyncClient(transport=transport, base_url="http://test")
def _make_channel_info(name: str, secret: bytes):
"""Create a mock channel info response."""
result = MagicMock()

View File

@@ -9,11 +9,9 @@ Uses httpx.AsyncClient with real in-memory SQLite database.
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from meshcore import EventType
from app.database import Database
from app.radio import radio_manager
from app.repository import ContactAdvertPathRepository, ContactRepository, MessageRepository
@@ -43,24 +41,6 @@ def _reset_radio_state():
radio_manager._operation_lock = prev_lock
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
async def _insert_contact(public_key=KEY_A, name="Alice", on_radio=False, **overrides):
"""Insert a contact into the test database."""
data = {
@@ -82,15 +62,6 @@ async def _insert_contact(public_key=KEY_A, name="Alice", on_radio=False, **over
await ContactRepository.upsert(data)
@pytest.fixture
def client():
"""Create an httpx AsyncClient for testing the app."""
from app.main import app
transport = httpx.ASGITransport(app=app)
return httpx.AsyncClient(transport=transport, base_url="http://test")
class TestListContacts:
"""Test GET /api/contacts."""

View File

@@ -11,7 +11,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.database import Database
from app.decoder import DecryptedDirectMessage
from app.repository import (
ContactRepository,
@@ -19,36 +18,6 @@ from app.repository import (
RawPacketRepository,
)
@pytest.fixture
async def test_db():
"""Create an in-memory test database."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.fixture
def captured_broadcasts():
"""Capture WebSocket broadcasts for verification."""
broadcasts = []
def mock_broadcast(event_type: str, data: dict):
broadcasts.append({"type": event_type, "data": data})
return broadcasts, mock_broadcast
# Shared test constants
CHANNEL_KEY = "ABC123DEF456ABC123DEF456ABC12345"
CONTACT_PUB = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7"

View File

@@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.database import Database
from app.event_handlers import (
_active_subscriptions,
_pending_acks,
@@ -23,24 +22,6 @@ from app.repository import (
)
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.fixture(autouse=True)
def clear_test_state():
"""Clear pending ACKs and subscriptions before each test."""

View File

@@ -2,28 +2,9 @@
import pytest
from app.database import Database
from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.mark.asyncio
async def test_upsert_stores_lowercase_key(test_db):
await ContactRepository.upsert(

View File

@@ -2,28 +2,8 @@
import pytest
from app.database import Database
from app.repository import MessageRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
CHAN_KEY = "ABC123DEF456ABC123DEF456ABC12345"
DM_KEY = "aa" * 32

View File

@@ -2,28 +2,9 @@
import pytest
from app.database import Database
from app.repository import ContactRepository, MessageRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.mark.asyncio
async def test_claim_prefix_promotes_dm_to_full_key(test_db):
full_key = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7"

View File

@@ -13,7 +13,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.database import Database
from app.decoder import DecryptedDirectMessage, PacketInfo, ParsedAdvertisement, PayloadType
from app.repository import (
ChannelRepository,
@@ -28,41 +27,6 @@ with open(FIXTURES_PATH) as f:
FIXTURES = json.load(f)
@pytest.fixture
async def test_db():
"""Create an in-memory test database.
We need to patch the db module-level variable before any repository
methods are called, so they use our test database.
"""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
# Store original and patch the module attribute directly
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.fixture
def captured_broadcasts():
"""Capture WebSocket broadcasts for verification."""
broadcasts = []
def mock_broadcast(event_type: str, data: dict):
"""Synchronous mock that captures broadcasts."""
broadcasts.append({"type": event_type, "data": data})
return broadcasts, mock_broadcast
class TestChannelMessagePipeline:
"""Test channel message flow: packet → decrypt → store → broadcast."""

View File

@@ -7,47 +7,11 @@ undecrypted count endpoint, and the maintenance endpoint.
import time
from unittest.mock import patch
import httpx
import pytest
from app.database import Database
from app.repository import ChannelRepository, MessageRepository, RawPacketRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
# Also patch the db reference used by the packets router for VACUUM
import app.routers.packets as packets_module
original_packets_db = packets_module.db
packets_module.db = db
try:
yield db
finally:
repo_module.db = original_db
packets_module.db = original_packets_db
await db.disconnect()
@pytest.fixture
def client():
"""Create an httpx AsyncClient for testing the app."""
from app.main import app
transport = httpx.ASGITransport(app=app)
return httpx.AsyncClient(transport=transport, base_url="http://test")
async def _insert_raw_packets(count: int, decrypted: bool = False, age_days: int = 0) -> list[int]:
"""Insert raw packets and return their IDs."""
ids = []

View File

@@ -10,7 +10,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from meshcore import EventType
from app.database import Database
from app.models import Favorite
from app.radio import RadioManager, radio_manager
from app.radio_sync import (
@@ -30,24 +29,6 @@ from app.repository import (
)
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.fixture(autouse=True)
def reset_sync_state():
"""Reset polling pause state, sync timestamp, and radio_manager before/after each test."""

View File

@@ -6,11 +6,11 @@ import pytest
from fastapi import HTTPException
from meshcore import EventType
from app.database import Database
from app.models import CommandRequest, Contact, RepeaterLoginRequest
from app.radio import radio_manager
from app.repository import ContactRepository
from app.routers.contacts import (
from app.routers.contacts import request_trace
from app.routers.repeaters import (
_batch_cli_fetch,
_fetch_repeater_response,
repeater_acl,
@@ -21,7 +21,6 @@ from app.routers.contacts import (
repeater_owner_info,
repeater_radio_settings,
repeater_status,
request_trace,
send_repeater_command,
)
@@ -29,7 +28,7 @@ KEY_A = "aa" * 32
# Patch target for the wall-clock wrapper used by _fetch_repeater_response.
# We patch _monotonic (not time.monotonic) to avoid breaking the asyncio event loop.
_MONOTONIC = "app.routers.contacts._monotonic"
_MONOTONIC = "app.routers.repeaters._monotonic"
@pytest.fixture(autouse=True)
@@ -42,24 +41,6 @@ def _reset_radio_state():
radio_manager._operation_lock = prev_lock
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
def _radio_result(event_type=EventType.OK, payload=None):
result = MagicMock()
result.type = event_type
@@ -210,7 +191,7 @@ class TestFetchRepeaterResponse:
with (
patch(_MONOTONIC, side_effect=_advancing_clock()),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
result = await _fetch_repeater_response(mc, "aaaaaaaaaaaa", timeout=5.0)
@@ -229,7 +210,7 @@ class TestFetchRepeaterResponse:
with (
patch(_MONOTONIC, side_effect=times),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
result = await _fetch_repeater_response(mc, "aaaaaaaaaaaa", timeout=2.0)
@@ -247,7 +228,7 @@ class TestFetchRepeaterResponse:
with (
patch(_MONOTONIC, side_effect=_advancing_clock()),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
result = await _fetch_repeater_response(mc, "aaaaaaaaaaaa", timeout=5.0)
@@ -290,7 +271,7 @@ class TestRepeaterCommandRoute:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -308,10 +289,10 @@ class TestRepeaterCommandRoute:
# Expire the deadline after a couple of ticks
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=[0.0, 5.0, 25.0]),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
response = await send_repeater_command(KEY_A, CommandRequest(command="ver"))
@@ -337,7 +318,7 @@ class TestRepeaterCommandRoute:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -365,7 +346,7 @@ class TestRepeaterCommandRoute:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -391,7 +372,7 @@ class TestRepeaterCommandRoute:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -419,7 +400,7 @@ class TestRepeaterCommandRoute:
mc.commands.get_msg = AsyncMock(side_effect=[unrelated, expected])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -445,7 +426,7 @@ class TestRepeaterCommandRoute:
mc.commands.get_msg = AsyncMock(side_effect=[channel_msg, expected])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -468,10 +449,10 @@ class TestRepeaterCommandRoute:
mc.commands.get_msg = AsyncMock(side_effect=[no_msgs, expected])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
response = await send_repeater_command(KEY_A, CommandRequest(command="ver"))
@@ -548,10 +529,10 @@ class TestRepeaterLogin:
await _insert_contact(KEY_A, name="Repeater", contact_type=2)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(
"app.routers.contacts.prepare_repeater_connection",
"app.routers.repeaters.prepare_repeater_connection",
new_callable=AsyncMock,
) as mock_prepare,
):
@@ -564,7 +545,7 @@ class TestRepeaterLogin:
async def test_404_missing_contact(self, test_db):
mc = _mock_mc()
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -576,7 +557,7 @@ class TestRepeaterLogin:
mc = _mock_mc()
await _insert_contact(KEY_A, name="Client", contact_type=1)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -593,9 +574,9 @@ class TestRepeaterLogin:
raise HTTPException(status_code=401, detail="Login failed")
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch("app.routers.contacts.prepare_repeater_connection", side_effect=_prepare_fail),
patch("app.routers.repeaters.prepare_repeater_connection", side_effect=_prepare_fail),
):
with pytest.raises(HTTPException) as exc:
await repeater_login(KEY_A, RepeaterLoginRequest(password="bad"))
@@ -630,7 +611,7 @@ class TestRepeaterStatus:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_status(KEY_A)
@@ -653,7 +634,7 @@ class TestRepeaterStatus:
mc.commands.req_status_sync = AsyncMock(return_value=None)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -665,7 +646,7 @@ class TestRepeaterStatus:
mc = _mock_mc()
await _insert_contact(KEY_A, name="Client", contact_type=1)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -691,7 +672,7 @@ class TestRepeaterLppTelemetry:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_lpp_telemetry(KEY_A)
@@ -713,7 +694,7 @@ class TestRepeaterLppTelemetry:
mc.commands.req_telemetry_sync = AsyncMock(return_value=[])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_lpp_telemetry(KEY_A)
@@ -727,7 +708,7 @@ class TestRepeaterLppTelemetry:
mc.commands.req_telemetry_sync = AsyncMock(return_value=None)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -739,7 +720,7 @@ class TestRepeaterLppTelemetry:
mc = _mock_mc()
await _insert_contact(KEY_A, name="Client", contact_type=1)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -765,7 +746,7 @@ class TestRepeaterNeighbors:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_neighbors(KEY_A)
@@ -783,7 +764,7 @@ class TestRepeaterNeighbors:
mc.commands.fetch_all_neighbours = AsyncMock(return_value={"neighbours": []})
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_neighbors(KEY_A)
@@ -797,7 +778,7 @@ class TestRepeaterNeighbors:
mc.commands.fetch_all_neighbours = AsyncMock(return_value=None)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_neighbors(KEY_A)
@@ -821,7 +802,7 @@ class TestRepeaterAcl:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_acl(KEY_A)
@@ -839,7 +820,7 @@ class TestRepeaterAcl:
mc.commands.req_acl_sync = AsyncMock(return_value=[])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_acl(KEY_A)
@@ -853,7 +834,7 @@ class TestRepeaterAcl:
mc.commands.req_acl_sync = AsyncMock(return_value=None)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
response = await repeater_acl(KEY_A)
@@ -890,7 +871,7 @@ class TestRepeaterRadioSettings:
mc.commands.get_msg = AsyncMock(side_effect=get_msg_results)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -927,10 +908,10 @@ class TestRepeaterRadioSettings:
clock_ticks.extend([base, base + 5.0, base + 11.0])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=clock_ticks),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
response = await repeater_radio_settings(KEY_A)
@@ -943,7 +924,7 @@ class TestRepeaterRadioSettings:
mc = _mock_mc()
await _insert_contact(KEY_A, name="Client", contact_type=1)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -970,7 +951,7 @@ class TestRepeaterAdvertIntervals:
mc.commands.get_msg = AsyncMock(side_effect=responses)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -991,10 +972,10 @@ class TestRepeaterAdvertIntervals:
clock_ticks.extend([base, base + 5.0, base + 11.0])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=clock_ticks),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
response = await repeater_advert_intervals(KEY_A)
@@ -1025,7 +1006,7 @@ class TestRepeaterOwnerInfo:
mc.commands.get_msg = AsyncMock(side_effect=responses)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
):
@@ -1046,10 +1027,10 @@ class TestRepeaterOwnerInfo:
clock_ticks.extend([base, base + 5.0, base + 11.0])
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=clock_ticks),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
response = await repeater_owner_info(KEY_A)
@@ -1107,7 +1088,7 @@ class TestBatchCliFetch:
with (
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=_advancing_clock()),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
results = await _batch_cli_fetch(
contact, "test_op", [("bad_cmd", "field_a"), ("good_cmd", "field_b")]
@@ -1128,7 +1109,7 @@ class TestBatchCliFetch:
with (
patch.object(radio_manager, "_meshcore", mc),
patch(_MONOTONIC, side_effect=[0.0, 5.0, 11.0]),
patch("app.routers.contacts.asyncio.sleep", new_callable=AsyncMock),
patch("app.routers.repeaters.asyncio.sleep", new_callable=AsyncMock),
):
results = await _batch_cli_fetch(contact, "test_op", [("clock", "clock_output")])
@@ -1147,7 +1128,7 @@ class TestRepeaterAddContactError:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -1165,7 +1146,7 @@ class TestRepeaterAddContactError:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -1183,7 +1164,7 @@ class TestRepeaterAddContactError:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:
@@ -1201,7 +1182,7 @@ class TestRepeaterAddContactError:
)
with (
patch("app.routers.contacts.require_connected", return_value=mc),
patch("app.routers.repeaters.require_connected", return_value=mc),
patch.object(radio_manager, "_meshcore", mc),
):
with pytest.raises(HTTPException) as exc:

View File

@@ -4,7 +4,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.database import Database
from app.repository import (
ContactAdvertPathRepository,
ContactNameHistoryRepository,
@@ -13,24 +12,6 @@ from app.repository import (
)
@pytest.fixture
async def test_db():
"""Create an in-memory test database with the module-level db swapped in."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
async def _create_message(test_db, **overrides) -> int:
"""Helper to insert a message and return its id."""
defaults = {
@@ -90,7 +71,7 @@ class TestMessageRepositoryAddPath:
"""Adding a path without received_at uses current timestamp."""
msg_id = await _create_message(test_db)
with patch("app.repository.time") as mock_time:
with patch("app.repository.messages.time") as mock_time:
mock_time.time.return_value = 1700000500.5
result = await MessageRepository.add_path(message_id=msg_id, path="1A2B")
@@ -518,7 +499,7 @@ class TestAppSettingsRepository:
mock_db = MagicMock()
mock_db.conn = mock_conn
with patch("app.repository.db", mock_db):
with patch("app.repository.settings.db", mock_db):
from app.repository import AppSettingsRepository
settings = await AppSettingsRepository.get()

View File

@@ -8,7 +8,6 @@ import pytest
from fastapi import HTTPException
from meshcore import EventType
from app.database import Database
from app.models import (
SendChannelMessageRequest,
SendDirectMessageRequest,
@@ -36,24 +35,6 @@ def _reset_radio_state():
radio_manager._operation_lock = prev_lock
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
def _make_radio_result(payload=None):
"""Create a mock radio command result."""
result = MagicMock()

View File

@@ -3,7 +3,6 @@
import pytest
from fastapi import HTTPException
from app.database import Database
from app.models import AppSettings, BotConfig
from app.repository import AppSettingsRepository
from app.routers.settings import (
@@ -16,24 +15,6 @@ from app.routers.settings import (
)
@pytest.fixture
async def test_db():
"""Create an in-memory test database with schema + migrations."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
class TestUpdateSettings:
@pytest.mark.asyncio
async def test_forwards_only_provided_fields(self, test_db):

View File

@@ -4,28 +4,9 @@ import time
import pytest
from app.database import Database
from app.repository import StatisticsRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database with the module-level db swapped in."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
class TestStatisticsEmpty:
@pytest.mark.asyncio
async def test_empty_database(self, test_db):