"""Tests for app.tcp_proxy.session — ProxySession command handlers.""" import asyncio import time from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.tcp_proxy.protocol import ( CMD_APP_START, CMD_DEVICE_QUERY, CMD_GET_BATT_AND_STORAGE, CMD_GET_CHANNEL, CMD_GET_CONTACT_BY_KEY, CMD_GET_CONTACTS, CMD_GET_DEVICE_TIME, CMD_HAS_CONNECTION, CMD_RESET_PATH, CMD_SEND_CHANNEL_TXT_MSG, CMD_SEND_TXT_MSG, CMD_SET_CHANNEL, CMD_SYNC_NEXT_MESSAGE, ERR_NOT_FOUND, PROXY_FW_VER, PUSH_MSG_WAITING, RESP_BATTERY, RESP_CHANNEL_MSG_RECV_V3, RESP_CONTACT_END, RESP_CONTACT_MSG_RECV_V3, RESP_CONTACT_START, RESP_CURRENT_TIME, RESP_DEVICE_INFO, RESP_ERR, RESP_MSG_SENT, RESP_NO_MORE_MSGS, RESP_OK, RESP_SELF_INFO, encode_path_byte, ) from app.tcp_proxy.session import ProxySession EXAMPLE_KEY = "ab" * 32 # ── Helpers ────────────────────────────────────────────────────────── def _make_session() -> tuple[ProxySession, list[bytes]]: """Create a ProxySession with a capturing writer.""" reader = AsyncMock(spec=asyncio.StreamReader) writer = MagicMock(spec=asyncio.StreamWriter) writer.get_extra_info.return_value = ("127.0.0.1", 12345) sent: list[bytes] = [] def capture_write(data: bytes): sent.append(data) writer.write = capture_write writer.drain = AsyncMock() session = ProxySession(reader, writer) return session, sent def _extract_payloads(sent: list[bytes]) -> list[bytes]: """Extract payloads from framed response bytes.""" payloads = [] for frame in sent: assert frame[0] == 0x3E size = int.from_bytes(frame[1:3], "little") payloads.append(frame[3 : 3 + size]) return payloads def _make_contact(public_key: str = EXAMPLE_KEY, name: str = "Alice", **kw): return MagicMock( model_dump=MagicMock( return_value={ "public_key": public_key, "name": name, "type": 1, "favorite": True, "direct_path": None, "direct_path_len": -1, "direct_path_hash_mode": -1, "last_advert": 0, "lat": 0.0, "lon": 0.0, "first_seen": int(time.time()), **kw, } ) ) def _make_channel(key: str = "cc" * 16, name: str = "test", favorite: bool = True): return MagicMock( model_dump=MagicMock(return_value={"key": key, "name": name, "favorite": favorite}) ) def _make_settings(last_message_times=None): return MagicMock(last_message_times=last_message_times or {}) def _mock_radio_runtime(connected: bool = True, self_info: dict | None = None): rt = MagicMock() rt.is_connected = connected mc = MagicMock() mc.self_info = self_info or { "public_key": EXAMPLE_KEY, "name": "TestNode", "tx_power": 20, "max_tx_power": 22, "adv_lat": 0.0, "adv_lon": 0.0, "radio_freq": 915.0, "radio_bw": 250.0, "radio_sf": 10, "radio_cr": 7, } rt.meshcore = mc return rt # ── Tests ──────────────────────────────────────────────────────────── class TestAppStart: @pytest.mark.asyncio async def test_sends_self_info(self): session, sent = _make_session() contacts = [_make_contact()] channels = [_make_channel()] settings = _make_settings() rt = _mock_radio_runtime() with ( patch("app.repository.ContactRepository") as cr, patch("app.repository.ChannelRepository") as chr_, patch("app.repository.AppSettingsRepository") as sr, patch("app.services.radio_runtime.radio_runtime", rt), ): cr.get_favorites = AsyncMock(return_value=contacts) chr_.get_all = AsyncMock(return_value=channels) sr.get = AsyncMock(return_value=settings) await session._cmd_app_start(bytes([CMD_APP_START])) payloads = _extract_payloads(sent) assert len(payloads) == 1 assert payloads[0][0] == RESP_SELF_INFO @pytest.mark.asyncio async def test_populates_contacts_and_channels(self): session, sent = _make_session() contacts = [_make_contact(), _make_contact(public_key="cd" * 32, name="Bob")] channels = [_make_channel(), _make_channel(key="dd" * 16, name="ch2")] settings = _make_settings() rt = _mock_radio_runtime() with ( patch("app.repository.ContactRepository") as cr, patch("app.repository.ChannelRepository") as chr_, patch("app.repository.AppSettingsRepository") as sr, patch("app.services.radio_runtime.radio_runtime", rt), ): cr.get_favorites = AsyncMock(return_value=contacts) chr_.get_all = AsyncMock(return_value=channels) sr.get = AsyncMock(return_value=settings) await session._cmd_app_start(bytes([CMD_APP_START])) assert len(session.contacts) == 2 # Only favorite channels are slotted assert len(session.channel_slots) == 2 class TestDeviceQuery: @pytest.mark.asyncio async def test_sends_device_info(self): session, sent = _make_session() rt = _mock_radio_runtime() with patch("app.services.radio_runtime.radio_runtime", rt): await session._cmd_device_query(bytes([CMD_DEVICE_QUERY, 0x03])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_DEVICE_INFO assert payloads[0][1] == PROXY_FW_VER class TestGetContacts: @pytest.mark.asyncio async def test_sends_start_contacts_end(self): session, sent = _make_session() contacts = [_make_contact()] with patch("app.repository.ContactRepository") as cr: cr.get_favorites = AsyncMock(return_value=contacts) await session._cmd_get_contacts(bytes([CMD_GET_CONTACTS])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_CONTACT_START count = int.from_bytes(payloads[0][1:5], "little") assert count == 1 # Middle payload(s) are contacts assert payloads[-1][0] == RESP_CONTACT_END class TestGetContactByKey: @pytest.mark.asyncio async def test_found(self): session, sent = _make_session() session.contacts = [ { "public_key": EXAMPLE_KEY, "type": 1, "name": "Alice", "favorite": True, "direct_path": None, "direct_path_len": -1, "direct_path_hash_mode": -1, "last_advert": 0, "lat": 0.0, "lon": 0.0, "first_seen": 0, } ] cmd = bytes([CMD_GET_CONTACT_BY_KEY]) + bytes.fromhex(EXAMPLE_KEY) await session._cmd_get_contact_by_key(cmd) payloads = _extract_payloads(sent) assert len(payloads) == 1 assert payloads[0][0] == 0x03 # RESP_CONTACT @pytest.mark.asyncio async def test_not_found(self): session, sent = _make_session() session.contacts = [] cmd = bytes([CMD_GET_CONTACT_BY_KEY]) + bytes.fromhex(EXAMPLE_KEY) await session._cmd_get_contact_by_key(cmd) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_ERR assert payloads[0][1] == ERR_NOT_FOUND class TestGetChannel: @pytest.mark.asyncio async def test_found(self): session, sent = _make_session() key = "cc" * 16 session.channel_slots = {0: key} session.channels = [{"key": key, "name": "test"}] await session._cmd_get_channel(bytes([CMD_GET_CHANNEL, 0])) payloads = _extract_payloads(sent) assert payloads[0][0] == 0x12 # RESP_CHANNEL_INFO @pytest.mark.asyncio async def test_empty_slot_returns_error(self): session, sent = _make_session() session.channel_slots = {} await session._cmd_get_channel(bytes([CMD_GET_CHANNEL, 5])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_ERR class TestSetChannel: @pytest.mark.asyncio async def test_updates_slot_mapping(self): session, sent = _make_session() name = b"test" + b"\x00" * 28 # 32 bytes secret = b"\xaa" * 16 cmd = bytes([CMD_SET_CHANNEL, 3]) + name + secret await session._cmd_set_channel(cmd) assert session.channel_slots[3] == "aa" * 16 assert session.key_to_idx["aa" * 16] == 3 payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_OK @pytest.mark.asyncio async def test_cleans_stale_mapping(self): session, sent = _make_session() # Pre-load slot 0 with key_a session.channel_slots[0] = "aa" * 16 session.key_to_idx["aa" * 16] = 0 # Overwrite slot 0 with key_b name = b"\x00" * 32 secret_b = b"\xbb" * 16 cmd = bytes([CMD_SET_CHANNEL, 0]) + name + secret_b await session._cmd_set_channel(cmd) assert session.channel_slots[0] == "bb" * 16 assert "aa" * 16 not in session.key_to_idx class TestSendDm: @pytest.mark.asyncio async def test_sends_msg_sent_and_ack(self): session, sent = _make_session() session.contacts = [{"public_key": EXAMPLE_KEY}] # CMD_SEND_TXT_MSG: cmd(1) + txt_type(1) + attempt(1) + ts(4) + prefix(6) + text prefix = bytes.fromhex(EXAMPLE_KEY[:12]) cmd = ( bytes([CMD_SEND_TXT_MSG, 0, 0]) + int(time.time()).to_bytes(4, "little") + prefix + b"Hello" ) with patch.object(session, "_do_send_dm", new_callable=AsyncMock): await session._cmd_send_dm(cmd) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_MSG_SENT assert payloads[1][0] == 0x82 # PUSH_ACK # ACK code should match ack_from_sent = payloads[0][2:6] ack_from_push = payloads[1][1:5] assert ack_from_sent == ack_from_push @pytest.mark.asyncio async def test_long_text_with_prefix(self): """6-byte prefix + long text (>26 chars) must resolve correctly.""" session, sent = _make_session() session.contacts = [{"public_key": EXAMPLE_KEY}] prefix = bytes.fromhex(EXAMPLE_KEY[:12]) long_text = b"A" * 50 # well over 26 chars cmd = ( bytes([CMD_SEND_TXT_MSG, 0, 0]) + int(time.time()).to_bytes(4, "little") + prefix + long_text ) with patch.object(session, "_do_send_dm", new_callable=AsyncMock) as mock_send: await session._cmd_send_dm(cmd) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_MSG_SENT # not ERR mock_send.assert_called_once() call_key, call_text = mock_send.call_args[0] assert call_key == EXAMPLE_KEY assert call_text == "A" * 50 class TestSendChannel: @pytest.mark.asyncio async def test_sends_ok(self): session, sent = _make_session() key = "cc" * 16 session.channel_slots = {0: key} session.channels = [{"key": key, "name": "test"}] cmd = ( bytes([CMD_SEND_CHANNEL_TXT_MSG, 0, 0]) + int(time.time()).to_bytes(4, "little") + b"Hello" ) fake_channel = MagicMock(name="test") with ( patch( "app.repository.ChannelRepository.get_by_key", new_callable=AsyncMock, return_value=fake_channel, ), patch.object(session, "_do_send_channel", new_callable=AsyncMock), ): await session._cmd_send_channel(cmd) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_OK class TestSimpleCommands: @pytest.mark.asyncio async def test_get_time(self): session, sent = _make_session() await session._cmd_get_time(bytes([CMD_GET_DEVICE_TIME])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_CURRENT_TIME @pytest.mark.asyncio async def test_battery(self): session, sent = _make_session() await session._cmd_battery(bytes([CMD_GET_BATT_AND_STORAGE])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_BATTERY @pytest.mark.asyncio async def test_has_connection(self): session, sent = _make_session() rt = _mock_radio_runtime(connected=True) with patch("app.services.radio_runtime.radio_runtime", rt): await session._cmd_has_connection(bytes([CMD_HAS_CONNECTION])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_OK val = int.from_bytes(payloads[0][1:5], "little") assert val == 1 @pytest.mark.asyncio async def test_ok_stub(self): session, sent = _make_session() await session._cmd_ok_stub(bytes([CMD_RESET_PATH])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_OK class TestSyncNext: @pytest.mark.asyncio async def test_empty_queue(self): session, sent = _make_session() await session._cmd_sync_next(bytes([CMD_SYNC_NEXT_MESSAGE])) payloads = _extract_payloads(sent) assert payloads[0][0] == RESP_NO_MORE_MSGS @pytest.mark.asyncio async def test_dequeues_message(self): session, sent = _make_session() fake_msg = bytes([0x10, 0x00, 0x00, 0x00]) + b"\xaa" * 10 session._msg_queue.append(fake_msg) await session._cmd_sync_next(bytes([CMD_SYNC_NEXT_MESSAGE])) payloads = _extract_payloads(sent) assert payloads[0] == fake_msg assert len(session._msg_queue) == 0 class TestExtractPathMeta: """Tests for _extract_path_meta static helper.""" def test_no_paths(self): snr, path_byte = ProxySession._extract_path_meta({"paths": None}) assert snr == 0 assert path_byte == 0 # 0 hops, mode 0 def test_empty_paths_list(self): snr, path_byte = ProxySession._extract_path_meta({"paths": []}) assert snr == 0 assert path_byte == 0 def test_one_byte_hops(self): """2 hops at 1-byte hash mode → path_byte = (0 << 6) | 2 = 0x02.""" snr, path_byte = ProxySession._extract_path_meta( { "paths": [{"path": "aabb", "path_len": 2, "snr": None, "rssi": None}], } ) assert path_byte == encode_path_byte(2, 0) assert path_byte == 0x02 def test_two_byte_hops(self): """3 hops at 2-byte hash mode → path_byte = (1 << 6) | 3 = 0x43.""" snr, path_byte = ProxySession._extract_path_meta( { "paths": [{"path": "aabbccddee11", "path_len": 3, "snr": None, "rssi": None}], } ) assert path_byte == encode_path_byte(3, 1) assert path_byte == 0x43 def test_three_byte_hops(self): """1 hop at 3-byte hash mode → path_byte = (2 << 6) | 1 = 0x81.""" snr, path_byte = ProxySession._extract_path_meta( { "paths": [{"path": "aabbcc", "path_len": 1, "snr": None, "rssi": None}], } ) assert path_byte == encode_path_byte(1, 2) assert path_byte == 0x81 def test_snr_encoded(self): """SNR is encoded as int8(snr * 4).""" snr, _ = ProxySession._extract_path_meta( { "paths": [{"path": "aa", "path_len": 1, "snr": -5.25, "rssi": -100}], } ) assert snr == (-21) & 0xFF # -5.25 * 4 = -21 → unsigned byte def test_zero_hops_empty_path(self): """0 hops, empty path → path_byte 0.""" snr, path_byte = ProxySession._extract_path_meta( { "paths": [{"path": "", "path_len": 0, "snr": None, "rssi": None}], } ) assert path_byte == 0 def test_legacy_no_path_len(self): """path_len=None falls back to inferring from hex length (1-byte hops).""" snr, path_byte = ProxySession._extract_path_meta( { "paths": [{"path": "aabb", "path_len": None, "snr": None, "rssi": None}], } ) # Inferred: 2 hops, path is 2 bytes → 1-byte hash → mode 0 assert path_byte == encode_path_byte(2, 0) class TestEventHandlers: @pytest.mark.asyncio async def test_priv_message_queued(self): session, sent = _make_session() data = { "type": "PRIV", "outgoing": False, "conversation_key": EXAMPLE_KEY, "text": "hello", "sender_timestamp": 1700000000, } await session.on_event_message(data) assert len(session._msg_queue) == 1 payloads = _extract_payloads(sent) assert payloads[0][0] == PUSH_MSG_WAITING @pytest.mark.asyncio async def test_priv_message_path_encoding(self): """DM frame encodes path_len byte from message path data.""" session, sent = _make_session() data = { "type": "PRIV", "outgoing": False, "conversation_key": EXAMPLE_KEY, "text": "hi", "sender_timestamp": 1700000000, "paths": [{"path": "aabb", "path_len": 2, "snr": 3.0, "rssi": -80}], } await session.on_event_message(data) frame = session._msg_queue[0] assert frame[0] == RESP_CONTACT_MSG_RECV_V3 snr_byte = frame[1] assert snr_byte == 12 # 3.0 * 4 # path_len byte is at offset 10 (after: type, snr, 2 reserved, 6 prefix) path_byte = frame[10] assert path_byte == encode_path_byte(2, 0) # 2 hops, 1-byte hash @pytest.mark.asyncio async def test_chan_message_queued(self): session, sent = _make_session() key = "cc" * 16 session.key_to_idx = {key: 0} data = { "type": "CHAN", "outgoing": False, "conversation_key": key.upper(), # test case normalization "text": "hello", "sender_timestamp": 1700000000, } await session.on_event_message(data) assert len(session._msg_queue) == 1 @pytest.mark.asyncio async def test_chan_message_path_encoding(self): """Channel frame encodes path_len byte correctly instead of 0xFF.""" session, sent = _make_session() key = "cc" * 16 session.key_to_idx = {key: 0} data = { "type": "CHAN", "outgoing": False, "conversation_key": key, "text": "hello", "sender_timestamp": 1700000000, "paths": [{"path": "aabbccdd", "path_len": 2, "snr": -2.5, "rssi": -90}], } await session.on_event_message(data) frame = session._msg_queue[0] assert frame[0] == RESP_CHANNEL_MSG_RECV_V3 snr_byte = frame[1] assert snr_byte == (-10) & 0xFF # -2.5 * 4 # path_len byte is at offset 5 (after: type, snr, 2 reserved, channel_idx) path_byte = frame[5] assert path_byte == encode_path_byte(2, 1) # 2 hops, 2-byte hash assert path_byte != 0xFF # Must NOT be the old wrong value @pytest.mark.asyncio async def test_chan_message_no_paths_defaults_zero(self): """Channel message with no path data uses 0 (not 0xFF).""" session, sent = _make_session() key = "cc" * 16 session.key_to_idx = {key: 0} data = { "type": "CHAN", "outgoing": False, "conversation_key": key, "text": "hello", "sender_timestamp": 1700000000, } await session.on_event_message(data) frame = session._msg_queue[0] path_byte = frame[5] assert path_byte == 0 # 0 hops, not 0xFF @pytest.mark.asyncio async def test_outgoing_message_ignored(self): session, sent = _make_session() data = {"type": "PRIV", "outgoing": True, "conversation_key": EXAMPLE_KEY} await session.on_event_message(data) assert len(session._msg_queue) == 0 assert len(sent) == 0 @pytest.mark.asyncio async def test_chan_unmapped_dropped(self): session, sent = _make_session() session.key_to_idx = {} data = { "type": "CHAN", "outgoing": False, "conversation_key": "ff" * 16, "text": "hello", "sender_timestamp": 0, } await session.on_event_message(data) assert len(session._msg_queue) == 0 @pytest.mark.asyncio async def test_contact_event_updates_existing_cache(self): session, sent = _make_session() # Contact must already be in favorites cache to receive pushes session.contacts = [ { "public_key": EXAMPLE_KEY, "name": "Old", "type": 1, "favorite": True, "direct_path": None, "direct_path_len": -1, "direct_path_hash_mode": -1, "last_advert": 0, "lat": 0.0, "lon": 0.0, "first_seen": 0, } ] data = { "public_key": EXAMPLE_KEY, "type": 1, "name": "Updated", "favorite": True, "direct_path": None, "direct_path_len": -1, "direct_path_hash_mode": -1, "last_advert": 100, "lat": 0.0, "lon": 0.0, "first_seen": 0, } await session.on_event_contact(data) assert len(session.contacts) == 1 assert session.contacts[0]["name"] == "Updated" # Should have sent a PUSH_NEW_ADVERT payloads = _extract_payloads(sent) assert payloads[0][0] == 0x8A # PUSH_NEW_ADVERT @pytest.mark.asyncio async def test_contact_event_ignored_for_non_favorites(self): session, sent = _make_session() session.contacts = [] data = { "public_key": EXAMPLE_KEY, "type": 1, "name": "Stranger", "favorite": False, } await session.on_event_contact(data) assert len(session.contacts) == 0 assert len(sent) == 0