Add support for community MQTT ingest

This commit is contained in:
Jack Kingsman
2026-03-01 09:55:11 -08:00
parent 2496d70c4b
commit 00ca4afa8d
17 changed files with 1495 additions and 26 deletions
+463
View File
@@ -0,0 +1,463 @@
"""Tests for community MQTT publisher."""
import json
from unittest.mock import MagicMock, patch
import nacl.bindings
import pytest
from app.community_mqtt import (
_CLIENT_ID,
_DEFAULT_BROKER,
_DEFAULT_PORT,
CommunityMqttPublisher,
_base64url_encode,
_calculate_packet_hash,
_ed25519_sign_expanded,
_format_raw_packet,
_generate_jwt_token,
_parse_broker_address,
community_mqtt_broadcast,
)
from app.models import AppSettings
def _make_test_keys() -> tuple[bytes, bytes]:
"""Generate a test MeshCore-format key pair.
Returns (private_key_64_bytes, public_key_32_bytes).
MeshCore format: scalar(32) || prefix(32), where scalar is already clamped.
"""
import hashlib
import os
seed = os.urandom(32)
expanded = hashlib.sha512(seed).digest()
scalar = bytearray(expanded[:32])
# Clamp scalar (standard Ed25519 clamping)
scalar[0] &= 248
scalar[31] &= 127
scalar[31] |= 64
scalar = bytes(scalar)
prefix = expanded[32:]
private_key = scalar + prefix
public_key = nacl.bindings.crypto_scalarmult_ed25519_base_noclamp(scalar)
return private_key, public_key
class TestBase64UrlEncode:
def test_encodes_without_padding(self):
result = _base64url_encode(b"\x00\x01\x02")
assert "=" not in result
def test_uses_url_safe_chars(self):
# Bytes that would produce + and / in standard base64
result = _base64url_encode(b"\xfb\xff\xfe")
assert "+" not in result
assert "/" not in result
class TestJwtGeneration:
def test_token_has_three_parts(self):
private_key, public_key = _make_test_keys()
token = _generate_jwt_token(private_key, public_key)
parts = token.split(".")
assert len(parts) == 3
def test_header_contains_ed25519_alg(self):
private_key, public_key = _make_test_keys()
token = _generate_jwt_token(private_key, public_key)
header_b64 = token.split(".")[0]
# Add padding for base64 decoding
import base64
padded = header_b64 + "=" * (4 - len(header_b64) % 4)
header = json.loads(base64.urlsafe_b64decode(padded))
assert header["alg"] == "Ed25519"
assert header["typ"] == "JWT"
def test_payload_contains_required_fields(self):
private_key, public_key = _make_test_keys()
token = _generate_jwt_token(private_key, public_key)
payload_b64 = token.split(".")[1]
import base64
padded = payload_b64 + "=" * (4 - len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded))
assert payload["publicKey"] == public_key.hex().upper()
assert "iat" in payload
assert "exp" in payload
assert payload["exp"] - payload["iat"] == 86400
assert payload["aud"] == _DEFAULT_BROKER
assert payload["owner"] == public_key.hex().upper()
assert payload["client"] == _CLIENT_ID
assert "email" not in payload # omitted when empty
def test_payload_includes_email_when_provided(self):
private_key, public_key = _make_test_keys()
token = _generate_jwt_token(private_key, public_key, email="test@example.com")
payload_b64 = token.split(".")[1]
import base64
padded = payload_b64 + "=" * (4 - len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded))
assert payload["email"] == "test@example.com"
def test_payload_uses_custom_audience(self):
private_key, public_key = _make_test_keys()
token = _generate_jwt_token(private_key, public_key, audience="custom.broker.net")
payload_b64 = token.split(".")[1]
import base64
padded = payload_b64 + "=" * (4 - len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded))
assert payload["aud"] == "custom.broker.net"
def test_signature_is_valid_hex(self):
private_key, public_key = _make_test_keys()
token = _generate_jwt_token(private_key, public_key)
sig_hex = token.split(".")[2]
sig_bytes = bytes.fromhex(sig_hex)
assert len(sig_bytes) == 64
def test_signature_verifies(self):
"""Verify the JWT signature using nacl.bindings.crypto_sign_open."""
private_key, public_key = _make_test_keys()
token = _generate_jwt_token(private_key, public_key)
parts = token.split(".")
signing_input = f"{parts[0]}.{parts[1]}".encode()
signature = bytes.fromhex(parts[2])
# crypto_sign_open expects signature + message concatenated
signed_message = signature + signing_input
# This will raise if the signature is invalid
verified = nacl.bindings.crypto_sign_open(signed_message, public_key)
assert verified == signing_input
class TestEddsaSignExpanded:
def test_produces_64_byte_signature(self):
private_key, public_key = _make_test_keys()
message = b"test message"
sig = _ed25519_sign_expanded(message, private_key[:32], private_key[32:], public_key)
assert len(sig) == 64
def test_signature_verifies_with_nacl(self):
private_key, public_key = _make_test_keys()
message = b"hello world"
sig = _ed25519_sign_expanded(message, private_key[:32], private_key[32:], public_key)
signed_message = sig + message
verified = nacl.bindings.crypto_sign_open(signed_message, public_key)
assert verified == message
def test_different_messages_produce_different_signatures(self):
private_key, public_key = _make_test_keys()
sig1 = _ed25519_sign_expanded(b"msg1", private_key[:32], private_key[32:], public_key)
sig2 = _ed25519_sign_expanded(b"msg2", private_key[:32], private_key[32:], public_key)
assert sig1 != sig2
class TestPacketFormatConversion:
def test_basic_field_mapping(self):
data = {
"id": 1,
"observation_id": 100,
"timestamp": 1700000000,
"data": "0a1b2c3d",
"payload_type": "ADVERT",
"snr": 5.5,
"rssi": -90,
"decrypted": False,
"decrypted_info": None,
}
result = _format_raw_packet(data, "TestNode", "AABBCCDD" * 8)
assert result["origin"] == "TestNode"
assert result["origin_id"] == "AABBCCDD" * 8
assert result["raw"] == "0A1B2C3D"
assert result["SNR"] == "5.5"
assert result["RSSI"] == "-90"
assert result["type"] == "PACKET"
assert result["direction"] == "rx"
assert result["len"] == "4"
def test_timestamp_is_iso8601(self):
data = {"timestamp": 1700000000, "data": "00", "snr": None, "rssi": None}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["timestamp"]
assert "T" in result["timestamp"]
def test_snr_rssi_unknown_when_none(self):
data = {"timestamp": 0, "data": "00", "snr": None, "rssi": None}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["SNR"] == "Unknown"
assert result["RSSI"] == "Unknown"
def test_packet_type_extraction(self):
# Header 0x14 = type 5, route 0 (TRANSPORT_FLOOD): header + 4 transport + path_len.
data = {"timestamp": 0, "data": "140102030400", "snr": None, "rssi": None}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["packet_type"] == "5"
assert result["route"] == "F"
def test_route_mapping(self):
# Test all 4 route types (matches meshcore-packet-capture)
# TRANSPORT_FLOOD=0 -> "F", FLOOD=1 -> "F", DIRECT=2 -> "D", TRANSPORT_DIRECT=3 -> "T"
samples = [
("000102030400", "F"), # TRANSPORT_FLOOD: header + transport + path_len
("0100", "F"), # FLOOD: header + path_len
("0200", "D"), # DIRECT: header + path_len
("030102030400", "T"), # TRANSPORT_DIRECT: header + transport + path_len
]
for raw_hex, expected in samples:
data = {"timestamp": 0, "data": raw_hex, "snr": None, "rssi": None}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["route"] == expected
def test_hash_is_16_uppercase_hex_chars(self):
data = {"timestamp": 0, "data": "aabb", "snr": None, "rssi": None}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert len(result["hash"]) == 16
assert result["hash"] == result["hash"].upper()
def test_empty_data_handled(self):
data = {"timestamp": 0, "data": "", "snr": None, "rssi": None}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["raw"] == ""
assert result["len"] == "0"
assert result["packet_type"] == "0"
assert result["route"] == "U"
def test_includes_reference_time_fields(self):
data = {"timestamp": 0, "data": "0100aabb", "snr": 1.0, "rssi": -70}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["time"]
assert result["date"]
assert result["payload_len"] == "2"
def test_adds_path_for_direct_route(self):
# route=2 (DIRECT), path_len=2, path=aa bb, payload=cc
data = {"timestamp": 0, "data": "0202AABBCC", "snr": 1.0, "rssi": -70}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["route"] == "D"
assert result["path"] == "aa,bb"
def test_direct_route_includes_empty_path_field(self):
data = {"timestamp": 0, "data": "0200", "snr": 1.0, "rssi": -70}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["route"] == "D"
assert "path" in result
assert result["path"] == ""
def test_unknown_version_uses_defaults(self):
# version=1 in high bits, type=5, route=1
header = (1 << 6) | (5 << 2) | 1
data = {"timestamp": 0, "data": f"{header:02x}00", "snr": 1.0, "rssi": -70}
result = _format_raw_packet(data, "Node", "AA" * 32)
assert result["packet_type"] == "0"
assert result["route"] == "U"
assert result["payload_len"] == "0"
class TestCalculatePacketHash:
def test_empty_bytes_returns_zeroes(self):
result = _calculate_packet_hash(b"")
assert result == "0" * 16
def test_returns_16_uppercase_hex_chars(self):
# Simple flood packet: header(1) + path_len(1) + payload
raw = bytes([0x01, 0x00, 0xAA, 0xBB]) # FLOOD, no path, payload=0xAABB
result = _calculate_packet_hash(raw)
assert len(result) == 16
assert result == result.upper()
def test_flood_packet_hash(self):
"""FLOOD route (0x01): no transport codes, header + path_len + payload."""
import hashlib
# Header 0x11 = route=FLOOD(1), payload_type=4(ADVERT): (4<<2)|1 = 0x11
payload = b"\xde\xad"
raw = bytes([0x11, 0x00]) + payload # header, path_len=0, payload
result = _calculate_packet_hash(raw)
# Expected: sha256(payload_type_byte + payload_data)[:16].upper()
expected = hashlib.sha256(bytes([4]) + payload).hexdigest()[:16].upper()
assert result == expected
def test_transport_flood_skips_transport_codes(self):
"""TRANSPORT_FLOOD (0x00): has 4 bytes of transport codes after header."""
import hashlib
# Header 0x10 = route=TRANSPORT_FLOOD(0), payload_type=4: (4<<2)|0 = 0x10
transport_codes = b"\x01\x02\x03\x04"
payload = b"\xca\xfe"
raw = bytes([0x10]) + transport_codes + bytes([0x00]) + payload
result = _calculate_packet_hash(raw)
expected = hashlib.sha256(bytes([4]) + payload).hexdigest()[:16].upper()
assert result == expected
def test_transport_direct_skips_transport_codes(self):
"""TRANSPORT_DIRECT (0x03): also has 4 bytes of transport codes."""
import hashlib
# Header 0x13 = route=TRANSPORT_DIRECT(3), payload_type=4: (4<<2)|3 = 0x13
transport_codes = b"\x05\x06\x07\x08"
payload = b"\xbe\xef"
raw = bytes([0x13]) + transport_codes + bytes([0x00]) + payload
result = _calculate_packet_hash(raw)
expected = hashlib.sha256(bytes([4]) + payload).hexdigest()[:16].upper()
assert result == expected
def test_trace_packet_includes_path_len_in_hash(self):
"""TRACE packets (type 9) include path_len as uint16_t LE in the hash."""
import hashlib
# Header for TRACE with FLOOD route: (9<<2)|1 = 0x25
path_len = 3
path_data = b"\xaa\xbb\xcc"
payload = b"\x01\x02"
raw = bytes([0x25, path_len]) + path_data + payload
result = _calculate_packet_hash(raw)
expected_hash = (
hashlib.sha256(bytes([9]) + path_len.to_bytes(2, byteorder="little") + payload)
.hexdigest()[:16]
.upper()
)
assert result == expected_hash
def test_with_path_data(self):
"""Packet with non-zero path_len should skip path bytes to reach payload."""
import hashlib
# FLOOD route, payload_type=2 (TXT_MSG): (2<<2)|1 = 0x09
path_data = b"\xaa\xbb" # 2 bytes of path
payload = b"\x48\x65\x6c\x6c\x6f" # "Hello"
raw = bytes([0x09, 0x02]) + path_data + payload
result = _calculate_packet_hash(raw)
expected = hashlib.sha256(bytes([2]) + payload).hexdigest()[:16].upper()
assert result == expected
def test_truncated_packet_returns_zeroes(self):
# Header says TRANSPORT_FLOOD, but missing path_len at required offset.
raw = bytes([0x10, 0x01, 0x02])
assert _calculate_packet_hash(raw) == "0" * 16
class TestCommunityMqttPublisher:
def test_initial_state(self):
pub = CommunityMqttPublisher()
assert pub.connected is False
assert pub._client is None
assert pub._task is None
@pytest.mark.asyncio
async def test_publish_drops_when_disconnected(self):
pub = CommunityMqttPublisher()
# Should not raise
await pub.publish("topic", {"key": "value"})
@pytest.mark.asyncio
async def test_stop_resets_state(self):
pub = CommunityMqttPublisher()
pub.connected = True
pub._client = MagicMock()
await pub.stop()
assert pub.connected is False
assert pub._client is None
def test_is_configured_false_when_disabled(self):
pub = CommunityMqttPublisher()
pub._settings = AppSettings(community_mqtt_enabled=False)
with patch("app.keystore.has_private_key", return_value=True):
assert pub._is_configured() is False
def test_is_configured_false_when_no_private_key(self):
pub = CommunityMqttPublisher()
pub._settings = AppSettings(community_mqtt_enabled=True)
with patch("app.keystore.has_private_key", return_value=False):
assert pub._is_configured() is False
def test_is_configured_true_when_enabled_with_key(self):
pub = CommunityMqttPublisher()
pub._settings = AppSettings(community_mqtt_enabled=True)
with patch("app.keystore.has_private_key", return_value=True):
assert pub._is_configured() is True
class TestCommunityMqttBroadcast:
def test_filters_non_raw_packet(self):
"""Non-raw_packet events should be ignored."""
with patch("app.community_mqtt.community_publisher") as mock_pub:
mock_pub.connected = True
mock_pub._settings = AppSettings(community_mqtt_enabled=True)
community_mqtt_broadcast("message", {"text": "hello"})
# No asyncio.create_task should be called for non-raw_packet events
# Since we're filtering, we just verify no exception
def test_skips_when_disconnected(self):
"""Should not publish when disconnected."""
with (
patch("app.community_mqtt.community_publisher") as mock_pub,
patch("app.community_mqtt.asyncio.create_task") as mock_task,
):
mock_pub.connected = False
mock_pub._settings = AppSettings(community_mqtt_enabled=True)
community_mqtt_broadcast("raw_packet", {"data": "00"})
mock_task.assert_not_called()
def test_skips_when_settings_none(self):
"""Should not publish when settings are None."""
with (
patch("app.community_mqtt.community_publisher") as mock_pub,
patch("app.community_mqtt.asyncio.create_task") as mock_task,
):
mock_pub.connected = True
mock_pub._settings = None
community_mqtt_broadcast("raw_packet", {"data": "00"})
mock_task.assert_not_called()
class TestParseBrokerAddress:
def test_hostname_only_uses_default_port(self):
host, port = _parse_broker_address("mqtt-us-v1.letsmesh.net")
assert host == "mqtt-us-v1.letsmesh.net"
assert port == _DEFAULT_PORT
def test_hostname_with_port(self):
host, port = _parse_broker_address("mqtt-us-v1.letsmesh.net:8883")
assert host == "mqtt-us-v1.letsmesh.net"
assert port == 8883
def test_hostname_with_port_443(self):
host, port = _parse_broker_address("broker.example.com:443")
assert host == "broker.example.com"
assert port == 443
def test_invalid_port_uses_default(self):
host, port = _parse_broker_address("broker.example.com:abc")
assert host == "broker.example.com:abc"
assert port == _DEFAULT_PORT
def test_empty_string(self):
host, port = _parse_broker_address("")
assert host == ""
assert port == _DEFAULT_PORT
class TestPublishFailureSetsDisconnected:
@pytest.mark.asyncio
async def test_publish_error_sets_connected_false(self):
"""A publish error should set connected=False so the loop can detect it."""
pub = CommunityMqttPublisher()
pub.connected = True
mock_client = MagicMock()
mock_client.publish = MagicMock(side_effect=Exception("broker gone"))
pub._client = mock_client
await pub.publish("topic", {"data": "test"})
assert pub.connected is False
+77 -20
View File
@@ -100,8 +100,8 @@ class TestMigration001:
# Run migrations
applied = await run_migrations(conn)
assert applied == 31 # All migrations run
assert await get_version(conn) == 31
assert applied == 32 # All migrations run
assert await get_version(conn) == 32
# Verify columns exist by inserting and selecting
await conn.execute(
@@ -183,9 +183,9 @@ class TestMigration001:
applied1 = await run_migrations(conn)
applied2 = await run_migrations(conn)
assert applied1 == 31 # All migrations run
assert applied1 == 32 # All migrations run
assert applied2 == 0 # No migrations on second run
assert await get_version(conn) == 31
assert await get_version(conn) == 32
finally:
await conn.close()
@@ -246,8 +246,8 @@ class TestMigration001:
applied = await run_migrations(conn)
# All migrations applied (version incremented) but no error
assert applied == 31
assert await get_version(conn) == 31
assert applied == 32
assert await get_version(conn) == 32
finally:
await conn.close()
@@ -374,10 +374,10 @@ class TestMigration013:
)
await conn.commit()
# Run migration 13 (plus 14-27 which also run)
# Run migration 13 (plus 14-33 which also run)
applied = await run_migrations(conn)
assert applied == 19
assert await get_version(conn) == 31
assert applied == 20
assert await get_version(conn) == 32
# Verify bots array was created with migrated data
cursor = await conn.execute("SELECT bots FROM app_settings WHERE id = 1")
@@ -497,7 +497,7 @@ class TestMigration018:
assert await cursor.fetchone() is not None
await run_migrations(conn)
assert await get_version(conn) == 31
assert await get_version(conn) == 32
# Verify autoindex is gone
cursor = await conn.execute(
@@ -575,8 +575,8 @@ class TestMigration018:
await conn.commit()
applied = await run_migrations(conn)
assert applied == 14 # Migrations 18-31 run (18+19 skip internally)
assert await get_version(conn) == 31
assert applied == 15 # Migrations 18-32 run (18+19 skip internally)
assert await get_version(conn) == 32
finally:
await conn.close()
@@ -648,7 +648,7 @@ class TestMigration019:
assert await cursor.fetchone() is not None
await run_migrations(conn)
assert await get_version(conn) == 31
assert await get_version(conn) == 32
# Verify autoindex is gone
cursor = await conn.execute(
@@ -714,8 +714,8 @@ class TestMigration020:
assert (await cursor.fetchone())[0] == "delete"
applied = await run_migrations(conn)
assert applied == 12 # Migrations 20-31
assert await get_version(conn) == 31
assert applied == 13 # Migrations 20-32
assert await get_version(conn) == 32
# Verify WAL mode
cursor = await conn.execute("PRAGMA journal_mode")
@@ -745,7 +745,7 @@ class TestMigration020:
await set_version(conn, 20)
applied = await run_migrations(conn)
assert applied == 11 # Migrations 21-31 still run
assert applied == 12 # Migrations 21-32 still run
# Still WAL + INCREMENTAL
cursor = await conn.execute("PRAGMA journal_mode")
@@ -803,8 +803,8 @@ class TestMigration028:
await conn.commit()
applied = await run_migrations(conn)
assert applied == 4
assert await get_version(conn) == 31
assert applied == 5
assert await get_version(conn) == 32
# Verify payload_hash column is now BLOB
cursor = await conn.execute("PRAGMA table_info(raw_packets)")
@@ -873,8 +873,8 @@ class TestMigration028:
await conn.commit()
applied = await run_migrations(conn)
assert applied == 4 # Version still bumped
assert await get_version(conn) == 31
assert applied == 5 # Version still bumped
assert await get_version(conn) == 32
# Verify data unchanged
cursor = await conn.execute("SELECT payload_hash FROM raw_packets")
@@ -882,3 +882,60 @@ class TestMigration028:
assert bytes(row["payload_hash"]) == b"\xab" * 32
finally:
await conn.close()
class TestMigration032:
"""Test migration 032: add community MQTT columns to app_settings."""
@pytest.mark.asyncio
async def test_migration_adds_all_community_mqtt_columns(self):
"""Migration adds enabled, iata, broker, and email columns."""
conn = await aiosqlite.connect(":memory:")
conn.row_factory = aiosqlite.Row
try:
await set_version(conn, 31)
# Create app_settings without community columns (pre-migration schema)
await conn.execute("""
CREATE TABLE app_settings (
id INTEGER PRIMARY KEY,
max_radio_contacts INTEGER DEFAULT 200,
favorites TEXT DEFAULT '[]',
auto_decrypt_dm_on_advert INTEGER DEFAULT 0,
sidebar_sort_order TEXT DEFAULT 'recent',
last_message_times TEXT DEFAULT '{}',
preferences_migrated INTEGER DEFAULT 0,
advert_interval INTEGER DEFAULT 0,
last_advert_time INTEGER DEFAULT 0,
bots TEXT DEFAULT '[]',
mqtt_broker_host TEXT DEFAULT '',
mqtt_broker_port INTEGER DEFAULT 1883,
mqtt_username TEXT DEFAULT '',
mqtt_password TEXT DEFAULT '',
mqtt_use_tls INTEGER DEFAULT 0,
mqtt_tls_insecure INTEGER DEFAULT 0,
mqtt_topic_prefix TEXT DEFAULT 'meshcore',
mqtt_publish_messages INTEGER DEFAULT 0,
mqtt_publish_raw_packets INTEGER DEFAULT 0
)
""")
await conn.execute("INSERT INTO app_settings (id) VALUES (1)")
await conn.commit()
applied = await run_migrations(conn)
assert applied == 1
assert await get_version(conn) == 32
# Verify all columns exist with correct defaults
cursor = await conn.execute(
"""SELECT community_mqtt_enabled, community_mqtt_iata,
community_mqtt_broker, community_mqtt_email
FROM app_settings WHERE id = 1"""
)
row = await cursor.fetchone()
assert row["community_mqtt_enabled"] == 0
assert row["community_mqtt_iata"] == ""
assert row["community_mqtt_broker"] == "mqtt-us-v1.letsmesh.net"
assert row["community_mqtt_email"] == ""
finally:
await conn.close()
+4
View File
@@ -502,6 +502,10 @@ class TestAppSettingsRepository:
"mqtt_topic_prefix": "meshcore",
"mqtt_publish_messages": 0,
"mqtt_publish_raw_packets": 0,
"community_mqtt_enabled": 0,
"community_mqtt_iata": "",
"community_mqtt_broker": "mqtt-us-v1.letsmesh.net",
"community_mqtt_email": "",
}
)
mock_conn.execute = AsyncMock(return_value=mock_cursor)
+72
View File
@@ -117,6 +117,78 @@ class TestUpdateSettings:
assert settings.mqtt_publish_messages is False
assert settings.mqtt_publish_raw_packets is False
@pytest.mark.asyncio
async def test_community_mqtt_fields_round_trip(self, test_db):
"""Community MQTT settings should be saved and retrieved correctly."""
mock_community = type("MockCommunity", (), {"restart": AsyncMock()})()
with patch("app.community_mqtt.community_publisher", mock_community):
result = await update_settings(
AppSettingsUpdate(
community_mqtt_enabled=True,
community_mqtt_iata="DEN",
community_mqtt_broker="custom-broker.example.com",
community_mqtt_email="test@example.com",
)
)
assert result.community_mqtt_enabled is True
assert result.community_mqtt_iata == "DEN"
assert result.community_mqtt_broker == "custom-broker.example.com"
assert result.community_mqtt_email == "test@example.com"
# Verify persistence
fresh = await AppSettingsRepository.get()
assert fresh.community_mqtt_enabled is True
assert fresh.community_mqtt_iata == "DEN"
assert fresh.community_mqtt_broker == "custom-broker.example.com"
assert fresh.community_mqtt_email == "test@example.com"
# Verify restart was called
mock_community.restart.assert_called_once()
@pytest.mark.asyncio
async def test_community_mqtt_iata_validation_rejects_invalid(self, test_db):
"""Invalid IATA codes should be rejected."""
with pytest.raises(HTTPException) as exc:
await update_settings(AppSettingsUpdate(community_mqtt_iata="A"))
assert exc.value.status_code == 400
with pytest.raises(HTTPException) as exc:
await update_settings(AppSettingsUpdate(community_mqtt_iata="ABCDE"))
assert exc.value.status_code == 400
with pytest.raises(HTTPException) as exc:
await update_settings(AppSettingsUpdate(community_mqtt_iata="12"))
assert exc.value.status_code == 400
with pytest.raises(HTTPException) as exc:
await update_settings(AppSettingsUpdate(community_mqtt_iata="ABCD"))
assert exc.value.status_code == 400
@pytest.mark.asyncio
async def test_community_mqtt_enable_requires_iata(self, test_db):
"""Enabling community MQTT without a valid IATA code should be rejected."""
with pytest.raises(HTTPException) as exc:
await update_settings(AppSettingsUpdate(community_mqtt_enabled=True))
assert exc.value.status_code == 400
assert "IATA" in exc.value.detail
@pytest.mark.asyncio
async def test_community_mqtt_iata_uppercased(self, test_db):
"""IATA codes should be uppercased."""
mock_community = type("MockCommunity", (), {"restart": AsyncMock()})()
with patch("app.community_mqtt.community_publisher", mock_community):
result = await update_settings(AppSettingsUpdate(community_mqtt_iata="den"))
assert result.community_mqtt_iata == "DEN"
@pytest.mark.asyncio
async def test_community_mqtt_defaults_on_fresh_db(self, test_db):
"""Community MQTT fields should have correct defaults on a fresh database."""
settings = await AppSettingsRepository.get()
assert settings.community_mqtt_enabled is False
assert settings.community_mqtt_iata == ""
assert settings.community_mqtt_email == ""
class TestToggleFavorite:
@pytest.mark.asyncio