Files
Remote-Terminal-for-MeshCore/tests/test_tcp_proxy_session.py
T

696 lines
22 KiB
Python

"""Tests for app.tcp_proxy.session — ProxySession command handlers."""
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.tcp_proxy.protocol import (
CMD_APP_START,
CMD_DEVICE_QUERY,
CMD_GET_BATT_AND_STORAGE,
CMD_GET_CHANNEL,
CMD_GET_CONTACT_BY_KEY,
CMD_GET_CONTACTS,
CMD_GET_DEVICE_TIME,
CMD_HAS_CONNECTION,
CMD_RESET_PATH,
CMD_SEND_CHANNEL_TXT_MSG,
CMD_SEND_TXT_MSG,
CMD_SET_CHANNEL,
CMD_SYNC_NEXT_MESSAGE,
ERR_NOT_FOUND,
PROXY_FW_VER,
PUSH_MSG_WAITING,
RESP_BATTERY,
RESP_CHANNEL_MSG_RECV_V3,
RESP_CONTACT_END,
RESP_CONTACT_MSG_RECV_V3,
RESP_CONTACT_START,
RESP_CURRENT_TIME,
RESP_DEVICE_INFO,
RESP_ERR,
RESP_MSG_SENT,
RESP_NO_MORE_MSGS,
RESP_OK,
RESP_SELF_INFO,
encode_path_byte,
)
from app.tcp_proxy.session import ProxySession
EXAMPLE_KEY = "ab" * 32
# ── Helpers ──────────────────────────────────────────────────────────
def _make_session() -> tuple[ProxySession, list[bytes]]:
"""Create a ProxySession with a capturing writer."""
reader = AsyncMock(spec=asyncio.StreamReader)
writer = MagicMock(spec=asyncio.StreamWriter)
writer.get_extra_info.return_value = ("127.0.0.1", 12345)
sent: list[bytes] = []
def capture_write(data: bytes):
sent.append(data)
writer.write = capture_write
writer.drain = AsyncMock()
session = ProxySession(reader, writer)
return session, sent
def _extract_payloads(sent: list[bytes]) -> list[bytes]:
"""Extract payloads from framed response bytes."""
payloads = []
for frame in sent:
assert frame[0] == 0x3E
size = int.from_bytes(frame[1:3], "little")
payloads.append(frame[3 : 3 + size])
return payloads
def _make_contact(public_key: str = EXAMPLE_KEY, name: str = "Alice", **kw):
return MagicMock(
model_dump=MagicMock(
return_value={
"public_key": public_key,
"name": name,
"type": 1,
"favorite": True,
"direct_path": None,
"direct_path_len": -1,
"direct_path_hash_mode": -1,
"last_advert": 0,
"lat": 0.0,
"lon": 0.0,
"first_seen": int(time.time()),
**kw,
}
)
)
def _make_channel(key: str = "cc" * 16, name: str = "test", favorite: bool = True):
return MagicMock(
model_dump=MagicMock(return_value={"key": key, "name": name, "favorite": favorite})
)
def _make_settings(last_message_times=None):
return MagicMock(last_message_times=last_message_times or {})
def _mock_radio_runtime(connected: bool = True, self_info: dict | None = None):
rt = MagicMock()
rt.is_connected = connected
mc = MagicMock()
mc.self_info = self_info or {
"public_key": EXAMPLE_KEY,
"name": "TestNode",
"tx_power": 20,
"max_tx_power": 22,
"adv_lat": 0.0,
"adv_lon": 0.0,
"radio_freq": 915.0,
"radio_bw": 250.0,
"radio_sf": 10,
"radio_cr": 7,
}
rt.meshcore = mc
return rt
# ── Tests ────────────────────────────────────────────────────────────
class TestAppStart:
@pytest.mark.asyncio
async def test_sends_self_info(self):
session, sent = _make_session()
contacts = [_make_contact()]
channels = [_make_channel()]
settings = _make_settings()
rt = _mock_radio_runtime()
with (
patch("app.repository.ContactRepository") as cr,
patch("app.repository.ChannelRepository") as chr_,
patch("app.repository.AppSettingsRepository") as sr,
patch("app.services.radio_runtime.radio_runtime", rt),
):
cr.get_favorites = AsyncMock(return_value=contacts)
chr_.get_all = AsyncMock(return_value=channels)
sr.get = AsyncMock(return_value=settings)
await session._cmd_app_start(bytes([CMD_APP_START]))
payloads = _extract_payloads(sent)
assert len(payloads) == 1
assert payloads[0][0] == RESP_SELF_INFO
@pytest.mark.asyncio
async def test_populates_contacts_and_channels(self):
session, sent = _make_session()
contacts = [_make_contact(), _make_contact(public_key="cd" * 32, name="Bob")]
channels = [_make_channel(), _make_channel(key="dd" * 16, name="ch2")]
settings = _make_settings()
rt = _mock_radio_runtime()
with (
patch("app.repository.ContactRepository") as cr,
patch("app.repository.ChannelRepository") as chr_,
patch("app.repository.AppSettingsRepository") as sr,
patch("app.services.radio_runtime.radio_runtime", rt),
):
cr.get_favorites = AsyncMock(return_value=contacts)
chr_.get_all = AsyncMock(return_value=channels)
sr.get = AsyncMock(return_value=settings)
await session._cmd_app_start(bytes([CMD_APP_START]))
assert len(session.contacts) == 2
# Only favorite channels are slotted
assert len(session.channel_slots) == 2
class TestDeviceQuery:
@pytest.mark.asyncio
async def test_sends_device_info(self):
session, sent = _make_session()
rt = _mock_radio_runtime()
with patch("app.services.radio_runtime.radio_runtime", rt):
await session._cmd_device_query(bytes([CMD_DEVICE_QUERY, 0x03]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_DEVICE_INFO
assert payloads[0][1] == PROXY_FW_VER
class TestGetContacts:
@pytest.mark.asyncio
async def test_sends_start_contacts_end(self):
session, sent = _make_session()
contacts = [_make_contact()]
with patch("app.repository.ContactRepository") as cr:
cr.get_favorites = AsyncMock(return_value=contacts)
await session._cmd_get_contacts(bytes([CMD_GET_CONTACTS]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_CONTACT_START
count = int.from_bytes(payloads[0][1:5], "little")
assert count == 1
# Middle payload(s) are contacts
assert payloads[-1][0] == RESP_CONTACT_END
class TestGetContactByKey:
@pytest.mark.asyncio
async def test_found(self):
session, sent = _make_session()
session.contacts = [
{
"public_key": EXAMPLE_KEY,
"type": 1,
"name": "Alice",
"favorite": True,
"direct_path": None,
"direct_path_len": -1,
"direct_path_hash_mode": -1,
"last_advert": 0,
"lat": 0.0,
"lon": 0.0,
"first_seen": 0,
}
]
cmd = bytes([CMD_GET_CONTACT_BY_KEY]) + bytes.fromhex(EXAMPLE_KEY)
await session._cmd_get_contact_by_key(cmd)
payloads = _extract_payloads(sent)
assert len(payloads) == 1
assert payloads[0][0] == 0x03 # RESP_CONTACT
@pytest.mark.asyncio
async def test_not_found(self):
session, sent = _make_session()
session.contacts = []
cmd = bytes([CMD_GET_CONTACT_BY_KEY]) + bytes.fromhex(EXAMPLE_KEY)
await session._cmd_get_contact_by_key(cmd)
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_ERR
assert payloads[0][1] == ERR_NOT_FOUND
class TestGetChannel:
@pytest.mark.asyncio
async def test_found(self):
session, sent = _make_session()
key = "cc" * 16
session.channel_slots = {0: key}
session.channels = [{"key": key, "name": "test"}]
await session._cmd_get_channel(bytes([CMD_GET_CHANNEL, 0]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == 0x12 # RESP_CHANNEL_INFO
@pytest.mark.asyncio
async def test_empty_slot_returns_error(self):
session, sent = _make_session()
session.channel_slots = {}
await session._cmd_get_channel(bytes([CMD_GET_CHANNEL, 5]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_ERR
class TestSetChannel:
@pytest.mark.asyncio
async def test_updates_slot_mapping(self):
session, sent = _make_session()
name = b"test" + b"\x00" * 28 # 32 bytes
secret = b"\xaa" * 16
cmd = bytes([CMD_SET_CHANNEL, 3]) + name + secret
await session._cmd_set_channel(cmd)
assert session.channel_slots[3] == "aa" * 16
assert session.key_to_idx["aa" * 16] == 3
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_OK
@pytest.mark.asyncio
async def test_cleans_stale_mapping(self):
session, sent = _make_session()
# Pre-load slot 0 with key_a
session.channel_slots[0] = "aa" * 16
session.key_to_idx["aa" * 16] = 0
# Overwrite slot 0 with key_b
name = b"\x00" * 32
secret_b = b"\xbb" * 16
cmd = bytes([CMD_SET_CHANNEL, 0]) + name + secret_b
await session._cmd_set_channel(cmd)
assert session.channel_slots[0] == "bb" * 16
assert "aa" * 16 not in session.key_to_idx
class TestSendDm:
@pytest.mark.asyncio
async def test_sends_msg_sent_and_ack(self):
session, sent = _make_session()
session.contacts = [{"public_key": EXAMPLE_KEY}]
# CMD_SEND_TXT_MSG: cmd(1) + txt_type(1) + attempt(1) + ts(4) + prefix(6) + text
prefix = bytes.fromhex(EXAMPLE_KEY[:12])
cmd = (
bytes([CMD_SEND_TXT_MSG, 0, 0])
+ int(time.time()).to_bytes(4, "little")
+ prefix
+ b"Hello"
)
with patch.object(session, "_do_send_dm", new_callable=AsyncMock):
await session._cmd_send_dm(cmd)
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_MSG_SENT
assert payloads[1][0] == 0x82 # PUSH_ACK
# ACK code should match
ack_from_sent = payloads[0][2:6]
ack_from_push = payloads[1][1:5]
assert ack_from_sent == ack_from_push
@pytest.mark.asyncio
async def test_long_text_with_prefix(self):
"""6-byte prefix + long text (>26 chars) must resolve correctly."""
session, sent = _make_session()
session.contacts = [{"public_key": EXAMPLE_KEY}]
prefix = bytes.fromhex(EXAMPLE_KEY[:12])
long_text = b"A" * 50 # well over 26 chars
cmd = (
bytes([CMD_SEND_TXT_MSG, 0, 0])
+ int(time.time()).to_bytes(4, "little")
+ prefix
+ long_text
)
with patch.object(session, "_do_send_dm", new_callable=AsyncMock) as mock_send:
await session._cmd_send_dm(cmd)
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_MSG_SENT # not ERR
mock_send.assert_called_once()
call_key, call_text = mock_send.call_args[0]
assert call_key == EXAMPLE_KEY
assert call_text == "A" * 50
class TestSendChannel:
@pytest.mark.asyncio
async def test_sends_ok(self):
session, sent = _make_session()
key = "cc" * 16
session.channel_slots = {0: key}
session.channels = [{"key": key, "name": "test"}]
cmd = (
bytes([CMD_SEND_CHANNEL_TXT_MSG, 0, 0])
+ int(time.time()).to_bytes(4, "little")
+ b"Hello"
)
fake_channel = MagicMock(name="test")
with (
patch(
"app.repository.ChannelRepository.get_by_key",
new_callable=AsyncMock,
return_value=fake_channel,
),
patch.object(session, "_do_send_channel", new_callable=AsyncMock),
):
await session._cmd_send_channel(cmd)
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_OK
class TestSimpleCommands:
@pytest.mark.asyncio
async def test_get_time(self):
session, sent = _make_session()
await session._cmd_get_time(bytes([CMD_GET_DEVICE_TIME]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_CURRENT_TIME
@pytest.mark.asyncio
async def test_battery(self):
session, sent = _make_session()
await session._cmd_battery(bytes([CMD_GET_BATT_AND_STORAGE]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_BATTERY
@pytest.mark.asyncio
async def test_has_connection(self):
session, sent = _make_session()
rt = _mock_radio_runtime(connected=True)
with patch("app.services.radio_runtime.radio_runtime", rt):
await session._cmd_has_connection(bytes([CMD_HAS_CONNECTION]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_OK
val = int.from_bytes(payloads[0][1:5], "little")
assert val == 1
@pytest.mark.asyncio
async def test_ok_stub(self):
session, sent = _make_session()
await session._cmd_ok_stub(bytes([CMD_RESET_PATH]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_OK
class TestSyncNext:
@pytest.mark.asyncio
async def test_empty_queue(self):
session, sent = _make_session()
await session._cmd_sync_next(bytes([CMD_SYNC_NEXT_MESSAGE]))
payloads = _extract_payloads(sent)
assert payloads[0][0] == RESP_NO_MORE_MSGS
@pytest.mark.asyncio
async def test_dequeues_message(self):
session, sent = _make_session()
fake_msg = bytes([0x10, 0x00, 0x00, 0x00]) + b"\xaa" * 10
session._msg_queue.append(fake_msg)
await session._cmd_sync_next(bytes([CMD_SYNC_NEXT_MESSAGE]))
payloads = _extract_payloads(sent)
assert payloads[0] == fake_msg
assert len(session._msg_queue) == 0
class TestExtractPathMeta:
"""Tests for _extract_path_meta static helper."""
def test_no_paths(self):
snr, path_byte = ProxySession._extract_path_meta({"paths": None})
assert snr == 0
assert path_byte == 0 # 0 hops, mode 0
def test_empty_paths_list(self):
snr, path_byte = ProxySession._extract_path_meta({"paths": []})
assert snr == 0
assert path_byte == 0
def test_one_byte_hops(self):
"""2 hops at 1-byte hash mode → path_byte = (0 << 6) | 2 = 0x02."""
snr, path_byte = ProxySession._extract_path_meta(
{
"paths": [{"path": "aabb", "path_len": 2, "snr": None, "rssi": None}],
}
)
assert path_byte == encode_path_byte(2, 0)
assert path_byte == 0x02
def test_two_byte_hops(self):
"""3 hops at 2-byte hash mode → path_byte = (1 << 6) | 3 = 0x43."""
snr, path_byte = ProxySession._extract_path_meta(
{
"paths": [{"path": "aabbccddee11", "path_len": 3, "snr": None, "rssi": None}],
}
)
assert path_byte == encode_path_byte(3, 1)
assert path_byte == 0x43
def test_three_byte_hops(self):
"""1 hop at 3-byte hash mode → path_byte = (2 << 6) | 1 = 0x81."""
snr, path_byte = ProxySession._extract_path_meta(
{
"paths": [{"path": "aabbcc", "path_len": 1, "snr": None, "rssi": None}],
}
)
assert path_byte == encode_path_byte(1, 2)
assert path_byte == 0x81
def test_snr_encoded(self):
"""SNR is encoded as int8(snr * 4)."""
snr, _ = ProxySession._extract_path_meta(
{
"paths": [{"path": "aa", "path_len": 1, "snr": -5.25, "rssi": -100}],
}
)
assert snr == (-21) & 0xFF # -5.25 * 4 = -21 → unsigned byte
def test_zero_hops_empty_path(self):
"""0 hops, empty path → path_byte 0."""
snr, path_byte = ProxySession._extract_path_meta(
{
"paths": [{"path": "", "path_len": 0, "snr": None, "rssi": None}],
}
)
assert path_byte == 0
def test_legacy_no_path_len(self):
"""path_len=None falls back to inferring from hex length (1-byte hops)."""
snr, path_byte = ProxySession._extract_path_meta(
{
"paths": [{"path": "aabb", "path_len": None, "snr": None, "rssi": None}],
}
)
# Inferred: 2 hops, path is 2 bytes → 1-byte hash → mode 0
assert path_byte == encode_path_byte(2, 0)
class TestEventHandlers:
@pytest.mark.asyncio
async def test_priv_message_queued(self):
session, sent = _make_session()
data = {
"type": "PRIV",
"outgoing": False,
"conversation_key": EXAMPLE_KEY,
"text": "hello",
"sender_timestamp": 1700000000,
}
await session.on_event_message(data)
assert len(session._msg_queue) == 1
payloads = _extract_payloads(sent)
assert payloads[0][0] == PUSH_MSG_WAITING
@pytest.mark.asyncio
async def test_priv_message_path_encoding(self):
"""DM frame encodes path_len byte from message path data."""
session, sent = _make_session()
data = {
"type": "PRIV",
"outgoing": False,
"conversation_key": EXAMPLE_KEY,
"text": "hi",
"sender_timestamp": 1700000000,
"paths": [{"path": "aabb", "path_len": 2, "snr": 3.0, "rssi": -80}],
}
await session.on_event_message(data)
frame = session._msg_queue[0]
assert frame[0] == RESP_CONTACT_MSG_RECV_V3
snr_byte = frame[1]
assert snr_byte == 12 # 3.0 * 4
# path_len byte is at offset 10 (after: type, snr, 2 reserved, 6 prefix)
path_byte = frame[10]
assert path_byte == encode_path_byte(2, 0) # 2 hops, 1-byte hash
@pytest.mark.asyncio
async def test_chan_message_queued(self):
session, sent = _make_session()
key = "cc" * 16
session.key_to_idx = {key: 0}
data = {
"type": "CHAN",
"outgoing": False,
"conversation_key": key.upper(), # test case normalization
"text": "hello",
"sender_timestamp": 1700000000,
}
await session.on_event_message(data)
assert len(session._msg_queue) == 1
@pytest.mark.asyncio
async def test_chan_message_path_encoding(self):
"""Channel frame encodes path_len byte correctly instead of 0xFF."""
session, sent = _make_session()
key = "cc" * 16
session.key_to_idx = {key: 0}
data = {
"type": "CHAN",
"outgoing": False,
"conversation_key": key,
"text": "hello",
"sender_timestamp": 1700000000,
"paths": [{"path": "aabbccdd", "path_len": 2, "snr": -2.5, "rssi": -90}],
}
await session.on_event_message(data)
frame = session._msg_queue[0]
assert frame[0] == RESP_CHANNEL_MSG_RECV_V3
snr_byte = frame[1]
assert snr_byte == (-10) & 0xFF # -2.5 * 4
# path_len byte is at offset 5 (after: type, snr, 2 reserved, channel_idx)
path_byte = frame[5]
assert path_byte == encode_path_byte(2, 1) # 2 hops, 2-byte hash
assert path_byte != 0xFF # Must NOT be the old wrong value
@pytest.mark.asyncio
async def test_chan_message_no_paths_defaults_zero(self):
"""Channel message with no path data uses 0 (not 0xFF)."""
session, sent = _make_session()
key = "cc" * 16
session.key_to_idx = {key: 0}
data = {
"type": "CHAN",
"outgoing": False,
"conversation_key": key,
"text": "hello",
"sender_timestamp": 1700000000,
}
await session.on_event_message(data)
frame = session._msg_queue[0]
path_byte = frame[5]
assert path_byte == 0 # 0 hops, not 0xFF
@pytest.mark.asyncio
async def test_outgoing_message_ignored(self):
session, sent = _make_session()
data = {"type": "PRIV", "outgoing": True, "conversation_key": EXAMPLE_KEY}
await session.on_event_message(data)
assert len(session._msg_queue) == 0
assert len(sent) == 0
@pytest.mark.asyncio
async def test_chan_unmapped_dropped(self):
session, sent = _make_session()
session.key_to_idx = {}
data = {
"type": "CHAN",
"outgoing": False,
"conversation_key": "ff" * 16,
"text": "hello",
"sender_timestamp": 0,
}
await session.on_event_message(data)
assert len(session._msg_queue) == 0
@pytest.mark.asyncio
async def test_contact_event_updates_existing_cache(self):
session, sent = _make_session()
# Contact must already be in favorites cache to receive pushes
session.contacts = [
{
"public_key": EXAMPLE_KEY,
"name": "Old",
"type": 1,
"favorite": True,
"direct_path": None,
"direct_path_len": -1,
"direct_path_hash_mode": -1,
"last_advert": 0,
"lat": 0.0,
"lon": 0.0,
"first_seen": 0,
}
]
data = {
"public_key": EXAMPLE_KEY,
"type": 1,
"name": "Updated",
"favorite": True,
"direct_path": None,
"direct_path_len": -1,
"direct_path_hash_mode": -1,
"last_advert": 100,
"lat": 0.0,
"lon": 0.0,
"first_seen": 0,
}
await session.on_event_contact(data)
assert len(session.contacts) == 1
assert session.contacts[0]["name"] == "Updated"
# Should have sent a PUSH_NEW_ADVERT
payloads = _extract_payloads(sent)
assert payloads[0][0] == 0x8A # PUSH_NEW_ADVERT
@pytest.mark.asyncio
async def test_contact_event_ignored_for_non_favorites(self):
session, sent = _make_session()
session.contacts = []
data = {
"public_key": EXAMPLE_KEY,
"type": 1,
"name": "Stranger",
"favorite": False,
}
await session.on_event_contact(data)
assert len(session.contacts) == 0
assert len(sent) == 0