From a13e241636c7cde7411cb53fd7118eceb957e3f9 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Wed, 11 Mar 2026 22:02:05 -0700 Subject: [PATCH] Other connected clients get new chans over WS --- app/routers/channels.py | 20 ++++++++++++------ app/routers/contacts.py | 3 +++ tests/test_channels_router.py | 40 +++++++++++++++++++++++++++++++++++ tests/test_contacts_router.py | 21 +++++++++++++----- 4 files changed, 72 insertions(+), 12 deletions(-) diff --git a/app/routers/channels.py b/app/routers/channels.py index 807318a..7403137 100644 --- a/app/routers/channels.py +++ b/app/routers/channels.py @@ -17,6 +17,10 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/channels", tags=["channels"]) +def _broadcast_channel_update(channel: Channel) -> None: + broadcast_event("channel", channel.model_dump()) + + class CreateChannelRequest(BaseModel): name: str = Field(min_length=1, max_length=32) key: str | None = Field( @@ -98,13 +102,12 @@ async def create_channel(request: CreateChannelRequest) -> Channel: on_radio=False, ) - return Channel( - key=key_hex, - name=request.name, - is_hashtag=is_hashtag, - on_radio=False, - flood_scope_override=None, - ) + stored = await ChannelRepository.get_by_key(key_hex) + if stored is None: + raise HTTPException(status_code=500, detail="Channel was created but could not be reloaded") + + _broadcast_channel_update(stored) + return stored @router.post("/sync") @@ -123,6 +126,9 @@ async def sync_channels_from_radio(max_channels: int = Query(default=40, ge=1, l key_hex = await upsert_channel_from_radio_slot(result.payload, on_radio=True) if key_hex is not None: count += 1 + stored = await ChannelRepository.get_by_key(key_hex) + if stored is not None: + _broadcast_channel_update(stored) logger.debug( "Synced channel %s: %s", key_hex, result.payload.get("channel_name") ) diff --git a/app/routers/contacts.py b/app/routers/contacts.py index c75b235..0569fbb 100644 --- a/app/routers/contacts.py +++ b/app/routers/contacts.py @@ -289,6 +289,7 @@ async def create_contact( background_tasks, request.public_key, request.name or existing.name ) + await _broadcast_contact_update(existing) return existing # Create new contact @@ -319,6 +320,7 @@ async def create_contact( stored = await ContactRepository.get_by_key(lower_key) if stored is None: raise HTTPException(status_code=500, detail="Contact was created but could not be reloaded") + await _broadcast_contact_update(stored) await _broadcast_contact_resolution(promoted_keys, stored) return stored @@ -411,6 +413,7 @@ async def sync_contacts_from_radio() -> dict: ) stored = await ContactRepository.get_by_key(lower_key) if stored is not None: + await _broadcast_contact_update(stored) await _broadcast_contact_resolution(promoted_keys, stored) count += 1 diff --git a/tests/test_channels_router.py b/tests/test_channels_router.py index 382985c..407594a 100644 --- a/tests/test_channels_router.py +++ b/tests/test_channels_router.py @@ -112,6 +112,33 @@ class TestSyncChannelsFromRadio: assert secret_a.hex().upper() in keys assert secret_b.hex().upper() in keys + @pytest.mark.asyncio + async def test_sync_broadcasts_channel_updates(self, test_db, client): + secret = bytes.fromhex("0123456789abcdef0123456789abcdef") + mock_mc = MagicMock() + + async def mock_get_channel(idx): + if idx == 0: + return _make_channel_info("#general", secret) + return _make_empty_channel() + + mock_mc.commands.get_channel = AsyncMock(side_effect=mock_get_channel) + radio_manager._meshcore = mock_mc + + with ( + _patch_require_connected(mock_mc), + patch("app.routers.channels.radio_manager") as mock_ch_rm, + patch("app.routers.channels.broadcast_event") as mock_broadcast, + ): + mock_ch_rm.radio_operation = lambda desc: _noop_radio_operation(mock_mc) + + response = await client.post("/api/channels/sync?max_channels=3") + + assert response.status_code == 200 + mock_broadcast.assert_called_once() + assert mock_broadcast.call_args.args[0] == "channel" + assert mock_broadcast.call_args.args[1]["key"] == secret.hex().upper() + @pytest.mark.asyncio async def test_sync_skips_empty_channels(self, test_db, client): """Empty channel slots are skipped during sync.""" @@ -278,6 +305,19 @@ class TestChannelFloodScopeOverride: mock_broadcast.assert_called_once() assert mock_broadcast.call_args.args[0] == "channel" + +class TestCreateChannel: + @pytest.mark.asyncio + async def test_create_broadcasts_channel_update(self, test_db): + from app.routers.channels import CreateChannelRequest, create_channel + + with patch("app.routers.channels.broadcast_event") as mock_broadcast: + result = await create_channel(CreateChannelRequest(name="#mychannel")) + + mock_broadcast.assert_called_once() + assert mock_broadcast.call_args.args[0] == "channel" + assert mock_broadcast.call_args.args[1]["key"] == result.key + @pytest.mark.asyncio async def test_existing_hash_is_not_doubled(self, test_db, client): key = "CC" * 16 diff --git a/tests/test_contacts_router.py b/tests/test_contacts_router.py index 5e8b5b1..4206f43 100644 --- a/tests/test_contacts_router.py +++ b/tests/test_contacts_router.py @@ -108,10 +108,11 @@ class TestCreateContact: @pytest.mark.asyncio async def test_create_new_contact(self, test_db, client): - response = await client.post( - "/api/contacts", - json={"public_key": KEY_A, "name": "NewContact"}, - ) + with patch("app.websocket.broadcast_event") as mock_broadcast: + response = await client.post( + "/api/contacts", + json={"public_key": KEY_A, "name": "NewContact"}, + ) assert response.status_code == 200 data = response.json() @@ -124,6 +125,7 @@ class TestCreateContact: assert contact is not None assert contact.name == "NewContact" assert data["last_seen"] == contact.last_seen + mock_broadcast.assert_called_once_with("contact", contact.model_dump()) @pytest.mark.asyncio async def test_create_invalid_hex(self, test_db, client): @@ -662,7 +664,10 @@ class TestSyncContacts: mock_mc.commands.get_contacts = AsyncMock(return_value=mock_result) radio_manager._meshcore = mock_mc - with _patch_require_connected(mock_mc): + with ( + _patch_require_connected(mock_mc), + patch("app.websocket.broadcast_event") as mock_broadcast, + ): response = await client.post("/api/contacts/sync") assert response.status_code == 200 @@ -672,6 +677,12 @@ class TestSyncContacts: alice = await ContactRepository.get_by_key(KEY_A) assert alice is not None assert alice.name == "Alice" + assert mock_broadcast.call_count == 2 + assert [call.args[0] for call in mock_broadcast.call_args_list] == ["contact", "contact"] + assert {call.args[1]["public_key"] for call in mock_broadcast.call_args_list} == { + KEY_A, + KEY_B, + } @pytest.mark.asyncio async def test_sync_requires_connection(self, test_db, client):