From b1e3e71b68e776a987916170b30c15af2eb28446 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 9 Mar 2026 17:03:07 -0700 Subject: [PATCH] extract dm ack tracker service --- app/event_handlers.py | 31 +++++++----------------- app/services/dm_ack_tracker.py | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 23 deletions(-) create mode 100644 app/services/dm_ack_tracker.py diff --git a/app/event_handlers.py b/app/event_handlers.py index 0e1db4b..526e8fb 100644 --- a/app/event_handlers.py +++ b/app/event_handlers.py @@ -12,6 +12,7 @@ from app.repository import ( ContactRepository, MessageRepository, ) +from app.services import dm_ack_tracker from app.services.messages import create_fallback_direct_message, increment_ack_and_broadcast from app.websocket import broadcast_event @@ -23,33 +24,17 @@ logger = logging.getLogger(__name__) # Track active subscriptions so we can unsubscribe before re-registering # This prevents handler duplication after reconnects _active_subscriptions: list["Subscription"] = [] - - -# Track pending ACKs: expected_ack_code -> (message_id, timestamp, timeout_ms) -_pending_acks: dict[str, tuple[int, float, int]] = {} +_pending_acks = dm_ack_tracker._pending_acks def track_pending_ack(expected_ack: str, message_id: int, timeout_ms: int) -> None: - """Track a pending ACK for a direct message.""" - _pending_acks[expected_ack] = (message_id, time.time(), timeout_ms) - logger.debug( - "Tracking pending ACK %s for message %d (timeout %dms)", - expected_ack, - message_id, - timeout_ms, - ) + """Compatibility wrapper for pending DM ACK tracking.""" + dm_ack_tracker.track_pending_ack(expected_ack, message_id, timeout_ms) def cleanup_expired_acks() -> None: - """Remove expired pending ACKs.""" - now = time.time() - expired = [] - for code, (_msg_id, created_at, timeout_ms) in _pending_acks.items(): - if now - created_at > (timeout_ms / 1000) * 2: # 2x timeout as buffer - expired.append(code) - for code in expired: - del _pending_acks[code] - logger.debug("Expired pending ACK %s", code) + """Compatibility wrapper for expiring stale DM ACK entries.""" + dm_ack_tracker.cleanup_expired_acks() async def on_contact_message(event: "Event") -> None: @@ -280,8 +265,8 @@ async def on_ack(event: "Event") -> None: cleanup_expired_acks() - if ack_code in _pending_acks: - message_id, _, _ = _pending_acks.pop(ack_code) + message_id = dm_ack_tracker.pop_pending_ack(ack_code) + if message_id is not None: logger.info("ACK received for message %d", message_id) # DM ACKs don't carry path data, so paths is intentionally omitted. # The frontend's mergePendingAck handles the missing field correctly, diff --git a/app/services/dm_ack_tracker.py b/app/services/dm_ack_tracker.py new file mode 100644 index 0000000..b882073 --- /dev/null +++ b/app/services/dm_ack_tracker.py @@ -0,0 +1,43 @@ +"""Shared pending ACK tracking for outgoing direct messages.""" + +import logging +import time + +logger = logging.getLogger(__name__) + +PendingAck = tuple[int, float, int] + +_pending_acks: dict[str, PendingAck] = {} + + +def track_pending_ack(expected_ack: str, message_id: int, timeout_ms: int) -> None: + """Track an expected ACK code for an outgoing direct message.""" + _pending_acks[expected_ack] = (message_id, time.time(), timeout_ms) + logger.debug( + "Tracking pending ACK %s for message %d (timeout %dms)", + expected_ack, + message_id, + timeout_ms, + ) + + +def cleanup_expired_acks() -> None: + """Remove stale pending ACK entries.""" + now = time.time() + expired_codes = [ + code + for code, (_message_id, created_at, timeout_ms) in _pending_acks.items() + if now - created_at > (timeout_ms / 1000) * 2 + ] + for code in expired_codes: + del _pending_acks[code] + logger.debug("Expired pending ACK %s", code) + + +def pop_pending_ack(ack_code: str) -> int | None: + """Claim the tracked message ID for an ACK code if present.""" + pending = _pending_acks.pop(ack_code, None) + if pending is None: + return None + message_id, _, _ = pending + return message_id