Files
Remote-Terminal-for-MeshCore/app/repository/contacts.py
2026-03-09 23:42:46 -07:00

558 lines
21 KiB
Python

import time
from collections.abc import Mapping
from typing import Any
from app.database import db
from app.models import (
Contact,
ContactAdvertPath,
ContactAdvertPathSummary,
ContactNameHistory,
ContactUpsert,
)
from app.path_utils import first_hop_hex, normalize_contact_route, normalize_route_override
class AmbiguousPublicKeyPrefixError(ValueError):
"""Raised when a public key prefix matches multiple contacts."""
def __init__(self, prefix: str, matches: list[str]):
self.prefix = prefix.lower()
self.matches = matches
super().__init__(f"Ambiguous public key prefix '{self.prefix}'")
class ContactRepository:
@staticmethod
def _coerce_contact_upsert(
contact: ContactUpsert | Contact | Mapping[str, Any],
) -> ContactUpsert:
if isinstance(contact, ContactUpsert):
return contact
if isinstance(contact, Contact):
return contact.to_upsert()
return ContactUpsert.model_validate(contact)
@staticmethod
async def upsert(contact: ContactUpsert | Contact | Mapping[str, Any]) -> None:
contact_row = ContactRepository._coerce_contact_upsert(contact)
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
contact_row.last_path,
contact_row.last_path_len,
contact_row.out_path_hash_mode,
)
route_override_path, route_override_len, route_override_hash_mode = (
normalize_route_override(
contact_row.route_override_path,
contact_row.route_override_len,
contact_row.route_override_hash_mode,
)
)
await db.conn.execute(
"""
INSERT INTO contacts (public_key, name, type, flags, last_path, last_path_len,
out_path_hash_mode,
route_override_path, route_override_len,
route_override_hash_mode,
last_advert, lat, lon, last_seen,
on_radio, last_contacted, first_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(public_key) DO UPDATE SET
name = COALESCE(excluded.name, contacts.name),
type = CASE WHEN excluded.type = 0 THEN contacts.type ELSE excluded.type END,
flags = excluded.flags,
last_path = COALESCE(excluded.last_path, contacts.last_path),
last_path_len = excluded.last_path_len,
out_path_hash_mode = excluded.out_path_hash_mode,
route_override_path = COALESCE(
excluded.route_override_path, contacts.route_override_path
),
route_override_len = COALESCE(
excluded.route_override_len, contacts.route_override_len
),
route_override_hash_mode = COALESCE(
excluded.route_override_hash_mode, contacts.route_override_hash_mode
),
last_advert = COALESCE(excluded.last_advert, contacts.last_advert),
lat = COALESCE(excluded.lat, contacts.lat),
lon = COALESCE(excluded.lon, contacts.lon),
last_seen = excluded.last_seen,
on_radio = COALESCE(excluded.on_radio, contacts.on_radio),
last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted),
first_seen = COALESCE(contacts.first_seen, excluded.first_seen)
""",
(
contact_row.public_key.lower(),
contact_row.name,
contact_row.type,
contact_row.flags,
last_path,
last_path_len,
out_path_hash_mode,
route_override_path,
route_override_len,
route_override_hash_mode,
contact_row.last_advert,
contact_row.lat,
contact_row.lon,
contact_row.last_seen if contact_row.last_seen is not None else int(time.time()),
contact_row.on_radio,
contact_row.last_contacted,
contact_row.first_seen,
),
)
await db.conn.commit()
@staticmethod
def _row_to_contact(row) -> Contact:
"""Convert a database row to a Contact model."""
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
row["last_path"],
row["last_path_len"],
row["out_path_hash_mode"],
)
available_columns = set(row.keys())
route_override_path = (
row["route_override_path"] if "route_override_path" in available_columns else None
)
route_override_len = (
row["route_override_len"] if "route_override_len" in available_columns else None
)
route_override_hash_mode = (
row["route_override_hash_mode"]
if "route_override_hash_mode" in available_columns
else None
)
route_override_path, route_override_len, route_override_hash_mode = (
normalize_route_override(
route_override_path,
route_override_len,
route_override_hash_mode,
)
)
return Contact(
public_key=row["public_key"],
name=row["name"],
type=row["type"],
flags=row["flags"],
last_path=last_path,
last_path_len=last_path_len,
out_path_hash_mode=out_path_hash_mode,
route_override_path=route_override_path,
route_override_len=route_override_len,
route_override_hash_mode=route_override_hash_mode,
last_advert=row["last_advert"],
lat=row["lat"],
lon=row["lon"],
last_seen=row["last_seen"],
on_radio=bool(row["on_radio"]),
last_contacted=row["last_contacted"],
last_read_at=row["last_read_at"],
first_seen=row["first_seen"],
)
@staticmethod
async def get_by_key(public_key: str) -> Contact | None:
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),)
)
row = await cursor.fetchone()
return ContactRepository._row_to_contact(row) if row else None
@staticmethod
async def get_by_key_prefix(prefix: str) -> Contact | None:
"""Get a contact by key prefix only if it resolves uniquely.
Returns None when no contacts match OR when multiple contacts match
the prefix (to avoid silently selecting the wrong contact).
"""
normalized_prefix = prefix.lower()
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2",
(f"{normalized_prefix}%",),
)
rows = list(await cursor.fetchall())
if len(rows) != 1:
return None
return ContactRepository._row_to_contact(rows[0])
@staticmethod
async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]:
"""Get contacts matching a key prefix, up to limit."""
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?",
(f"{prefix.lower()}%", limit),
)
rows = list(await cursor.fetchall())
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def get_by_key_or_prefix(key_or_prefix: str) -> Contact | None:
"""Get a contact by exact key match, falling back to prefix match.
Useful when the input might be a full 64-char public key or a shorter prefix.
"""
contact = await ContactRepository.get_by_key(key_or_prefix)
if contact:
return contact
matches = await ContactRepository._get_prefix_matches(key_or_prefix, limit=2)
if len(matches) == 1:
return matches[0]
if len(matches) > 1:
raise AmbiguousPublicKeyPrefixError(
key_or_prefix,
[m.public_key for m in matches],
)
return None
@staticmethod
async def get_by_name(name: str) -> list[Contact]:
"""Get all contacts with the given exact name."""
cursor = await db.conn.execute("SELECT * FROM contacts WHERE name = ?", (name,))
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def resolve_prefixes(prefixes: list[str]) -> dict[str, Contact]:
"""Resolve multiple key prefixes to contacts in a single query.
Returns a dict mapping each prefix to its Contact, only for prefixes
that resolve uniquely (exactly one match). Ambiguous or unmatched
prefixes are omitted.
"""
if not prefixes:
return {}
normalized = [p.lower() for p in prefixes]
conditions = " OR ".join(["public_key LIKE ?"] * len(normalized))
params = [f"{p}%" for p in normalized]
cursor = await db.conn.execute(f"SELECT * FROM contacts WHERE {conditions}", params)
rows = await cursor.fetchall()
# Group by which prefix each row matches
prefix_to_rows: dict[str, list] = {p: [] for p in normalized}
for row in rows:
pk = row["public_key"]
for p in normalized:
if pk.startswith(p):
prefix_to_rows[p].append(row)
# Only include uniquely-resolved prefixes
result: dict[str, Contact] = {}
for p in normalized:
if len(prefix_to_rows[p]) == 1:
result[p] = ContactRepository._row_to_contact(prefix_to_rows[p][0])
return result
@staticmethod
async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]:
cursor = await db.conn.execute(
"SELECT * FROM contacts ORDER BY COALESCE(name, public_key) LIMIT ? OFFSET ?",
(limit, offset),
)
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def get_recent_non_repeaters(limit: int = 200) -> list[Contact]:
"""Get the most recently active non-repeater contacts.
Orders by most recent activity (last_contacted or last_advert),
excluding repeaters (type=2).
"""
cursor = await db.conn.execute(
"""
SELECT * FROM contacts
WHERE type != 2
ORDER BY COALESCE(last_contacted, 0) DESC, COALESCE(last_advert, 0) DESC
LIMIT ?
""",
(limit,),
)
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def update_path(
public_key: str,
path: str,
path_len: int,
out_path_hash_mode: int | None = None,
) -> None:
normalized_path, normalized_path_len, normalized_hash_mode = normalize_contact_route(
path,
path_len,
out_path_hash_mode,
)
await db.conn.execute(
"""UPDATE contacts SET last_path = ?, last_path_len = ?,
out_path_hash_mode = COALESCE(?, out_path_hash_mode),
last_seen = ? WHERE public_key = ?""",
(
normalized_path,
normalized_path_len,
normalized_hash_mode,
int(time.time()),
public_key.lower(),
),
)
await db.conn.commit()
@staticmethod
async def set_routing_override(
public_key: str,
path: str | None,
path_len: int | None,
out_path_hash_mode: int | None = None,
) -> None:
normalized_path, normalized_len, normalized_hash_mode = normalize_route_override(
path,
path_len,
out_path_hash_mode,
)
await db.conn.execute(
"""
UPDATE contacts
SET route_override_path = ?, route_override_len = ?, route_override_hash_mode = ?
WHERE public_key = ?
""",
(
normalized_path,
normalized_len,
normalized_hash_mode,
public_key.lower(),
),
)
await db.conn.commit()
@staticmethod
async def clear_routing_override(public_key: str) -> None:
await db.conn.execute(
"""
UPDATE contacts
SET route_override_path = NULL,
route_override_len = NULL,
route_override_hash_mode = NULL
WHERE public_key = ?
""",
(public_key.lower(),),
)
await db.conn.commit()
@staticmethod
async def set_on_radio(public_key: str, on_radio: bool) -> None:
await db.conn.execute(
"UPDATE contacts SET on_radio = ? WHERE public_key = ?",
(on_radio, public_key.lower()),
)
await db.conn.commit()
@staticmethod
async def clear_on_radio_except(keep_keys: list[str]) -> None:
"""Set on_radio=False for all contacts NOT in keep_keys."""
if not keep_keys:
await db.conn.execute("UPDATE contacts SET on_radio = 0 WHERE on_radio = 1")
else:
placeholders = ",".join("?" * len(keep_keys))
await db.conn.execute(
f"UPDATE contacts SET on_radio = 0 WHERE on_radio = 1 AND public_key NOT IN ({placeholders})",
keep_keys,
)
await db.conn.commit()
@staticmethod
async def delete(public_key: str) -> None:
normalized = public_key.lower()
await db.conn.execute(
"DELETE FROM contact_name_history WHERE public_key = ?", (normalized,)
)
await db.conn.execute(
"DELETE FROM contact_advert_paths WHERE public_key = ?", (normalized,)
)
await db.conn.execute("DELETE FROM contacts WHERE public_key = ?", (normalized,))
await db.conn.commit()
@staticmethod
async def update_last_contacted(public_key: str, timestamp: int | None = None) -> None:
"""Update the last_contacted timestamp for a contact."""
ts = timestamp if timestamp is not None else int(time.time())
await db.conn.execute(
"UPDATE contacts SET last_contacted = ?, last_seen = ? WHERE public_key = ?",
(ts, ts, public_key.lower()),
)
await db.conn.commit()
@staticmethod
async def update_last_read_at(public_key: str, timestamp: int | None = None) -> bool:
"""Update the last_read_at timestamp for a contact.
Returns True if a row was updated, False if contact not found.
"""
ts = timestamp if timestamp is not None else int(time.time())
cursor = await db.conn.execute(
"UPDATE contacts SET last_read_at = ? WHERE public_key = ?",
(ts, public_key.lower()),
)
await db.conn.commit()
return cursor.rowcount > 0
@staticmethod
async def mark_all_read(timestamp: int) -> None:
"""Mark all contacts as read at the given timestamp."""
await db.conn.execute("UPDATE contacts SET last_read_at = ?", (timestamp,))
await db.conn.commit()
@staticmethod
async def get_by_pubkey_first_byte(hex_byte: str) -> list[Contact]:
"""Get contacts whose public key starts with the given hex byte (2 chars)."""
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?",
(hex_byte.lower(),),
)
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
class ContactAdvertPathRepository:
"""Repository for recent unique advertisement paths per contact."""
@staticmethod
def _row_to_path(row) -> ContactAdvertPath:
path = row["path_hex"] or ""
path_len = row["path_len"]
next_hop = first_hop_hex(path, path_len)
return ContactAdvertPath(
path=path,
path_len=path_len,
next_hop=next_hop,
first_seen=row["first_seen"],
last_seen=row["last_seen"],
heard_count=row["heard_count"],
)
@staticmethod
async def record_observation(
public_key: str,
path_hex: str,
timestamp: int,
max_paths: int = 10,
hop_count: int | None = None,
) -> None:
"""
Upsert a unique advert path observation for a contact and prune to N most recent.
"""
if max_paths < 1:
max_paths = 1
normalized_key = public_key.lower()
normalized_path = path_hex.lower()
path_len = hop_count if hop_count is not None else len(normalized_path) // 2
await db.conn.execute(
"""
INSERT INTO contact_advert_paths
(public_key, path_hex, path_len, first_seen, last_seen, heard_count)
VALUES (?, ?, ?, ?, ?, 1)
ON CONFLICT(public_key, path_hex, path_len) DO UPDATE SET
last_seen = MAX(contact_advert_paths.last_seen, excluded.last_seen),
heard_count = contact_advert_paths.heard_count + 1
""",
(normalized_key, normalized_path, path_len, timestamp, timestamp),
)
# Keep only the N most recent unique paths per contact.
await db.conn.execute(
"""
DELETE FROM contact_advert_paths
WHERE public_key = ?
AND id NOT IN (
SELECT id
FROM contact_advert_paths
WHERE public_key = ?
ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC
LIMIT ?
)
""",
(normalized_key, normalized_key, max_paths),
)
await db.conn.commit()
@staticmethod
async def get_recent_for_contact(public_key: str, limit: int = 10) -> list[ContactAdvertPath]:
cursor = await db.conn.execute(
"""
SELECT path_hex, path_len, first_seen, last_seen, heard_count
FROM contact_advert_paths
WHERE public_key = ?
ORDER BY last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC
LIMIT ?
""",
(public_key.lower(), limit),
)
rows = await cursor.fetchall()
return [ContactAdvertPathRepository._row_to_path(row) for row in rows]
@staticmethod
async def get_recent_for_all_contacts(
limit_per_contact: int = 10,
) -> list[ContactAdvertPathSummary]:
cursor = await db.conn.execute(
"""
SELECT public_key, path_hex, path_len, first_seen, last_seen, heard_count
FROM contact_advert_paths
ORDER BY public_key ASC, last_seen DESC, heard_count DESC, path_len ASC, path_hex ASC
"""
)
rows = await cursor.fetchall()
grouped: dict[str, list[ContactAdvertPath]] = {}
for row in rows:
key = row["public_key"]
paths = grouped.get(key)
if paths is None:
paths = []
grouped[key] = paths
if len(paths) >= limit_per_contact:
continue
paths.append(ContactAdvertPathRepository._row_to_path(row))
return [
ContactAdvertPathSummary(public_key=key, paths=paths) for key, paths in grouped.items()
]
class ContactNameHistoryRepository:
"""Repository for contact name change history."""
@staticmethod
async def record_name(public_key: str, name: str, timestamp: int) -> None:
"""Record a name observation. Upserts: updates last_seen if name already known."""
await db.conn.execute(
"""
INSERT INTO contact_name_history (public_key, name, first_seen, last_seen)
VALUES (?, ?, ?, ?)
ON CONFLICT(public_key, name) DO UPDATE SET
last_seen = MAX(contact_name_history.last_seen, excluded.last_seen)
""",
(public_key.lower(), name, timestamp, timestamp),
)
await db.conn.commit()
@staticmethod
async def get_history(public_key: str) -> list[ContactNameHistory]:
cursor = await db.conn.execute(
"""
SELECT name, first_seen, last_seen
FROM contact_name_history
WHERE public_key = ?
ORDER BY last_seen DESC
""",
(public_key.lower(),),
)
rows = await cursor.fetchall()
return [
ContactNameHistory(
name=row["name"], first_seen=row["first_seen"], last_seen=row["last_seen"]
)
for row in rows
]