diff --git a/AGENTS.md b/AGENTS.md index b3a684a..f2b7173 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -443,11 +443,11 @@ Community MQTT forwards raw packets only. Its derived `path` field, when present ### Web Push Notifications -Web Push is a standalone subsystem (`app/push/`) that sends browser push notifications for incoming messages even when the browser tab is closed. It is **not** a fanout module — it manages its own per-browser subscriptions with server-side filter preferences. +Web Push is a standalone subsystem (`app/push/`) that sends browser push notifications for incoming messages even when the browser tab is closed. It is **not** a fanout module — it manages its own per-browser subscriptions, while the set of push-enabled conversations is stored once per server instance. - **Requires HTTPS** (self-signed certificates work) and outbound internet from the server to reach browser push services (Google FCM, Mozilla autopush). - VAPID key pair is auto-generated on first startup and stored in `app_settings`. -- Each browser subscription is stored in `push_subscriptions` with per-conversation filter preferences (`all_messages`, `all_dms`, or `selected` conversations). +- Each browser subscription is stored in `push_subscriptions` with device identity and delivery state. The set of push-enabled conversations is stored globally in `app_settings.push_conversations`, so all subscribed browsers receive the same configured rooms/DMs. - `broadcast_event()` in `websocket.py` dispatches to `push_manager.dispatch_message()` alongside fanout for `message` events. - Expired subscriptions (HTTP 404/410 from push service) are auto-deleted. - Frontend: service worker (`sw.js`) handles push display and notification click navigation. The `BellRing` icon in `ChatHeader` toggles per-conversation push. Device management lives in Settings > Local. diff --git a/app/AGENTS.md b/app/AGENTS.md index 68fd09a..ae650cf 100644 --- a/app/AGENTS.md +++ b/app/AGENTS.md @@ -177,9 +177,9 @@ app/ Web Push is a standalone subsystem in `app/push/`, separate from the fanout module system. It sends browser push notifications for incoming messages even when the tab is closed. -- **Not a fanout module** — Web Push manages per-browser subscriptions (N browsers, each with own endpoint and preferences), unlike fanout which is one-config-to-one-destination. +- **Not a fanout module** — Web Push manages per-browser subscriptions (N browsers, each with its own endpoint and delivery state), unlike fanout which is one-config-to-one-destination. - **VAPID keys**: auto-generated P-256 key pair on first startup, stored in `app_settings.vapid_private_key` / `vapid_public_key`. Cached in-module by `app/push/vapid.py`. -- **Dispatch**: `broadcast_event()` in `websocket.py` fires `push_manager.dispatch_message(data)` alongside fanout for `message` events. The manager loads all subscriptions, filters each by its `filter_mode` (`all_messages`, `all_dms`, `selected`), builds a notification payload, and sends concurrently via `pywebpush` (run in thread executor). +- **Dispatch**: `broadcast_event()` in `websocket.py` fires `push_manager.dispatch_message(data)` alongside fanout for `message` events. The manager checks the global `app_settings.push_conversations` list, then sends to all currently registered subscriptions via `pywebpush` (run in a thread executor). - **Stale cleanup**: HTTP 404/410 from the push service triggers immediate subscription deletion. - **Subscriptions stored** in `push_subscriptions` table with `UNIQUE(endpoint)` for upsert semantics. - Requires HTTPS (self-signed OK) and outbound internet to reach browser push services. @@ -314,7 +314,7 @@ Main tables: - `contact_name_history` (tracks name changes over time) - `repeater_telemetry_history` (time-series telemetry snapshots for tracked repeaters) - `fanout_configs` (MQTT, bot, webhook, Apprise, SQS integration configs) -- `push_subscriptions` (Web Push browser subscriptions with per-conversation filter preferences; UNIQUE on endpoint) +- `push_subscriptions` (Web Push browser subscriptions with delivery metadata; UNIQUE on endpoint) - `app_settings` (includes `vapid_private_key` and `vapid_public_key` for Web Push VAPID signing) Contact route state is canonicalized on the backend: diff --git a/app/migrations/_057_web_push.py b/app/migrations/_058_web_push.py similarity index 75% rename from app/migrations/_057_web_push.py rename to app/migrations/_058_web_push.py index 72914ac..93ca6bf 100644 --- a/app/migrations/_057_web_push.py +++ b/app/migrations/_058_web_push.py @@ -6,9 +6,9 @@ logger = logging.getLogger(__name__) async def migrate(conn: aiosqlite.Connection) -> None: - """Add VAPID key columns and push_subscriptions table for Web Push.""" + """Add Web Push support: VAPID keys, push subscriptions table, and global conversation list.""" - # VAPID key pair stored in app_settings (one per instance) + # VAPID key pair + global push conversation list in app_settings table_check = await conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='app_settings'" ) @@ -24,8 +24,12 @@ async def migrate(conn: aiosqlite.Connection) -> None: await conn.execute( "ALTER TABLE app_settings ADD COLUMN vapid_public_key TEXT DEFAULT ''" ) + if "push_conversations" not in columns: + await conn.execute( + "ALTER TABLE app_settings ADD COLUMN push_conversations TEXT DEFAULT '[]'" + ) - # Push subscriptions — one row per browser + # Push subscriptions — one row per browser/device await conn.execute( """ CREATE TABLE IF NOT EXISTS push_subscriptions ( @@ -34,8 +38,6 @@ async def migrate(conn: aiosqlite.Connection) -> None: p256dh TEXT NOT NULL, auth TEXT NOT NULL, label TEXT NOT NULL DEFAULT '', - filter_mode TEXT NOT NULL DEFAULT 'all_messages', - filter_conversations TEXT NOT NULL DEFAULT '[]', created_at INTEGER NOT NULL, last_success_at INTEGER, failure_count INTEGER DEFAULT 0, diff --git a/app/push/manager.py b/app/push/manager.py index eb362fe..b9dad29 100644 --- a/app/push/manager.py +++ b/app/push/manager.py @@ -1,22 +1,25 @@ """Web Push dispatch manager. -Handles filtering subscriptions by their preferences and sending push -notifications concurrently when a new message arrives. +Checks the global push-enabled conversation list (stored in app_settings) +and sends push notifications to ALL registered devices when a matching +incoming message arrives. """ import asyncio import json import logging +from dataclasses import dataclass from pywebpush import WebPushException from app.push.send import send_push from app.push.vapid import get_vapid_private_key from app.repository.push_subscriptions import PushSubscriptionRepository +from app.repository.settings import AppSettingsRepository logger = logging.getLogger(__name__) -_SEND_TIMEOUT = 10 # seconds per push send +_SEND_TIMEOUT = 15 # seconds per push send _VAPID_CLAIMS = {"sub": "mailto:noreply@meshcore.local"} @@ -29,19 +32,6 @@ def _state_key_for_message(data: dict) -> str: return f"channel-{conversation_key}" -def _matches_filter(sub: dict, data: dict) -> bool: - """Check whether a message event matches a subscription's filter.""" - mode = sub.get("filter_mode", "all_messages") - if mode == "all_messages": - return True - if mode == "all_dms": - return data.get("type") == "PRIV" - if mode == "selected": - key = _state_key_for_message(data) - return key in (sub.get("filter_conversations") or []) - return False - - def _build_payload(data: dict) -> str: """Build the push notification JSON payload from a message event.""" msg_type = data.get("type", "") @@ -53,11 +43,11 @@ def _build_payload(data: dict) -> str: title = f"Message from {sender_name}" if sender_name else "New direct message" body = text else: - # Channel messages include "SenderName: text" in the text field - title = f"#{channel_name}" if channel_name else "Channel message" + title = channel_name if channel_name else "Channel message" body = text conversation_key = data.get("conversation_key", "") + state_key = _state_key_for_message(data) if msg_type == "PRIV": url_hash = f"#contact/{conversation_key}" else: @@ -67,7 +57,10 @@ def _build_payload(data: dict) -> str: { "title": title, "body": body, - "tag": f"meshcore-{data.get('id', '')}", + # Tag per conversation so different conversations coexist in the + # notification tray, while repeated messages in the same + # conversation replace each other. + "tag": f"meshcore-{state_key}", "url_hash": url_hash, } ) @@ -84,13 +77,31 @@ def _subscription_info(sub: dict) -> dict: } +@dataclass +class _SendResult: + sub_id: str + success: bool = False + expired: bool = False + + class PushManager: async def dispatch_message(self, data: dict) -> None: - """Send push notifications for a message event to matching subscriptions.""" + """Send push notifications for a message event to all devices.""" # Don't notify for messages the operator just sent themselves if data.get("outgoing"): return + # Check the global conversation list + state_key = _state_key_for_message(data) + try: + push_conversations = await AppSettingsRepository.get_push_conversations() + except Exception: + logger.debug("Push dispatch: failed to load push_conversations", exc_info=True) + return + + if state_key not in push_conversations: + return + try: subs = await PushSubscriptionRepository.get_all() except Exception: @@ -100,21 +111,40 @@ class PushManager: if not subs: return - matching = [s for s in subs if _matches_filter(s, data)] - if not matching: - return - payload = _build_payload(data) vapid_key = get_vapid_private_key() if not vapid_key: logger.debug("Push dispatch: no VAPID key configured, skipping") return - tasks = [self._send_one(sub, payload, vapid_key) for sub in matching] - await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather( + *(self._send_one(sub, payload, vapid_key) for sub in subs), + return_exceptions=True, + ) - async def _send_one(self, sub: dict, payload: str, vapid_key: str) -> None: + # Batch-update all delivery outcomes in one transaction. + success_ids: list[str] = [] + failure_ids: list[str] = [] + remove_ids: list[str] = [] + for r in results: + if isinstance(r, _SendResult): + if r.expired: + remove_ids.append(r.sub_id) + elif r.success: + success_ids.append(r.sub_id) + else: + failure_ids.append(r.sub_id) + if success_ids or failure_ids or remove_ids: + try: + await PushSubscriptionRepository.batch_record_outcomes( + success_ids, failure_ids, remove_ids + ) + except Exception: + logger.debug("Push dispatch: failed to record outcomes", exc_info=True) + + async def _send_one(self, sub: dict, payload: str, vapid_key: str) -> _SendResult: sub_id = sub["id"] + result = _SendResult(sub_id=sub_id) try: async with asyncio.timeout(_SEND_TIMEOUT): await send_push( @@ -123,26 +153,20 @@ class PushManager: vapid_private_key=vapid_key, vapid_claims=_VAPID_CLAIMS, ) - await PushSubscriptionRepository.record_success(sub_id) + result.success = True except WebPushException as e: status = getattr(e, "response", None) status_code = getattr(status, "status_code", 0) if status else 0 - if status_code in (404, 410): - logger.info( - "Push subscription expired (HTTP %d), removing %s", - status_code, - sub_id, - ) - await PushSubscriptionRepository.delete(sub_id) + if status_code in (403, 404, 410): + logger.info("Push subscription expired (HTTP %d), removing %s", status_code, sub_id) + result.expired = True else: logger.warning("Push send failed for %s: %s", sub_id, e) - await PushSubscriptionRepository.record_failure(sub_id) except TimeoutError: logger.warning("Push send timed out for %s", sub_id) - await PushSubscriptionRepository.record_failure(sub_id) except Exception: logger.debug("Push send error for %s", sub_id, exc_info=True) - await PushSubscriptionRepository.record_failure(sub_id) + return result push_manager = PushManager() diff --git a/app/push/send.py b/app/push/send.py index 843884e..2af8759 100644 --- a/app/push/send.py +++ b/app/push/send.py @@ -6,11 +6,201 @@ a thread executor to avoid blocking the event loop. import asyncio import logging +import socket +from typing import Any, cast +import requests +import urllib3.connection +import urllib3.connectionpool from pywebpush import webpush +from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import ConnectTimeout as RequestsConnectTimeout +from urllib3.exceptions import ConnectTimeoutError, NameResolutionError, NewConnectionError logger = logging.getLogger(__name__) +DEFAULT_TIMEOUT = object() +DEFAULT_PUSH_CONNECT_TIMEOUT_SECONDS = 3 +IPV4_FALLBACK_CONNECT_TIMEOUT_SECONDS = 10 +DEFAULT_PUSH_READ_TIMEOUT_SECONDS = 10 + + +def _create_ipv4_connection( + address: tuple[str, int], + timeout: float | None | object = DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + socket_options=None, +) -> socket.socket: + """Create a socket connection using IPv4 only.""" + host, port = address + if host.startswith("["): + host = host.strip("[]") + + err: OSError | None = None + for res in socket.getaddrinfo(host, port, socket.AF_INET, socket.SOCK_STREAM): + af, socktype, proto, _, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + if socket_options: + for opt in socket_options: + sock.setsockopt(*opt) + if timeout is not DEFAULT_TIMEOUT: + sock.settimeout(cast(float | None, timeout)) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + except OSError as exc: + err = exc + if sock is not None: + sock.close() + + if err is not None: + raise err + raise OSError("getaddrinfo returns an empty list") + + +class IPv4HTTPConnection(urllib3.connection.HTTPConnection): + """urllib3 HTTP connection that resolves and connects via IPv4 only.""" + + def _new_conn(self) -> socket.socket: + try: + return _create_ipv4_connection( + (self._dns_host, self.port), + self.timeout, + source_address=self.source_address, + socket_options=self.socket_options, + ) + except socket.gaierror as exc: + raise NameResolutionError(self.host, self, exc) from exc + except TimeoutError as exc: + raise ConnectTimeoutError( + self, + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from exc + except OSError as exc: + raise NewConnectionError(self, f"Failed to establish a new connection: {exc}") from exc + + +class IPv4HTTPSConnection(urllib3.connection.HTTPSConnection): + """urllib3 HTTPS connection that resolves and connects via IPv4 only.""" + + def _new_conn(self) -> socket.socket: + try: + return _create_ipv4_connection( + (self._dns_host, self.port), + self.timeout, + source_address=self.source_address, + socket_options=self.socket_options, + ) + except socket.gaierror as exc: + raise NameResolutionError(self.host, self, exc) from exc + except TimeoutError as exc: + raise ConnectTimeoutError( + self, + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from exc + except OSError as exc: + raise NewConnectionError(self, f"Failed to establish a new connection: {exc}") from exc + + +class IPv4HTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): + ConnectionCls = cast(Any, IPv4HTTPConnection) + + +class IPv4HTTPSConnectionPool(urllib3.connectionpool.HTTPSConnectionPool): + ConnectionCls = cast(Any, IPv4HTTPSConnection) + + +def _configure_pool_manager_for_ipv4(manager: Any) -> None: + manager.pool_classes_by_scheme = manager.pool_classes_by_scheme.copy() + manager.pool_classes_by_scheme["http"] = IPv4HTTPConnectionPool + manager.pool_classes_by_scheme["https"] = IPv4HTTPSConnectionPool + + +class IPv4HTTPAdapter(HTTPAdapter): + """requests adapter that uses IPv4-only urllib3 connection pools.""" + + def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): + super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs) + _configure_pool_manager_for_ipv4(self.poolmanager) + + def proxy_manager_for(self, *args, **kwargs): + manager = super().proxy_manager_for(*args, **kwargs) + _configure_pool_manager_for_ipv4(manager) + return manager + + +def _build_default_requests_session() -> requests.Session: + return requests.Session() + + +def _build_ipv4_requests_session() -> requests.Session: + session = requests.Session() + adapter = IPv4HTTPAdapter() + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +def _send_push_with_session( + *, + subscription_info: dict, + payload: str, + vapid_private_key: str, + vapid_claims: dict, + session: requests.Session, + connect_timeout_seconds: int, +) -> int: + response = webpush( + subscription_info=subscription_info, + data=payload, + vapid_private_key=vapid_private_key, + vapid_claims=vapid_claims, + content_encoding="aes128gcm", + timeout=cast(Any, (connect_timeout_seconds, DEFAULT_PUSH_READ_TIMEOUT_SECONDS)), + requests_session=session, + ) + return response.status_code # type: ignore[union-attr] + + +def _send_push_with_fallback( + subscription_info: dict, + payload: str, + vapid_private_key: str, + vapid_claims: dict, +) -> int: + """Send using normal dual-stack resolution, then retry with IPv4-only on connect failures.""" + session = _build_default_requests_session() + try: + return _send_push_with_session( + subscription_info=subscription_info, + payload=payload, + vapid_private_key=vapid_private_key, + vapid_claims=vapid_claims, + session=session, + connect_timeout_seconds=DEFAULT_PUSH_CONNECT_TIMEOUT_SECONDS, + ) + except (RequestsConnectTimeout, RequestsConnectionError) as exc: + logger.info("Push delivery retrying via IPv4 after initial network failure: %s", exc) + finally: + session.close() + + session = _build_ipv4_requests_session() + try: + return _send_push_with_session( + subscription_info=subscription_info, + payload=payload, + vapid_private_key=vapid_private_key, + vapid_claims=vapid_claims, + session=session, + connect_timeout_seconds=IPV4_FALLBACK_CONNECT_TIMEOUT_SECONDS, + ) + finally: + session.close() + async def send_push( subscription_info: dict, @@ -23,7 +213,7 @@ async def send_push( Args: subscription_info: {"endpoint": ..., "keys": {"p256dh": ..., "auth": ...}} payload: JSON string to encrypt and send - vapid_private_key: PEM-encoded VAPID private key + vapid_private_key: base64url-encoded raw EC private key scalar vapid_claims: {"sub": "mailto:..."} or {"sub": "https://..."} Returns: @@ -33,13 +223,9 @@ async def send_push( WebPushException: on push service error (caller handles 404/410 cleanup). """ loop = asyncio.get_running_loop() - response = await loop.run_in_executor( + return await loop.run_in_executor( None, - lambda: webpush( - subscription_info=subscription_info, - data=payload, - vapid_private_key=vapid_private_key, - vapid_claims=vapid_claims, + lambda: _send_push_with_fallback( + subscription_info, payload, vapid_private_key, vapid_claims ), ) - return response.status_code # type: ignore[union-attr] diff --git a/app/push/vapid.py b/app/push/vapid.py index 706778f..cf0ef9f 100644 --- a/app/push/vapid.py +++ b/app/push/vapid.py @@ -1,7 +1,8 @@ """VAPID key management for Web Push. -Generates a P-256 key pair on first use and caches it in app_settings. -The public key is served to browsers for PushManager.subscribe(). +Generates a P-256 key pair on first use and caches it in app_settings +via ``AppSettingsRepository``. The public key is served to browsers +for ``PushManager.subscribe()``. """ import base64 @@ -10,7 +11,7 @@ import logging from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from py_vapid import Vapid -from app.database import db +from app.repository.settings import AppSettingsRepository logger = logging.getLogger(__name__) @@ -22,14 +23,10 @@ async def ensure_vapid_keys() -> tuple[str, str]: """Read or generate VAPID keys. Call once at startup after DB connect.""" global _cached_private_key, _cached_public_key - cursor = await db.conn.execute( - "SELECT vapid_private_key, vapid_public_key FROM app_settings WHERE id = 1" - ) - row = await cursor.fetchone() - - if row and row["vapid_private_key"] and row["vapid_public_key"]: - _cached_private_key = row["vapid_private_key"] - _cached_public_key = row["vapid_public_key"] + private, public = await AppSettingsRepository.get_vapid_keys() + if private and public: + _cached_private_key = private + _cached_public_key = public logger.info("VAPID keys loaded from database") return _cached_private_key, _cached_public_key @@ -37,19 +34,17 @@ async def ensure_vapid_keys() -> tuple[str, str]: vapid = Vapid() vapid.generate_keys() - # Private key as PEM for pywebpush - _cached_private_key = vapid.private_pem().decode("utf-8") + # Private key as base64url-encoded raw 32-byte EC scalar — the format + # that pywebpush passes to ``Vapid.from_string()``. + raw_priv = vapid.private_key.private_numbers().private_value.to_bytes(32, "big") # type: ignore[union-attr] + _cached_private_key = base64.urlsafe_b64encode(raw_priv).rstrip(b"=").decode("ascii") # Public key as uncompressed P-256 point, base64url-encoded (no padding) # for the browser Push API's applicationServerKey raw_pub = vapid.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) # type: ignore[union-attr] _cached_public_key = base64.urlsafe_b64encode(raw_pub).rstrip(b"=").decode("ascii") - await db.conn.execute( - "UPDATE app_settings SET vapid_private_key = ?, vapid_public_key = ? WHERE id = 1", - (_cached_private_key, _cached_public_key), - ) - await db.conn.commit() + await AppSettingsRepository.set_vapid_keys(_cached_private_key, _cached_public_key) logger.info("Generated and stored new VAPID key pair") return _cached_private_key, _cached_public_key @@ -61,5 +56,5 @@ def get_vapid_public_key() -> str: def get_vapid_private_key() -> str: - """Return the cached VAPID private key (PEM). Must call ensure_vapid_keys() first.""" + """Return the cached VAPID private key (base64url). Must call ensure_vapid_keys() first.""" return _cached_private_key diff --git a/app/repository/push_subscriptions.py b/app/repository/push_subscriptions.py index 104e178..b8d1fec 100644 --- a/app/repository/push_subscriptions.py +++ b/app/repository/push_subscriptions.py @@ -1,6 +1,5 @@ """Repository for push_subscriptions table.""" -import json import logging import time import uuid @@ -10,23 +9,22 @@ from app.database import db logger = logging.getLogger(__name__) +# Auto-delete subscriptions that have failed this many times consecutively +# without any successful delivery in between. +MAX_CONSECUTIVE_FAILURES = 15 + def _row_to_dict(row: Any) -> dict[str, Any]: - result = { + return { "id": row["id"], "endpoint": row["endpoint"], "p256dh": row["p256dh"], "auth": row["auth"], "label": row["label"] or "", - "filter_mode": row["filter_mode"] or "all_messages", - "filter_conversations": json.loads(row["filter_conversations"]) - if row["filter_conversations"] - else [], "created_at": row["created_at"] or 0, "last_success_at": row["last_success_at"], "failure_count": row["failure_count"] or 0, } - return result class PushSubscriptionRepository: @@ -36,54 +34,58 @@ class PushSubscriptionRepository: p256dh: str, auth: str, label: str = "", - filter_mode: str = "all_messages", - filter_conversations: list[str] | None = None, ) -> dict[str, Any]: """Create or upsert a push subscription (keyed by endpoint).""" sub_id = str(uuid.uuid4()) now = int(time.time()) - convos_json = json.dumps(filter_conversations or []) - # Upsert: if endpoint already exists, update keys/label but keep the ID - await db.conn.execute( - """ - INSERT INTO push_subscriptions - (id, endpoint, p256dh, auth, label, filter_mode, - filter_conversations, created_at, failure_count) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0) - ON CONFLICT(endpoint) DO UPDATE SET - p256dh = excluded.p256dh, - auth = excluded.auth, - label = CASE WHEN excluded.label != '' THEN excluded.label ELSE push_subscriptions.label END, - failure_count = 0 - """, - (sub_id, endpoint, p256dh, auth, label, filter_mode, convos_json, now), - ) - await db.conn.commit() + async with db.tx() as conn: + await conn.execute( + """ + INSERT INTO push_subscriptions + (id, endpoint, p256dh, auth, label, created_at, failure_count) + VALUES (?, ?, ?, ?, ?, ?, 0) + ON CONFLICT(endpoint) DO UPDATE SET + p256dh = excluded.p256dh, + auth = excluded.auth, + label = CASE WHEN excluded.label != '' THEN excluded.label + ELSE push_subscriptions.label END, + failure_count = 0 + """, + (sub_id, endpoint, p256dh, auth, label, now), + ) + async with conn.execute( + "SELECT * FROM push_subscriptions WHERE endpoint = ?", (endpoint,) + ) as cursor: + row = await cursor.fetchone() - # Return the actual row (may be existing on upsert) - return await PushSubscriptionRepository.get_by_endpoint(endpoint) # type: ignore[return-value] + return _row_to_dict(row) if row else {"id": sub_id} # type: ignore[arg-type] @staticmethod async def get(subscription_id: str) -> dict[str, Any] | None: - cursor = await db.conn.execute( - "SELECT * FROM push_subscriptions WHERE id = ?", (subscription_id,) - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM push_subscriptions WHERE id = ?", (subscription_id,) + ) as cursor: + row = await cursor.fetchone() return _row_to_dict(row) if row else None @staticmethod async def get_by_endpoint(endpoint: str) -> dict[str, Any] | None: - cursor = await db.conn.execute( - "SELECT * FROM push_subscriptions WHERE endpoint = ?", (endpoint,) - ) - row = await cursor.fetchone() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM push_subscriptions WHERE endpoint = ?", (endpoint,) + ) as cursor: + row = await cursor.fetchone() return _row_to_dict(row) if row else None @staticmethod async def get_all() -> list[dict[str, Any]]: - cursor = await db.conn.execute("SELECT * FROM push_subscriptions ORDER BY created_at DESC") - rows = await cursor.fetchall() + async with db.readonly() as conn: + async with conn.execute( + "SELECT * FROM push_subscriptions ORDER BY created_at DESC" + ) as cursor: + rows = await cursor.fetchall() return [_row_to_dict(row) for row in rows] @staticmethod @@ -91,55 +93,70 @@ class PushSubscriptionRepository: updates: list[str] = [] params: list[Any] = [] - for key in ("label", "filter_mode"): - if key in fields: - updates.append(f"{key} = ?") - params.append(fields[key]) - - if "filter_conversations" in fields: - updates.append("filter_conversations = ?") - params.append(json.dumps(fields["filter_conversations"])) + if "label" in fields: + updates.append("label = ?") + params.append(fields["label"]) if not updates: return await PushSubscriptionRepository.get(subscription_id) params.append(subscription_id) - await db.conn.execute( - f"UPDATE push_subscriptions SET {', '.join(updates)} WHERE id = ?", - params, - ) - await db.conn.commit() - return await PushSubscriptionRepository.get(subscription_id) + async with db.tx() as conn: + await conn.execute( + f"UPDATE push_subscriptions SET {', '.join(updates)} WHERE id = ?", + params, + ) + async with conn.execute( + "SELECT * FROM push_subscriptions WHERE id = ?", (subscription_id,) + ) as cursor: + row = await cursor.fetchone() + return _row_to_dict(row) if row else None @staticmethod async def delete(subscription_id: str) -> bool: - cursor = await db.conn.execute( - "DELETE FROM push_subscriptions WHERE id = ?", (subscription_id,) - ) - await db.conn.commit() - return cursor.rowcount > 0 + async with db.tx() as conn: + async with conn.execute( + "DELETE FROM push_subscriptions WHERE id = ?", (subscription_id,) + ) as cursor: + return cursor.rowcount > 0 @staticmethod async def delete_by_endpoint(endpoint: str) -> bool: - cursor = await db.conn.execute( - "DELETE FROM push_subscriptions WHERE endpoint = ?", (endpoint,) - ) - await db.conn.commit() - return cursor.rowcount > 0 + async with db.tx() as conn: + async with conn.execute( + "DELETE FROM push_subscriptions WHERE endpoint = ?", (endpoint,) + ) as cursor: + return cursor.rowcount > 0 @staticmethod - async def record_success(subscription_id: str) -> None: + async def batch_record_outcomes( + success_ids: list[str], failure_ids: list[str], remove_ids: list[str] + ) -> None: + """Batch-update delivery outcomes in a single transaction.""" now = int(time.time()) - await db.conn.execute( - "UPDATE push_subscriptions SET last_success_at = ?, failure_count = 0 WHERE id = ?", - (now, subscription_id), - ) - await db.conn.commit() - - @staticmethod - async def record_failure(subscription_id: str) -> None: - await db.conn.execute( - "UPDATE push_subscriptions SET failure_count = failure_count + 1 WHERE id = ?", - (subscription_id,), - ) - await db.conn.commit() + async with db.tx() as conn: + if remove_ids: + placeholders = ",".join("?" for _ in remove_ids) + await conn.execute( + f"DELETE FROM push_subscriptions WHERE id IN ({placeholders})", + remove_ids, + ) + if success_ids: + placeholders = ",".join("?" for _ in success_ids) + await conn.execute( + f"UPDATE push_subscriptions SET last_success_at = ?, failure_count = 0 " + f"WHERE id IN ({placeholders})", + [now, *success_ids], + ) + if failure_ids: + placeholders = ",".join("?" for _ in failure_ids) + await conn.execute( + f"UPDATE push_subscriptions SET failure_count = failure_count + 1 " + f"WHERE id IN ({placeholders})", + failure_ids, + ) + # Evict subscriptions that have exceeded the failure threshold + await conn.execute( + "DELETE FROM push_subscriptions WHERE failure_count >= ?", + (MAX_CONSECUTIVE_FAILURES,), + ) diff --git a/app/repository/settings.py b/app/repository/settings.py index 38bd087..7405eaf 100644 --- a/app/repository/settings.py +++ b/app/repository/settings.py @@ -282,6 +282,85 @@ class AppSettingsRepository: await AppSettingsRepository._apply_updates(conn, blocked_names=new_names) return await AppSettingsRepository._get_in_conn(conn) + @staticmethod + async def get_vapid_keys() -> tuple[str, str]: + """Return (private_key_pem, public_key_b64url) from app_settings. + + These are internal-only columns not exposed via the AppSettings model. + """ + async with db.readonly() as conn: + async with conn.execute( + "SELECT vapid_private_key, vapid_public_key FROM app_settings WHERE id = 1" + ) as cursor: + row = await cursor.fetchone() + if row and row["vapid_private_key"] and row["vapid_public_key"]: + return row["vapid_private_key"], row["vapid_public_key"] + return "", "" + + @staticmethod + async def set_vapid_keys(private_key: str, public_key: str) -> None: + """Persist auto-generated VAPID key pair to app_settings.""" + async with db.tx() as conn: + await conn.execute( + "UPDATE app_settings SET vapid_private_key = ?, vapid_public_key = ? WHERE id = 1", + (private_key, public_key), + ) + + @staticmethod + async def get_push_conversations() -> list[str]: + """Return the global list of push-enabled conversation state keys. + + Internal-only column, not exposed via the AppSettings model. + """ + async with db.readonly() as conn: + async with conn.execute( + "SELECT push_conversations FROM app_settings WHERE id = 1" + ) as cursor: + row = await cursor.fetchone() + if row and row["push_conversations"]: + try: + return json.loads(row["push_conversations"]) + except (json.JSONDecodeError, TypeError): + return [] + return [] + + @staticmethod + async def set_push_conversations(conversations: list[str]) -> list[str]: + """Replace the global push-enabled conversation list.""" + async with db.tx() as conn: + await conn.execute( + "UPDATE app_settings SET push_conversations = ? WHERE id = 1", + (json.dumps(conversations),), + ) + return conversations + + @staticmethod + async def toggle_push_conversation(key: str) -> list[str]: + """Add or remove a conversation state key from the global push list. + + Atomic read-modify-write under a single ``db.tx()`` lock. + """ + async with db.tx() as conn: + async with conn.execute( + "SELECT push_conversations FROM app_settings WHERE id = 1" + ) as cursor: + row = await cursor.fetchone() + current: list[str] = [] + if row and row["push_conversations"]: + try: + current = json.loads(row["push_conversations"]) + except (json.JSONDecodeError, TypeError): + current = [] + if key in current: + current = [k for k in current if k != key] + else: + current.append(key) + await conn.execute( + "UPDATE app_settings SET push_conversations = ? WHERE id = 1", + (json.dumps(current),), + ) + return current + class StatisticsRepository: @staticmethod diff --git a/app/routers/push.py b/app/routers/push.py index 8e8299b..942976c 100644 --- a/app/routers/push.py +++ b/app/routers/push.py @@ -1,13 +1,17 @@ """Web Push subscription management endpoints.""" +import asyncio +import json import logging from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field +from pywebpush import WebPushException from app.push.send import send_push from app.push.vapid import get_vapid_private_key, get_vapid_public_key from app.repository.push_subscriptions import PushSubscriptionRepository +from app.repository.settings import AppSettingsRepository logger = logging.getLogger(__name__) @@ -30,11 +34,13 @@ class PushSubscribeRequest(BaseModel): class PushSubscriptionUpdate(BaseModel): label: str | None = None - filter_mode: str | None = None - filter_conversations: list[str] | None = None -# ── Endpoints ──────────────────────────────────────────────────────────── +class PushConversationToggle(BaseModel): + key: str = Field(min_length=1) + + +# ─��� Endpoints ──────────────────────────────────────────────────────────── @router.get("/vapid-public-key", response_model=VapidPublicKeyResponse) @@ -48,7 +54,7 @@ async def vapid_public_key() -> VapidPublicKeyResponse: @router.post("/subscribe") async def subscribe(body: PushSubscribeRequest) -> dict: - """Register or update a push subscription. Upserts by endpoint.""" + """Register or update a push subscription (device). Upserts by endpoint.""" sub = await PushSubscriptionRepository.create( endpoint=body.endpoint, p256dh=body.p256dh, @@ -60,13 +66,13 @@ async def subscribe(body: PushSubscribeRequest) -> dict: @router.get("/subscriptions") async def list_subscriptions() -> list[dict]: - """List all push subscriptions.""" + """List all push subscriptions (devices).""" return await PushSubscriptionRepository.get_all() @router.patch("/subscriptions/{subscription_id}") async def update_subscription(subscription_id: str, body: PushSubscriptionUpdate) -> dict: - """Update a subscription's label or filter preferences.""" + """Update a subscription's label.""" existing = await PushSubscriptionRepository.get(subscription_id) if not existing: raise HTTPException(status_code=404, detail="Subscription not found") @@ -74,12 +80,6 @@ async def update_subscription(subscription_id: str, body: PushSubscriptionUpdate updates = {} if body.label is not None: updates["label"] = body.label - if body.filter_mode is not None: - if body.filter_mode not in ("all_messages", "all_dms", "selected"): - raise HTTPException(status_code=400, detail="Invalid filter_mode") - updates["filter_mode"] = body.filter_mode - if body.filter_conversations is not None: - updates["filter_conversations"] = body.filter_conversations result = await PushSubscriptionRepository.update(subscription_id, **updates) return result or existing @@ -87,7 +87,7 @@ async def update_subscription(subscription_id: str, body: PushSubscriptionUpdate @router.delete("/subscriptions/{subscription_id}") async def unsubscribe(subscription_id: str) -> dict: - """Delete a push subscription.""" + """Delete a push subscription (device).""" deleted = await PushSubscriptionRepository.delete(subscription_id) if not deleted: raise HTTPException(status_code=404, detail="Subscription not found") @@ -105,8 +105,6 @@ async def test_push(subscription_id: str) -> dict: if not vapid_key: raise HTTPException(status_code=503, detail="VAPID keys not initialized") - import json - payload = json.dumps( { "title": "RemoteTerm Test", @@ -117,16 +115,50 @@ async def test_push(subscription_id: str) -> dict: ) try: - await send_push( - subscription_info={ - "endpoint": sub["endpoint"], - "keys": {"p256dh": sub["p256dh"], "auth": sub["auth"]}, - }, - payload=payload, - vapid_private_key=vapid_key, - vapid_claims={"sub": "mailto:noreply@meshcore.local"}, - ) + async with asyncio.timeout(15): + await send_push( + subscription_info={ + "endpoint": sub["endpoint"], + "keys": {"p256dh": sub["p256dh"], "auth": sub["auth"]}, + }, + payload=payload, + vapid_private_key=vapid_key, + vapid_claims={"sub": "mailto:noreply@meshcore.local"}, + ) return {"status": "sent"} + except TimeoutError: + raise HTTPException(status_code=504, detail="Push delivery timed out") from None + except WebPushException as e: + status_code = getattr(getattr(e, "response", None), "status_code", 0) + if status_code in (403, 404, 410): + logger.info( + "Test push: subscription stale (HTTP %d), removing %s", + status_code, + subscription_id, + ) + await PushSubscriptionRepository.delete(subscription_id) + raise HTTPException( + status_code=410, + detail="Subscription is stale (VAPID key mismatch or expired). " + "Re-enable push from a conversation header.", + ) from None + logger.warning("Test push failed: %s", e) + raise HTTPException(status_code=502, detail=f"Push delivery failed: {e}") from None except Exception as e: logger.warning("Test push failed: %s", e) raise HTTPException(status_code=502, detail=f"Push delivery failed: {e}") from None + + +# ── Global push conversation management ────────────────────────────────── + + +@router.get("/conversations") +async def get_push_conversations() -> list[str]: + """Return the global list of push-enabled conversation state keys.""" + return await AppSettingsRepository.get_push_conversations() + + +@router.post("/conversations/toggle") +async def toggle_push_conversation(body: PushConversationToggle) -> list[str]: + """Add or remove a conversation from the global push list.""" + return await AppSettingsRepository.toggle_push_conversation(body.key) diff --git a/frontend/AGENTS.md b/frontend/AGENTS.md index 084f51d..35d8ffd 100644 --- a/frontend/AGENTS.md +++ b/frontend/AGENTS.md @@ -435,7 +435,7 @@ The `SearchView` component (`components/SearchView.tsx`) provides full-text sear Web Push allows notifications even when the browser tab is closed. Requires HTTPS (self-signed OK). - **Service worker**: `frontend/public/sw.js` handles `push` events (show notification) and `notificationclick` (focus/open tab, navigate via `url_hash`). Registered in `main.tsx` on secure contexts only. -- **`usePushSubscription` hook**: manages the full subscription lifecycle — subscribe (register SW → `PushManager.subscribe()` → POST to backend), unsubscribe, per-conversation filter management (`addConversation`/`removeConversation`), device listing and deletion. +- **`usePushSubscription` hook**: manages the full subscription lifecycle — subscribe (register SW → `PushManager.subscribe()` → POST to backend), unsubscribe, global push-conversation toggles, device listing, and deletion. - **ChatHeader integration**: `BellRing` icon (amber when active) appears next to the existing desktop notification `Bell` on secure contexts. First click subscribes the browser and enables push for that conversation; subsequent clicks toggle the conversation on/off. - **Settings > Local**: `PushDeviceManagement` component shows subscription status, lists all registered devices with test/delete buttons. Uses `usePushSubscription` hook directly. - Auto-generates device labels from User-Agent (e.g., "Chrome on macOS"). diff --git a/frontend/index.html b/frontend/index.html index 2a1470f..027c16e 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -15,10 +15,8 @@