Other connected clients get new chans over WS

This commit is contained in:
Jack Kingsman
2026-03-11 22:02:05 -07:00
parent 8c1a58b293
commit a13e241636
4 changed files with 72 additions and 12 deletions

View File

@@ -17,6 +17,10 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/channels", tags=["channels"])
def _broadcast_channel_update(channel: Channel) -> None:
broadcast_event("channel", channel.model_dump())
class CreateChannelRequest(BaseModel):
name: str = Field(min_length=1, max_length=32)
key: str | None = Field(
@@ -98,13 +102,12 @@ async def create_channel(request: CreateChannelRequest) -> Channel:
on_radio=False,
)
return Channel(
key=key_hex,
name=request.name,
is_hashtag=is_hashtag,
on_radio=False,
flood_scope_override=None,
)
stored = await ChannelRepository.get_by_key(key_hex)
if stored is None:
raise HTTPException(status_code=500, detail="Channel was created but could not be reloaded")
_broadcast_channel_update(stored)
return stored
@router.post("/sync")
@@ -123,6 +126,9 @@ async def sync_channels_from_radio(max_channels: int = Query(default=40, ge=1, l
key_hex = await upsert_channel_from_radio_slot(result.payload, on_radio=True)
if key_hex is not None:
count += 1
stored = await ChannelRepository.get_by_key(key_hex)
if stored is not None:
_broadcast_channel_update(stored)
logger.debug(
"Synced channel %s: %s", key_hex, result.payload.get("channel_name")
)

View File

@@ -289,6 +289,7 @@ async def create_contact(
background_tasks, request.public_key, request.name or existing.name
)
await _broadcast_contact_update(existing)
return existing
# Create new contact
@@ -319,6 +320,7 @@ async def create_contact(
stored = await ContactRepository.get_by_key(lower_key)
if stored is None:
raise HTTPException(status_code=500, detail="Contact was created but could not be reloaded")
await _broadcast_contact_update(stored)
await _broadcast_contact_resolution(promoted_keys, stored)
return stored
@@ -411,6 +413,7 @@ async def sync_contacts_from_radio() -> dict:
)
stored = await ContactRepository.get_by_key(lower_key)
if stored is not None:
await _broadcast_contact_update(stored)
await _broadcast_contact_resolution(promoted_keys, stored)
count += 1

View File

@@ -112,6 +112,33 @@ class TestSyncChannelsFromRadio:
assert secret_a.hex().upper() in keys
assert secret_b.hex().upper() in keys
@pytest.mark.asyncio
async def test_sync_broadcasts_channel_updates(self, test_db, client):
secret = bytes.fromhex("0123456789abcdef0123456789abcdef")
mock_mc = MagicMock()
async def mock_get_channel(idx):
if idx == 0:
return _make_channel_info("#general", secret)
return _make_empty_channel()
mock_mc.commands.get_channel = AsyncMock(side_effect=mock_get_channel)
radio_manager._meshcore = mock_mc
with (
_patch_require_connected(mock_mc),
patch("app.routers.channels.radio_manager") as mock_ch_rm,
patch("app.routers.channels.broadcast_event") as mock_broadcast,
):
mock_ch_rm.radio_operation = lambda desc: _noop_radio_operation(mock_mc)
response = await client.post("/api/channels/sync?max_channels=3")
assert response.status_code == 200
mock_broadcast.assert_called_once()
assert mock_broadcast.call_args.args[0] == "channel"
assert mock_broadcast.call_args.args[1]["key"] == secret.hex().upper()
@pytest.mark.asyncio
async def test_sync_skips_empty_channels(self, test_db, client):
"""Empty channel slots are skipped during sync."""
@@ -278,6 +305,19 @@ class TestChannelFloodScopeOverride:
mock_broadcast.assert_called_once()
assert mock_broadcast.call_args.args[0] == "channel"
class TestCreateChannel:
@pytest.mark.asyncio
async def test_create_broadcasts_channel_update(self, test_db):
from app.routers.channels import CreateChannelRequest, create_channel
with patch("app.routers.channels.broadcast_event") as mock_broadcast:
result = await create_channel(CreateChannelRequest(name="#mychannel"))
mock_broadcast.assert_called_once()
assert mock_broadcast.call_args.args[0] == "channel"
assert mock_broadcast.call_args.args[1]["key"] == result.key
@pytest.mark.asyncio
async def test_existing_hash_is_not_doubled(self, test_db, client):
key = "CC" * 16

View File

@@ -108,10 +108,11 @@ class TestCreateContact:
@pytest.mark.asyncio
async def test_create_new_contact(self, test_db, client):
response = await client.post(
"/api/contacts",
json={"public_key": KEY_A, "name": "NewContact"},
)
with patch("app.websocket.broadcast_event") as mock_broadcast:
response = await client.post(
"/api/contacts",
json={"public_key": KEY_A, "name": "NewContact"},
)
assert response.status_code == 200
data = response.json()
@@ -124,6 +125,7 @@ class TestCreateContact:
assert contact is not None
assert contact.name == "NewContact"
assert data["last_seen"] == contact.last_seen
mock_broadcast.assert_called_once_with("contact", contact.model_dump())
@pytest.mark.asyncio
async def test_create_invalid_hex(self, test_db, client):
@@ -662,7 +664,10 @@ class TestSyncContacts:
mock_mc.commands.get_contacts = AsyncMock(return_value=mock_result)
radio_manager._meshcore = mock_mc
with _patch_require_connected(mock_mc):
with (
_patch_require_connected(mock_mc),
patch("app.websocket.broadcast_event") as mock_broadcast,
):
response = await client.post("/api/contacts/sync")
assert response.status_code == 200
@@ -672,6 +677,12 @@ class TestSyncContacts:
alice = await ContactRepository.get_by_key(KEY_A)
assert alice is not None
assert alice.name == "Alice"
assert mock_broadcast.call_count == 2
assert [call.args[0] for call in mock_broadcast.call_args_list] == ["contact", "contact"]
assert {call.args[1]["public_key"] for call in mock_broadcast.call_args_list} == {
KEY_A,
KEY_B,
}
@pytest.mark.asyncio
async def test_sync_requires_connection(self, test_db, client):