mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-03-28 17:43:05 +01:00
480 lines
16 KiB
Python
480 lines
16 KiB
Python
"""Tests for the packets router.
|
|
|
|
Covers the historical channel decryption endpoint, background task,
|
|
undecrypted count endpoint, and the maintenance endpoint.
|
|
"""
|
|
|
|
import time
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from app.repository import ChannelRepository, MessageRepository, RawPacketRepository
|
|
|
|
|
|
async def _insert_raw_packets(count: int, decrypted: bool = False, age_days: int = 0) -> list[int]:
|
|
"""Insert raw packets and return their IDs."""
|
|
ids = []
|
|
base_ts = int(time.time()) - (age_days * 86400)
|
|
for i in range(count):
|
|
packet_id, _ = await RawPacketRepository.create(
|
|
f"packet_data_{i}_{age_days}_{decrypted}".encode(), base_ts + i
|
|
)
|
|
if decrypted:
|
|
# Create a message and link it
|
|
msg_id = await MessageRepository.create(
|
|
msg_type="CHAN",
|
|
text=f"decrypted msg {i}",
|
|
conversation_key="DEADBEEF" * 4,
|
|
sender_timestamp=base_ts + i,
|
|
received_at=base_ts + i,
|
|
)
|
|
if msg_id is not None:
|
|
await RawPacketRepository.mark_decrypted(packet_id, msg_id)
|
|
ids.append(packet_id)
|
|
return ids
|
|
|
|
|
|
class TestUndecryptedCount:
|
|
"""Test GET /api/packets/undecrypted/count."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_zero_when_empty(self, test_db, client):
|
|
response = await client.get("/api/packets/undecrypted/count")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["count"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_counts_only_undecrypted(self, test_db, client):
|
|
await _insert_raw_packets(3, decrypted=False)
|
|
await _insert_raw_packets(2, decrypted=True)
|
|
|
|
response = await client.get("/api/packets/undecrypted/count")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["count"] == 3
|
|
|
|
|
|
class TestDecryptHistoricalPackets:
|
|
"""Test POST /api/packets/decrypt/historical."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_decrypt_with_hex_key(self, test_db, client):
|
|
"""Channel decryption with a valid hex key starts background task."""
|
|
await _insert_raw_packets(5)
|
|
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "channel",
|
|
"channel_key": "0123456789abcdef0123456789abcdef",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
data = response.json()
|
|
assert data["started"] is True
|
|
assert data["total_packets"] == 5
|
|
assert "background" in data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_decrypt_with_hashtag_name(self, test_db, client):
|
|
"""Channel decryption with a channel name derives key from hash."""
|
|
await _insert_raw_packets(3)
|
|
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "channel",
|
|
"channel_name": "#general",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
data = response.json()
|
|
assert data["started"] is True
|
|
assert data["total_packets"] == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_decrypt_invalid_hex(self, test_db, client):
|
|
"""Invalid hex string for channel key returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "channel",
|
|
"channel_key": "not_valid_hex",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "invalid" in data["detail"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_decrypt_wrong_key_length(self, test_db, client):
|
|
"""Channel key with wrong length returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "channel",
|
|
"channel_key": "aabbccdd", # Only 4 bytes, need 16
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "16 bytes" in data["detail"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_decrypt_no_key_or_name(self, test_db, client):
|
|
"""Channel decryption without key or name returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={"key_type": "channel"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "must provide" in data["detail"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_decrypt_no_undecrypted_packets(self, test_db, client):
|
|
"""Channel decryption with no undecrypted packets returns not started."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "channel",
|
|
"channel_key": "0123456789abcdef0123456789abcdef",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["started"] is False
|
|
assert data["total_packets"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_decrypt_resolves_channel_name(self, test_db, client):
|
|
"""Channel decryption finds display name from DB when channel exists."""
|
|
key_hex = "0123456789ABCDEF0123456789ABCDEF"
|
|
await ChannelRepository.upsert(key=key_hex, name="#test-channel", is_hashtag=True)
|
|
await _insert_raw_packets(1)
|
|
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "channel",
|
|
"channel_key": key_hex.lower(),
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
assert response.json()["started"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_contact_decrypt_missing_private_key(self, test_db, client):
|
|
"""Contact decryption without private key returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "contact",
|
|
"contact_public_key": "aa" * 32,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "private_key" in data["detail"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_contact_decrypt_missing_contact_key(self, test_db, client):
|
|
"""Contact decryption without contact public key returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "contact",
|
|
"private_key": "aa" * 64,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "contact_public_key" in data["detail"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_contact_decrypt_wrong_private_key_length(self, test_db, client):
|
|
"""Private key with wrong length returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "contact",
|
|
"private_key": "aa" * 32, # 32 bytes, need 64
|
|
"contact_public_key": "bb" * 32,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "64 bytes" in data["detail"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_contact_decrypt_wrong_public_key_length(self, test_db, client):
|
|
"""Contact public key with wrong length returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "contact",
|
|
"private_key": "aa" * 64,
|
|
"contact_public_key": "bb" * 16, # 16 bytes, need 32
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "32 bytes" in data["detail"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_contact_decrypt_invalid_hex(self, test_db, client):
|
|
"""Invalid hex for private key returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={
|
|
"key_type": "contact",
|
|
"private_key": "zz" * 64,
|
|
"contact_public_key": "bb" * 32,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "invalid" in data["detail"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_key_type(self, test_db, client):
|
|
"""Invalid key_type returns error."""
|
|
response = await client.post(
|
|
"/api/packets/decrypt/historical",
|
|
json={"key_type": "invalid"},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "key_type" in data["detail"].lower()
|
|
|
|
|
|
class TestRunHistoricalChannelDecryption:
|
|
"""Test the _run_historical_channel_decryption background task."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decrypts_matching_packets(self, test_db):
|
|
"""Background task decrypts packets that match the channel key."""
|
|
from app.routers.packets import _run_historical_channel_decryption
|
|
|
|
# Insert undecrypted packets
|
|
await _insert_raw_packets(3)
|
|
channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD"
|
|
channel_key_bytes = bytes.fromhex(channel_key_hex)
|
|
|
|
# Each packet must have unique content to avoid message deduplication
|
|
call_count = 0
|
|
|
|
def make_unique_result(*_args, **_kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return type(
|
|
"DecryptResult",
|
|
(),
|
|
{
|
|
"sender": f"User{call_count}",
|
|
"message": f"Hello {call_count}",
|
|
"timestamp": 1700000000 + call_count,
|
|
},
|
|
)()
|
|
|
|
with (
|
|
patch(
|
|
"app.routers.packets.try_decrypt_packet_with_channel_key",
|
|
side_effect=make_unique_result,
|
|
),
|
|
patch(
|
|
"app.routers.packets.parse_packet",
|
|
return_value=None,
|
|
),
|
|
patch("app.routers.packets.broadcast_success") as mock_success,
|
|
):
|
|
await _run_historical_channel_decryption(channel_key_bytes, channel_key_hex, "#test")
|
|
|
|
mock_success.assert_called_once()
|
|
assert "3" in mock_success.call_args[0][1] # "Decrypted 3 messages"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_non_matching_packets(self, test_db):
|
|
"""Background task skips packets that don't match the channel key."""
|
|
from app.routers.packets import _run_historical_channel_decryption
|
|
|
|
await _insert_raw_packets(2)
|
|
channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD"
|
|
channel_key_bytes = bytes.fromhex(channel_key_hex)
|
|
|
|
with (
|
|
patch(
|
|
"app.routers.packets.try_decrypt_packet_with_channel_key",
|
|
return_value=None, # No match
|
|
),
|
|
patch("app.routers.packets.broadcast_success") as mock_success,
|
|
):
|
|
await _run_historical_channel_decryption(channel_key_bytes, channel_key_hex, "#test")
|
|
|
|
# No success broadcast when nothing was decrypted
|
|
mock_success.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_packets_returns_early(self, test_db):
|
|
"""Background task returns early when no undecrypted packets exist."""
|
|
from app.routers.packets import _run_historical_channel_decryption
|
|
|
|
channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD"
|
|
channel_key_bytes = bytes.fromhex(channel_key_hex)
|
|
|
|
with patch("app.routers.packets.broadcast_success") as mock_success:
|
|
await _run_historical_channel_decryption(channel_key_bytes, channel_key_hex)
|
|
|
|
mock_success.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_display_name_fallback(self, test_db):
|
|
"""Uses channel key prefix when no display name is provided."""
|
|
from app.routers.packets import _run_historical_channel_decryption
|
|
|
|
await _insert_raw_packets(1)
|
|
channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD"
|
|
channel_key_bytes = bytes.fromhex(channel_key_hex)
|
|
|
|
mock_result = type(
|
|
"DecryptResult",
|
|
(),
|
|
{
|
|
"sender": "User",
|
|
"message": "msg",
|
|
"timestamp": 1700000000,
|
|
},
|
|
)()
|
|
|
|
with (
|
|
patch(
|
|
"app.routers.packets.try_decrypt_packet_with_channel_key",
|
|
return_value=mock_result,
|
|
),
|
|
patch("app.routers.packets.parse_packet", return_value=None),
|
|
patch("app.routers.packets.broadcast_success") as mock_success,
|
|
):
|
|
await _run_historical_channel_decryption(
|
|
channel_key_bytes,
|
|
channel_key_hex,
|
|
None, # No display name
|
|
)
|
|
|
|
# Should use key prefix as display name
|
|
call_msg = mock_success.call_args[0][0]
|
|
assert channel_key_hex[:12] in call_msg
|
|
|
|
|
|
class TestMaintenanceEndpoint:
|
|
"""Test POST /api/packets/maintenance."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prune_old_undecrypted(self, test_db, client):
|
|
"""Prune deletes undecrypted packets older than threshold."""
|
|
await _insert_raw_packets(3, decrypted=False, age_days=30)
|
|
await _insert_raw_packets(2, decrypted=False, age_days=0)
|
|
|
|
response = await client.post(
|
|
"/api/packets/maintenance",
|
|
json={"prune_undecrypted_days": 7},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["packets_deleted"] == 3
|
|
|
|
# Verify only recent packets remain
|
|
remaining = await RawPacketRepository.get_undecrypted_count()
|
|
assert remaining == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_purge_linked_raw_packets(self, test_db, client):
|
|
"""Purge deletes raw packets that are linked to stored messages."""
|
|
await _insert_raw_packets(3, decrypted=True)
|
|
await _insert_raw_packets(2, decrypted=False)
|
|
|
|
response = await client.post(
|
|
"/api/packets/maintenance",
|
|
json={"purge_linked_raw_packets": True},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["packets_deleted"] == 3
|
|
|
|
# Undecrypted packets should remain
|
|
remaining = await RawPacketRepository.get_undecrypted_count()
|
|
assert remaining == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_both_prune_and_purge(self, test_db, client):
|
|
"""Both prune and purge can run in a single request."""
|
|
await _insert_raw_packets(2, decrypted=True)
|
|
await _insert_raw_packets(3, decrypted=False, age_days=30)
|
|
await _insert_raw_packets(1, decrypted=False, age_days=0)
|
|
|
|
response = await client.post(
|
|
"/api/packets/maintenance",
|
|
json={
|
|
"prune_undecrypted_days": 7,
|
|
"purge_linked_raw_packets": True,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# 2 linked + 3 old undecrypted = 5 deleted
|
|
assert data["packets_deleted"] == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_options_deletes_nothing(self, test_db, client):
|
|
"""No options specified means no deletions (only vacuum)."""
|
|
await _insert_raw_packets(5)
|
|
|
|
response = await client.post(
|
|
"/api/packets/maintenance",
|
|
json={},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["packets_deleted"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vacuum_reports_status(self, test_db, client):
|
|
"""Maintenance endpoint reports vacuum status."""
|
|
response = await client.post(
|
|
"/api/packets/maintenance",
|
|
json={},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# vacuumed is a boolean (may be True or False depending on DB state)
|
|
assert isinstance(data["vacuumed"], bool)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prune_days_validation(self, test_db, client):
|
|
"""prune_undecrypted_days must be >= 1."""
|
|
response = await client.post(
|
|
"/api/packets/maintenance",
|
|
json={"prune_undecrypted_days": 0},
|
|
)
|
|
|
|
assert response.status_code == 422
|