diff --git a/repeater/web/companion_endpoints.py b/repeater/web/companion_endpoints.py index 0d3c4a8..c99f52a 100644 --- a/repeater/web/companion_endpoints.py +++ b/repeater/web/companion_endpoints.py @@ -17,7 +17,10 @@ from typing import Optional import cherrypy from pymc_core.companion.constants import DEFAULT_OFFLINE_QUEUE_SIZE -from repeater.companion.utils import validate_companion_node_name +from repeater.companion.utils import ( + trim_companion_contacts_to_fit, + validate_companion_node_name, +) from .auth.middleware import require_auth @@ -395,8 +398,11 @@ class CompanionAPIEndpoints: if limit < 1: raise cherrypy.HTTPError(400, "limit must be a positive integer") bridge = self._get_bridge(**self._resolve_bridge_params(body)) + # max_contacts lives on the ContactStore, not the bridge itself; reading it + # from bridge.contacts avoids silently falling back to the 1000 default for + # companions configured with a higher limit. + max_contacts = bridge.contacts.max_contacts if limit is not None: - max_contacts = getattr(bridge, "max_contacts", 1000) limit = min(limit, max_contacts) companion_hash = getattr(bridge, "_companion_hash", None) if not companion_hash: @@ -408,6 +414,16 @@ class CompanionAPIEndpoints: hours=hours, limit=limit, ) + # The bulk import writes directly to SQLite, bypassing the ContactStore cap + # that every other path honors. Trim favourite-aware (oldest non-favourites + # first) so persisted contacts never exceed max_contacts. + try: + removed = trim_companion_contacts_to_fit(sqlite_handler, companion_hash, max_contacts) + except ValueError as exc: + raise cherrypy.HTTPError( + 409, + f"Cannot trim imported contacts to fit max_contacts={max_contacts}: {exc}", + ) contact_rows = sqlite_handler.companion_load_contacts(companion_hash) if contact_rows: records = [] @@ -416,7 +432,7 @@ class CompanionAPIEndpoints: d["public_key"] = d.pop("pubkey", d.get("public_key", b"")) records.append(d) bridge.contacts.load_from_dicts(records) - return self._success({"imported": count}) + return self._success({"imported": count, "removed": removed}) # ----- Channels ----- diff --git a/tests/test_companion_settings.py b/tests/test_companion_settings.py index 92ccdd8..ebca62e 100644 --- a/tests/test_companion_settings.py +++ b/tests/test_companion_settings.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -268,3 +269,201 @@ class TestPersistSkipWhenOff: fs = self._frame_server(7) asyncio.run(fs._persist_companion_message({"text": "x"})) fs.sqlite_handler.companion_push_message.assert_called_once_with("0x01", {"text": "x"}, 7) + + +class TestImportRepeaterContactsCap: + """The import endpoint must never leave persisted contacts above max_contacts. + + The bulk import writes straight to SQLite, bypassing the ContactStore cap, so the + endpoint trims favourite-aware to fit after the insert. + """ + + _HASH = "0x01" + + @staticmethod + def _handler(tmp_path): + from repeater.data_acquisition.sqlite_handler import SQLiteHandler + + return SQLiteHandler(tmp_path) + + @staticmethod + def _seed_adverts(h, n, start_ts=10_000): + """Seed ``n`` repeater adverts with increasing last_seen (newest = highest i).""" + for i in range(n): + h.store_advert( + { + "timestamp": float(start_ts + i), + "pubkey": f"{i:064x}", + "node_name": f"adv-{i}", + "is_repeater": True, + "route_type": 1, + "contact_type": "repeater", + "latitude": 0.0, + "longitude": 0.0, + } + ) + + @classmethod + def _save_contacts(cls, h, contacts): + assert h.companion_save_contacts(cls._HASH, contacts) + + @staticmethod + def _contact(pk_int, *, flags=0, lastmod=0): + # Pre-existing contacts use a pubkey range disjoint from seeded adverts. + return { + "pubkey": (1_000_000 + pk_int).to_bytes(8, "big"), + "name": f"pre-{pk_int}", + "adv_type": 2, + "flags": flags, + "lastmod": lastmod, + "last_advert_timestamp": lastmod, + } + + @classmethod + def _endpoint(cls, handler, bridge, body): + from repeater.web.companion_endpoints import CompanionAPIEndpoints + + ep = CompanionAPIEndpoints.__new__(CompanionAPIEndpoints) + ep._require_post = lambda: None + ep._get_json_body = lambda: body + ep._resolve_bridge_params = lambda b: {} + ep._get_bridge = lambda **kw: bridge + ep._get_sqlite_handler = lambda: handler + return ep + + @staticmethod + def _invoke(ep): + """Call the endpoint past the @require_auth wrapper (no auth context in tests).""" + from repeater.web.companion_endpoints import CompanionAPIEndpoints + + return CompanionAPIEndpoints.import_repeater_contacts.__wrapped__(ep) + + @classmethod + def _bridge(cls, max_contacts): + contacts = SimpleNamespace(max_contacts=max_contacts, loaded=None) + contacts.load_from_dicts = lambda records: setattr(contacts, "loaded", list(records)) + return SimpleNamespace(_companion_hash=cls._HASH, contacts=contacts) + + def test_import_over_cap_trims_to_fit(self, tmp_path): + h = self._handler(tmp_path) + self._seed_adverts(h, 60) + bridge = self._bridge(max_contacts=50) + ep = self._endpoint(h, bridge, {"companion_name": "c"}) + + resp = self._invoke(ep) + + assert h.companion_count_contacts(self._HASH) == 50 + assert resp["data"] == {"imported": 60, "removed": 10} + assert len(bridge.contacts.loaded) == 50 + + def test_pre_existing_plus_import_accumulation(self, tmp_path): + h = self._handler(tmp_path) + # 40 old pre-existing contacts (lastmod 0..39). + self._save_contacts(h, [self._contact(i, lastmod=i) for i in range(40)]) + # 30 newer imported adverts (last_seen >= 10_000). + self._seed_adverts(h, 30) + bridge = self._bridge(max_contacts=50) + ep = self._endpoint(h, bridge, {"companion_name": "c"}) + + resp = self._invoke(ep) + + assert h.companion_count_contacts(self._HASH) == 50 + assert resp["data"]["imported"] == 30 + # All 30 newer imports survive; oldest pre-existing are evicted. + kept = {row["pubkey"] for row in h.companion_load_contacts(self._HASH)} + for i in range(30): + assert bytes.fromhex(f"{i:064x}") in kept + + def test_favourites_protected(self, tmp_path): + h = self._handler(tmp_path) + # 5 favourites that are also the oldest (lastmod 0..4). + favourites = [self._contact(i, flags=1, lastmod=i) for i in range(5)] + self._save_contacts(h, favourites) + self._seed_adverts(h, 60) + bridge = self._bridge(max_contacts=50) + ep = self._endpoint(h, bridge, {"companion_name": "c"}) + + self._invoke(ep) + + assert h.companion_count_contacts(self._HASH) == 50 + kept = {row["pubkey"] for row in h.companion_load_contacts(self._HASH)} + for fav in favourites: + assert fav["pubkey"] in kept + + def test_favourites_exceed_cap_returns_409(self, tmp_path): + import cherrypy + + h = self._handler(tmp_path) + self._save_contacts(h, [self._contact(i, flags=1, lastmod=i) for i in range(51)]) + self._seed_adverts(h, 1) + bridge = self._bridge(max_contacts=50) + ep = self._endpoint(h, bridge, {"companion_name": "c"}) + + with pytest.raises(cherrypy.HTTPError) as exc_info: + self._invoke(ep) + assert exc_info.value.code == 409 + + def test_cap_source_is_contacts_not_default(self, tmp_path): + # A companion configured above the 1000 default must not be silently clamped. + h = self._handler(tmp_path) + captured = {} + real_import = h.companion_import_repeater_contacts + + def _spy(companion_hash, **kwargs): + captured["limit"] = kwargs.get("limit") + return real_import(companion_hash, **kwargs) + + h.companion_import_repeater_contacts = _spy + bridge = self._bridge(max_contacts=1200) + ep = self._endpoint(h, bridge, {"companion_name": "c", "limit": 1100}) + + self._invoke(ep) + + # min(limit=1100, max_contacts=1200) -> 1100, proving the cap came from + # bridge.contacts.max_contacts (1200), not the old 1000 fallback. + assert captured["limit"] == 1100 + + def test_under_cap_import_is_noop_trim(self, tmp_path): + # Happy path: an import that fits leaves everything and trims nothing. + h = self._handler(tmp_path) + self._seed_adverts(h, 10) + bridge = self._bridge(max_contacts=50) + ep = self._endpoint(h, bridge, {"companion_name": "c"}) + + resp = self._invoke(ep) + + assert h.companion_count_contacts(self._HASH) == 10 + assert resp["data"] == {"imported": 10, "removed": 0} + assert len(bridge.contacts.loaded) == 10 + + def test_incident_scale_default_cap(self, tmp_path): + # Reproduces the reported incident: an oversized import at the real 1000 + # default must end at exactly the cap, not 1062. + h = self._handler(tmp_path) + self._seed_adverts(h, 1062) + bridge = self._bridge(max_contacts=_DEFAULT_MAX_CONTACTS) + ep = self._endpoint(h, bridge, {"companion_name": "c"}) + + resp = self._invoke(ep) + + assert h.companion_count_contacts(self._HASH) == _DEFAULT_MAX_CONTACTS + assert resp["data"] == {"imported": 1062, "removed": 62} + assert len(bridge.contacts.loaded) == _DEFAULT_MAX_CONTACTS + + def test_repeated_import_stays_within_cap(self, tmp_path): + # Repeated imports (a plausible cause of the original overflow) must never + # accumulate past the cap. + h = self._handler(tmp_path) + self._seed_adverts(h, 60) + bridge = self._bridge(max_contacts=50) + ep = self._endpoint(h, bridge, {"companion_name": "c"}) + + first = self._invoke(ep) + assert h.companion_count_contacts(self._HASH) == 50 + assert first["data"]["removed"] == 10 + + # Second call re-imports the same adverts (the 10 trimmed are still in the + # adverts table) and must trim back to the cap again, not climb to 60. + second = self._invoke(ep) + assert h.companion_count_contacts(self._HASH) == 50 + assert second["data"]["removed"] == 10