mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-07-01 15:31:50 +02:00
Add support for community MQTT ingest
This commit is contained in:
@@ -0,0 +1,463 @@
|
||||
"""Tests for community MQTT publisher."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import nacl.bindings
|
||||
import pytest
|
||||
|
||||
from app.community_mqtt import (
|
||||
_CLIENT_ID,
|
||||
_DEFAULT_BROKER,
|
||||
_DEFAULT_PORT,
|
||||
CommunityMqttPublisher,
|
||||
_base64url_encode,
|
||||
_calculate_packet_hash,
|
||||
_ed25519_sign_expanded,
|
||||
_format_raw_packet,
|
||||
_generate_jwt_token,
|
||||
_parse_broker_address,
|
||||
community_mqtt_broadcast,
|
||||
)
|
||||
from app.models import AppSettings
|
||||
|
||||
|
||||
def _make_test_keys() -> tuple[bytes, bytes]:
|
||||
"""Generate a test MeshCore-format key pair.
|
||||
|
||||
Returns (private_key_64_bytes, public_key_32_bytes).
|
||||
MeshCore format: scalar(32) || prefix(32), where scalar is already clamped.
|
||||
"""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
seed = os.urandom(32)
|
||||
expanded = hashlib.sha512(seed).digest()
|
||||
scalar = bytearray(expanded[:32])
|
||||
# Clamp scalar (standard Ed25519 clamping)
|
||||
scalar[0] &= 248
|
||||
scalar[31] &= 127
|
||||
scalar[31] |= 64
|
||||
scalar = bytes(scalar)
|
||||
prefix = expanded[32:]
|
||||
|
||||
private_key = scalar + prefix
|
||||
public_key = nacl.bindings.crypto_scalarmult_ed25519_base_noclamp(scalar)
|
||||
return private_key, public_key
|
||||
|
||||
|
||||
class TestBase64UrlEncode:
|
||||
def test_encodes_without_padding(self):
|
||||
result = _base64url_encode(b"\x00\x01\x02")
|
||||
assert "=" not in result
|
||||
|
||||
def test_uses_url_safe_chars(self):
|
||||
# Bytes that would produce + and / in standard base64
|
||||
result = _base64url_encode(b"\xfb\xff\xfe")
|
||||
assert "+" not in result
|
||||
assert "/" not in result
|
||||
|
||||
|
||||
class TestJwtGeneration:
|
||||
def test_token_has_three_parts(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
token = _generate_jwt_token(private_key, public_key)
|
||||
parts = token.split(".")
|
||||
assert len(parts) == 3
|
||||
|
||||
def test_header_contains_ed25519_alg(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
token = _generate_jwt_token(private_key, public_key)
|
||||
header_b64 = token.split(".")[0]
|
||||
# Add padding for base64 decoding
|
||||
import base64
|
||||
|
||||
padded = header_b64 + "=" * (4 - len(header_b64) % 4)
|
||||
header = json.loads(base64.urlsafe_b64decode(padded))
|
||||
assert header["alg"] == "Ed25519"
|
||||
assert header["typ"] == "JWT"
|
||||
|
||||
def test_payload_contains_required_fields(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
token = _generate_jwt_token(private_key, public_key)
|
||||
payload_b64 = token.split(".")[1]
|
||||
import base64
|
||||
|
||||
padded = payload_b64 + "=" * (4 - len(payload_b64) % 4)
|
||||
payload = json.loads(base64.urlsafe_b64decode(padded))
|
||||
assert payload["publicKey"] == public_key.hex().upper()
|
||||
assert "iat" in payload
|
||||
assert "exp" in payload
|
||||
assert payload["exp"] - payload["iat"] == 86400
|
||||
assert payload["aud"] == _DEFAULT_BROKER
|
||||
assert payload["owner"] == public_key.hex().upper()
|
||||
assert payload["client"] == _CLIENT_ID
|
||||
assert "email" not in payload # omitted when empty
|
||||
|
||||
def test_payload_includes_email_when_provided(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
token = _generate_jwt_token(private_key, public_key, email="test@example.com")
|
||||
payload_b64 = token.split(".")[1]
|
||||
import base64
|
||||
|
||||
padded = payload_b64 + "=" * (4 - len(payload_b64) % 4)
|
||||
payload = json.loads(base64.urlsafe_b64decode(padded))
|
||||
assert payload["email"] == "test@example.com"
|
||||
|
||||
def test_payload_uses_custom_audience(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
token = _generate_jwt_token(private_key, public_key, audience="custom.broker.net")
|
||||
payload_b64 = token.split(".")[1]
|
||||
import base64
|
||||
|
||||
padded = payload_b64 + "=" * (4 - len(payload_b64) % 4)
|
||||
payload = json.loads(base64.urlsafe_b64decode(padded))
|
||||
assert payload["aud"] == "custom.broker.net"
|
||||
|
||||
def test_signature_is_valid_hex(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
token = _generate_jwt_token(private_key, public_key)
|
||||
sig_hex = token.split(".")[2]
|
||||
sig_bytes = bytes.fromhex(sig_hex)
|
||||
assert len(sig_bytes) == 64
|
||||
|
||||
def test_signature_verifies(self):
|
||||
"""Verify the JWT signature using nacl.bindings.crypto_sign_open."""
|
||||
private_key, public_key = _make_test_keys()
|
||||
token = _generate_jwt_token(private_key, public_key)
|
||||
parts = token.split(".")
|
||||
signing_input = f"{parts[0]}.{parts[1]}".encode()
|
||||
signature = bytes.fromhex(parts[2])
|
||||
|
||||
# crypto_sign_open expects signature + message concatenated
|
||||
signed_message = signature + signing_input
|
||||
# This will raise if the signature is invalid
|
||||
verified = nacl.bindings.crypto_sign_open(signed_message, public_key)
|
||||
assert verified == signing_input
|
||||
|
||||
|
||||
class TestEddsaSignExpanded:
|
||||
def test_produces_64_byte_signature(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
message = b"test message"
|
||||
sig = _ed25519_sign_expanded(message, private_key[:32], private_key[32:], public_key)
|
||||
assert len(sig) == 64
|
||||
|
||||
def test_signature_verifies_with_nacl(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
message = b"hello world"
|
||||
sig = _ed25519_sign_expanded(message, private_key[:32], private_key[32:], public_key)
|
||||
|
||||
signed_message = sig + message
|
||||
verified = nacl.bindings.crypto_sign_open(signed_message, public_key)
|
||||
assert verified == message
|
||||
|
||||
def test_different_messages_produce_different_signatures(self):
|
||||
private_key, public_key = _make_test_keys()
|
||||
sig1 = _ed25519_sign_expanded(b"msg1", private_key[:32], private_key[32:], public_key)
|
||||
sig2 = _ed25519_sign_expanded(b"msg2", private_key[:32], private_key[32:], public_key)
|
||||
assert sig1 != sig2
|
||||
|
||||
|
||||
class TestPacketFormatConversion:
|
||||
def test_basic_field_mapping(self):
|
||||
data = {
|
||||
"id": 1,
|
||||
"observation_id": 100,
|
||||
"timestamp": 1700000000,
|
||||
"data": "0a1b2c3d",
|
||||
"payload_type": "ADVERT",
|
||||
"snr": 5.5,
|
||||
"rssi": -90,
|
||||
"decrypted": False,
|
||||
"decrypted_info": None,
|
||||
}
|
||||
result = _format_raw_packet(data, "TestNode", "AABBCCDD" * 8)
|
||||
|
||||
assert result["origin"] == "TestNode"
|
||||
assert result["origin_id"] == "AABBCCDD" * 8
|
||||
assert result["raw"] == "0A1B2C3D"
|
||||
assert result["SNR"] == "5.5"
|
||||
assert result["RSSI"] == "-90"
|
||||
assert result["type"] == "PACKET"
|
||||
assert result["direction"] == "rx"
|
||||
assert result["len"] == "4"
|
||||
|
||||
def test_timestamp_is_iso8601(self):
|
||||
data = {"timestamp": 1700000000, "data": "00", "snr": None, "rssi": None}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["timestamp"]
|
||||
assert "T" in result["timestamp"]
|
||||
|
||||
def test_snr_rssi_unknown_when_none(self):
|
||||
data = {"timestamp": 0, "data": "00", "snr": None, "rssi": None}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["SNR"] == "Unknown"
|
||||
assert result["RSSI"] == "Unknown"
|
||||
|
||||
def test_packet_type_extraction(self):
|
||||
# Header 0x14 = type 5, route 0 (TRANSPORT_FLOOD): header + 4 transport + path_len.
|
||||
data = {"timestamp": 0, "data": "140102030400", "snr": None, "rssi": None}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["packet_type"] == "5"
|
||||
assert result["route"] == "F"
|
||||
|
||||
def test_route_mapping(self):
|
||||
# Test all 4 route types (matches meshcore-packet-capture)
|
||||
# TRANSPORT_FLOOD=0 -> "F", FLOOD=1 -> "F", DIRECT=2 -> "D", TRANSPORT_DIRECT=3 -> "T"
|
||||
samples = [
|
||||
("000102030400", "F"), # TRANSPORT_FLOOD: header + transport + path_len
|
||||
("0100", "F"), # FLOOD: header + path_len
|
||||
("0200", "D"), # DIRECT: header + path_len
|
||||
("030102030400", "T"), # TRANSPORT_DIRECT: header + transport + path_len
|
||||
]
|
||||
for raw_hex, expected in samples:
|
||||
data = {"timestamp": 0, "data": raw_hex, "snr": None, "rssi": None}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["route"] == expected
|
||||
|
||||
def test_hash_is_16_uppercase_hex_chars(self):
|
||||
data = {"timestamp": 0, "data": "aabb", "snr": None, "rssi": None}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert len(result["hash"]) == 16
|
||||
assert result["hash"] == result["hash"].upper()
|
||||
|
||||
def test_empty_data_handled(self):
|
||||
data = {"timestamp": 0, "data": "", "snr": None, "rssi": None}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["raw"] == ""
|
||||
assert result["len"] == "0"
|
||||
assert result["packet_type"] == "0"
|
||||
assert result["route"] == "U"
|
||||
|
||||
def test_includes_reference_time_fields(self):
|
||||
data = {"timestamp": 0, "data": "0100aabb", "snr": 1.0, "rssi": -70}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["time"]
|
||||
assert result["date"]
|
||||
assert result["payload_len"] == "2"
|
||||
|
||||
def test_adds_path_for_direct_route(self):
|
||||
# route=2 (DIRECT), path_len=2, path=aa bb, payload=cc
|
||||
data = {"timestamp": 0, "data": "0202AABBCC", "snr": 1.0, "rssi": -70}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["route"] == "D"
|
||||
assert result["path"] == "aa,bb"
|
||||
|
||||
def test_direct_route_includes_empty_path_field(self):
|
||||
data = {"timestamp": 0, "data": "0200", "snr": 1.0, "rssi": -70}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["route"] == "D"
|
||||
assert "path" in result
|
||||
assert result["path"] == ""
|
||||
|
||||
def test_unknown_version_uses_defaults(self):
|
||||
# version=1 in high bits, type=5, route=1
|
||||
header = (1 << 6) | (5 << 2) | 1
|
||||
data = {"timestamp": 0, "data": f"{header:02x}00", "snr": 1.0, "rssi": -70}
|
||||
result = _format_raw_packet(data, "Node", "AA" * 32)
|
||||
assert result["packet_type"] == "0"
|
||||
assert result["route"] == "U"
|
||||
assert result["payload_len"] == "0"
|
||||
|
||||
|
||||
class TestCalculatePacketHash:
|
||||
def test_empty_bytes_returns_zeroes(self):
|
||||
result = _calculate_packet_hash(b"")
|
||||
assert result == "0" * 16
|
||||
|
||||
def test_returns_16_uppercase_hex_chars(self):
|
||||
# Simple flood packet: header(1) + path_len(1) + payload
|
||||
raw = bytes([0x01, 0x00, 0xAA, 0xBB]) # FLOOD, no path, payload=0xAABB
|
||||
result = _calculate_packet_hash(raw)
|
||||
assert len(result) == 16
|
||||
assert result == result.upper()
|
||||
|
||||
def test_flood_packet_hash(self):
|
||||
"""FLOOD route (0x01): no transport codes, header + path_len + payload."""
|
||||
import hashlib
|
||||
|
||||
# Header 0x11 = route=FLOOD(1), payload_type=4(ADVERT): (4<<2)|1 = 0x11
|
||||
payload = b"\xde\xad"
|
||||
raw = bytes([0x11, 0x00]) + payload # header, path_len=0, payload
|
||||
result = _calculate_packet_hash(raw)
|
||||
|
||||
# Expected: sha256(payload_type_byte + payload_data)[:16].upper()
|
||||
expected = hashlib.sha256(bytes([4]) + payload).hexdigest()[:16].upper()
|
||||
assert result == expected
|
||||
|
||||
def test_transport_flood_skips_transport_codes(self):
|
||||
"""TRANSPORT_FLOOD (0x00): has 4 bytes of transport codes after header."""
|
||||
import hashlib
|
||||
|
||||
# Header 0x10 = route=TRANSPORT_FLOOD(0), payload_type=4: (4<<2)|0 = 0x10
|
||||
transport_codes = b"\x01\x02\x03\x04"
|
||||
payload = b"\xca\xfe"
|
||||
raw = bytes([0x10]) + transport_codes + bytes([0x00]) + payload
|
||||
result = _calculate_packet_hash(raw)
|
||||
|
||||
expected = hashlib.sha256(bytes([4]) + payload).hexdigest()[:16].upper()
|
||||
assert result == expected
|
||||
|
||||
def test_transport_direct_skips_transport_codes(self):
|
||||
"""TRANSPORT_DIRECT (0x03): also has 4 bytes of transport codes."""
|
||||
import hashlib
|
||||
|
||||
# Header 0x13 = route=TRANSPORT_DIRECT(3), payload_type=4: (4<<2)|3 = 0x13
|
||||
transport_codes = b"\x05\x06\x07\x08"
|
||||
payload = b"\xbe\xef"
|
||||
raw = bytes([0x13]) + transport_codes + bytes([0x00]) + payload
|
||||
result = _calculate_packet_hash(raw)
|
||||
|
||||
expected = hashlib.sha256(bytes([4]) + payload).hexdigest()[:16].upper()
|
||||
assert result == expected
|
||||
|
||||
def test_trace_packet_includes_path_len_in_hash(self):
|
||||
"""TRACE packets (type 9) include path_len as uint16_t LE in the hash."""
|
||||
import hashlib
|
||||
|
||||
# Header for TRACE with FLOOD route: (9<<2)|1 = 0x25
|
||||
path_len = 3
|
||||
path_data = b"\xaa\xbb\xcc"
|
||||
payload = b"\x01\x02"
|
||||
raw = bytes([0x25, path_len]) + path_data + payload
|
||||
result = _calculate_packet_hash(raw)
|
||||
|
||||
expected_hash = (
|
||||
hashlib.sha256(bytes([9]) + path_len.to_bytes(2, byteorder="little") + payload)
|
||||
.hexdigest()[:16]
|
||||
.upper()
|
||||
)
|
||||
assert result == expected_hash
|
||||
|
||||
def test_with_path_data(self):
|
||||
"""Packet with non-zero path_len should skip path bytes to reach payload."""
|
||||
import hashlib
|
||||
|
||||
# FLOOD route, payload_type=2 (TXT_MSG): (2<<2)|1 = 0x09
|
||||
path_data = b"\xaa\xbb" # 2 bytes of path
|
||||
payload = b"\x48\x65\x6c\x6c\x6f" # "Hello"
|
||||
raw = bytes([0x09, 0x02]) + path_data + payload
|
||||
result = _calculate_packet_hash(raw)
|
||||
|
||||
expected = hashlib.sha256(bytes([2]) + payload).hexdigest()[:16].upper()
|
||||
assert result == expected
|
||||
|
||||
def test_truncated_packet_returns_zeroes(self):
|
||||
# Header says TRANSPORT_FLOOD, but missing path_len at required offset.
|
||||
raw = bytes([0x10, 0x01, 0x02])
|
||||
assert _calculate_packet_hash(raw) == "0" * 16
|
||||
|
||||
|
||||
class TestCommunityMqttPublisher:
|
||||
def test_initial_state(self):
|
||||
pub = CommunityMqttPublisher()
|
||||
assert pub.connected is False
|
||||
assert pub._client is None
|
||||
assert pub._task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_drops_when_disconnected(self):
|
||||
pub = CommunityMqttPublisher()
|
||||
# Should not raise
|
||||
await pub.publish("topic", {"key": "value"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_resets_state(self):
|
||||
pub = CommunityMqttPublisher()
|
||||
pub.connected = True
|
||||
pub._client = MagicMock()
|
||||
await pub.stop()
|
||||
assert pub.connected is False
|
||||
assert pub._client is None
|
||||
|
||||
def test_is_configured_false_when_disabled(self):
|
||||
pub = CommunityMqttPublisher()
|
||||
pub._settings = AppSettings(community_mqtt_enabled=False)
|
||||
with patch("app.keystore.has_private_key", return_value=True):
|
||||
assert pub._is_configured() is False
|
||||
|
||||
def test_is_configured_false_when_no_private_key(self):
|
||||
pub = CommunityMqttPublisher()
|
||||
pub._settings = AppSettings(community_mqtt_enabled=True)
|
||||
with patch("app.keystore.has_private_key", return_value=False):
|
||||
assert pub._is_configured() is False
|
||||
|
||||
def test_is_configured_true_when_enabled_with_key(self):
|
||||
pub = CommunityMqttPublisher()
|
||||
pub._settings = AppSettings(community_mqtt_enabled=True)
|
||||
with patch("app.keystore.has_private_key", return_value=True):
|
||||
assert pub._is_configured() is True
|
||||
|
||||
|
||||
class TestCommunityMqttBroadcast:
|
||||
def test_filters_non_raw_packet(self):
|
||||
"""Non-raw_packet events should be ignored."""
|
||||
with patch("app.community_mqtt.community_publisher") as mock_pub:
|
||||
mock_pub.connected = True
|
||||
mock_pub._settings = AppSettings(community_mqtt_enabled=True)
|
||||
community_mqtt_broadcast("message", {"text": "hello"})
|
||||
# No asyncio.create_task should be called for non-raw_packet events
|
||||
# Since we're filtering, we just verify no exception
|
||||
|
||||
def test_skips_when_disconnected(self):
|
||||
"""Should not publish when disconnected."""
|
||||
with (
|
||||
patch("app.community_mqtt.community_publisher") as mock_pub,
|
||||
patch("app.community_mqtt.asyncio.create_task") as mock_task,
|
||||
):
|
||||
mock_pub.connected = False
|
||||
mock_pub._settings = AppSettings(community_mqtt_enabled=True)
|
||||
community_mqtt_broadcast("raw_packet", {"data": "00"})
|
||||
mock_task.assert_not_called()
|
||||
|
||||
def test_skips_when_settings_none(self):
|
||||
"""Should not publish when settings are None."""
|
||||
with (
|
||||
patch("app.community_mqtt.community_publisher") as mock_pub,
|
||||
patch("app.community_mqtt.asyncio.create_task") as mock_task,
|
||||
):
|
||||
mock_pub.connected = True
|
||||
mock_pub._settings = None
|
||||
community_mqtt_broadcast("raw_packet", {"data": "00"})
|
||||
mock_task.assert_not_called()
|
||||
|
||||
|
||||
class TestParseBrokerAddress:
|
||||
def test_hostname_only_uses_default_port(self):
|
||||
host, port = _parse_broker_address("mqtt-us-v1.letsmesh.net")
|
||||
assert host == "mqtt-us-v1.letsmesh.net"
|
||||
assert port == _DEFAULT_PORT
|
||||
|
||||
def test_hostname_with_port(self):
|
||||
host, port = _parse_broker_address("mqtt-us-v1.letsmesh.net:8883")
|
||||
assert host == "mqtt-us-v1.letsmesh.net"
|
||||
assert port == 8883
|
||||
|
||||
def test_hostname_with_port_443(self):
|
||||
host, port = _parse_broker_address("broker.example.com:443")
|
||||
assert host == "broker.example.com"
|
||||
assert port == 443
|
||||
|
||||
def test_invalid_port_uses_default(self):
|
||||
host, port = _parse_broker_address("broker.example.com:abc")
|
||||
assert host == "broker.example.com:abc"
|
||||
assert port == _DEFAULT_PORT
|
||||
|
||||
def test_empty_string(self):
|
||||
host, port = _parse_broker_address("")
|
||||
assert host == ""
|
||||
assert port == _DEFAULT_PORT
|
||||
|
||||
|
||||
class TestPublishFailureSetsDisconnected:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_error_sets_connected_false(self):
|
||||
"""A publish error should set connected=False so the loop can detect it."""
|
||||
pub = CommunityMqttPublisher()
|
||||
pub.connected = True
|
||||
mock_client = MagicMock()
|
||||
mock_client.publish = MagicMock(side_effect=Exception("broker gone"))
|
||||
pub._client = mock_client
|
||||
await pub.publish("topic", {"data": "test"})
|
||||
assert pub.connected is False
|
||||
+77
-20
@@ -100,8 +100,8 @@ class TestMigration001:
|
||||
# Run migrations
|
||||
applied = await run_migrations(conn)
|
||||
|
||||
assert applied == 31 # All migrations run
|
||||
assert await get_version(conn) == 31
|
||||
assert applied == 32 # All migrations run
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify columns exist by inserting and selecting
|
||||
await conn.execute(
|
||||
@@ -183,9 +183,9 @@ class TestMigration001:
|
||||
applied1 = await run_migrations(conn)
|
||||
applied2 = await run_migrations(conn)
|
||||
|
||||
assert applied1 == 31 # All migrations run
|
||||
assert applied1 == 32 # All migrations run
|
||||
assert applied2 == 0 # No migrations on second run
|
||||
assert await get_version(conn) == 31
|
||||
assert await get_version(conn) == 32
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@@ -246,8 +246,8 @@ class TestMigration001:
|
||||
applied = await run_migrations(conn)
|
||||
|
||||
# All migrations applied (version incremented) but no error
|
||||
assert applied == 31
|
||||
assert await get_version(conn) == 31
|
||||
assert applied == 32
|
||||
assert await get_version(conn) == 32
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@@ -374,10 +374,10 @@ class TestMigration013:
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
# Run migration 13 (plus 14-27 which also run)
|
||||
# Run migration 13 (plus 14-33 which also run)
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 19
|
||||
assert await get_version(conn) == 31
|
||||
assert applied == 20
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify bots array was created with migrated data
|
||||
cursor = await conn.execute("SELECT bots FROM app_settings WHERE id = 1")
|
||||
@@ -497,7 +497,7 @@ class TestMigration018:
|
||||
assert await cursor.fetchone() is not None
|
||||
|
||||
await run_migrations(conn)
|
||||
assert await get_version(conn) == 31
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify autoindex is gone
|
||||
cursor = await conn.execute(
|
||||
@@ -575,8 +575,8 @@ class TestMigration018:
|
||||
await conn.commit()
|
||||
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 14 # Migrations 18-31 run (18+19 skip internally)
|
||||
assert await get_version(conn) == 31
|
||||
assert applied == 15 # Migrations 18-32 run (18+19 skip internally)
|
||||
assert await get_version(conn) == 32
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@@ -648,7 +648,7 @@ class TestMigration019:
|
||||
assert await cursor.fetchone() is not None
|
||||
|
||||
await run_migrations(conn)
|
||||
assert await get_version(conn) == 31
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify autoindex is gone
|
||||
cursor = await conn.execute(
|
||||
@@ -714,8 +714,8 @@ class TestMigration020:
|
||||
assert (await cursor.fetchone())[0] == "delete"
|
||||
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 12 # Migrations 20-31
|
||||
assert await get_version(conn) == 31
|
||||
assert applied == 13 # Migrations 20-32
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify WAL mode
|
||||
cursor = await conn.execute("PRAGMA journal_mode")
|
||||
@@ -745,7 +745,7 @@ class TestMigration020:
|
||||
await set_version(conn, 20)
|
||||
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 11 # Migrations 21-31 still run
|
||||
assert applied == 12 # Migrations 21-32 still run
|
||||
|
||||
# Still WAL + INCREMENTAL
|
||||
cursor = await conn.execute("PRAGMA journal_mode")
|
||||
@@ -803,8 +803,8 @@ class TestMigration028:
|
||||
await conn.commit()
|
||||
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 4
|
||||
assert await get_version(conn) == 31
|
||||
assert applied == 5
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify payload_hash column is now BLOB
|
||||
cursor = await conn.execute("PRAGMA table_info(raw_packets)")
|
||||
@@ -873,8 +873,8 @@ class TestMigration028:
|
||||
await conn.commit()
|
||||
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 4 # Version still bumped
|
||||
assert await get_version(conn) == 31
|
||||
assert applied == 5 # Version still bumped
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify data unchanged
|
||||
cursor = await conn.execute("SELECT payload_hash FROM raw_packets")
|
||||
@@ -882,3 +882,60 @@ class TestMigration028:
|
||||
assert bytes(row["payload_hash"]) == b"\xab" * 32
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
class TestMigration032:
|
||||
"""Test migration 032: add community MQTT columns to app_settings."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_adds_all_community_mqtt_columns(self):
|
||||
"""Migration adds enabled, iata, broker, and email columns."""
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
try:
|
||||
await set_version(conn, 31)
|
||||
|
||||
# Create app_settings without community columns (pre-migration schema)
|
||||
await conn.execute("""
|
||||
CREATE TABLE app_settings (
|
||||
id INTEGER PRIMARY KEY,
|
||||
max_radio_contacts INTEGER DEFAULT 200,
|
||||
favorites TEXT DEFAULT '[]',
|
||||
auto_decrypt_dm_on_advert INTEGER DEFAULT 0,
|
||||
sidebar_sort_order TEXT DEFAULT 'recent',
|
||||
last_message_times TEXT DEFAULT '{}',
|
||||
preferences_migrated INTEGER DEFAULT 0,
|
||||
advert_interval INTEGER DEFAULT 0,
|
||||
last_advert_time INTEGER DEFAULT 0,
|
||||
bots TEXT DEFAULT '[]',
|
||||
mqtt_broker_host TEXT DEFAULT '',
|
||||
mqtt_broker_port INTEGER DEFAULT 1883,
|
||||
mqtt_username TEXT DEFAULT '',
|
||||
mqtt_password TEXT DEFAULT '',
|
||||
mqtt_use_tls INTEGER DEFAULT 0,
|
||||
mqtt_tls_insecure INTEGER DEFAULT 0,
|
||||
mqtt_topic_prefix TEXT DEFAULT 'meshcore',
|
||||
mqtt_publish_messages INTEGER DEFAULT 0,
|
||||
mqtt_publish_raw_packets INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
await conn.execute("INSERT INTO app_settings (id) VALUES (1)")
|
||||
await conn.commit()
|
||||
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 1
|
||||
assert await get_version(conn) == 32
|
||||
|
||||
# Verify all columns exist with correct defaults
|
||||
cursor = await conn.execute(
|
||||
"""SELECT community_mqtt_enabled, community_mqtt_iata,
|
||||
community_mqtt_broker, community_mqtt_email
|
||||
FROM app_settings WHERE id = 1"""
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["community_mqtt_enabled"] == 0
|
||||
assert row["community_mqtt_iata"] == ""
|
||||
assert row["community_mqtt_broker"] == "mqtt-us-v1.letsmesh.net"
|
||||
assert row["community_mqtt_email"] == ""
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@@ -502,6 +502,10 @@ class TestAppSettingsRepository:
|
||||
"mqtt_topic_prefix": "meshcore",
|
||||
"mqtt_publish_messages": 0,
|
||||
"mqtt_publish_raw_packets": 0,
|
||||
"community_mqtt_enabled": 0,
|
||||
"community_mqtt_iata": "",
|
||||
"community_mqtt_broker": "mqtt-us-v1.letsmesh.net",
|
||||
"community_mqtt_email": "",
|
||||
}
|
||||
)
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
|
||||
@@ -117,6 +117,78 @@ class TestUpdateSettings:
|
||||
assert settings.mqtt_publish_messages is False
|
||||
assert settings.mqtt_publish_raw_packets is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_community_mqtt_fields_round_trip(self, test_db):
|
||||
"""Community MQTT settings should be saved and retrieved correctly."""
|
||||
mock_community = type("MockCommunity", (), {"restart": AsyncMock()})()
|
||||
with patch("app.community_mqtt.community_publisher", mock_community):
|
||||
result = await update_settings(
|
||||
AppSettingsUpdate(
|
||||
community_mqtt_enabled=True,
|
||||
community_mqtt_iata="DEN",
|
||||
community_mqtt_broker="custom-broker.example.com",
|
||||
community_mqtt_email="test@example.com",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.community_mqtt_enabled is True
|
||||
assert result.community_mqtt_iata == "DEN"
|
||||
assert result.community_mqtt_broker == "custom-broker.example.com"
|
||||
assert result.community_mqtt_email == "test@example.com"
|
||||
|
||||
# Verify persistence
|
||||
fresh = await AppSettingsRepository.get()
|
||||
assert fresh.community_mqtt_enabled is True
|
||||
assert fresh.community_mqtt_iata == "DEN"
|
||||
assert fresh.community_mqtt_broker == "custom-broker.example.com"
|
||||
assert fresh.community_mqtt_email == "test@example.com"
|
||||
|
||||
# Verify restart was called
|
||||
mock_community.restart.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_community_mqtt_iata_validation_rejects_invalid(self, test_db):
|
||||
"""Invalid IATA codes should be rejected."""
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await update_settings(AppSettingsUpdate(community_mqtt_iata="A"))
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await update_settings(AppSettingsUpdate(community_mqtt_iata="ABCDE"))
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await update_settings(AppSettingsUpdate(community_mqtt_iata="12"))
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await update_settings(AppSettingsUpdate(community_mqtt_iata="ABCD"))
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_community_mqtt_enable_requires_iata(self, test_db):
|
||||
"""Enabling community MQTT without a valid IATA code should be rejected."""
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await update_settings(AppSettingsUpdate(community_mqtt_enabled=True))
|
||||
assert exc.value.status_code == 400
|
||||
assert "IATA" in exc.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_community_mqtt_iata_uppercased(self, test_db):
|
||||
"""IATA codes should be uppercased."""
|
||||
mock_community = type("MockCommunity", (), {"restart": AsyncMock()})()
|
||||
with patch("app.community_mqtt.community_publisher", mock_community):
|
||||
result = await update_settings(AppSettingsUpdate(community_mqtt_iata="den"))
|
||||
assert result.community_mqtt_iata == "DEN"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_community_mqtt_defaults_on_fresh_db(self, test_db):
|
||||
"""Community MQTT fields should have correct defaults on a fresh database."""
|
||||
settings = await AppSettingsRepository.get()
|
||||
assert settings.community_mqtt_enabled is False
|
||||
assert settings.community_mqtt_iata == ""
|
||||
assert settings.community_mqtt_email == ""
|
||||
|
||||
|
||||
class TestToggleFavorite:
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user