"""Tests for the packets router. Covers the historical channel decryption endpoint, background task, undecrypted count endpoint, and the maintenance endpoint. """ import time from unittest.mock import patch import pytest from app.repository import ChannelRepository, MessageRepository, RawPacketRepository async def _insert_raw_packets(count: int, decrypted: bool = False, age_days: int = 0) -> list[int]: """Insert raw packets and return their IDs.""" ids = [] base_ts = int(time.time()) - (age_days * 86400) for i in range(count): packet_id, _ = await RawPacketRepository.create( f"packet_data_{i}_{age_days}_{decrypted}".encode(), base_ts + i ) if decrypted: # Create a message and link it msg_id = await MessageRepository.create( msg_type="CHAN", text=f"decrypted msg {i}", conversation_key="DEADBEEF" * 4, sender_timestamp=base_ts + i, received_at=base_ts + i, ) if msg_id is not None: await RawPacketRepository.mark_decrypted(packet_id, msg_id) ids.append(packet_id) return ids class TestUndecryptedCount: """Test GET /api/packets/undecrypted/count.""" @pytest.mark.asyncio async def test_returns_zero_when_empty(self, test_db, client): response = await client.get("/api/packets/undecrypted/count") assert response.status_code == 200 assert response.json()["count"] == 0 @pytest.mark.asyncio async def test_counts_only_undecrypted(self, test_db, client): await _insert_raw_packets(3, decrypted=False) await _insert_raw_packets(2, decrypted=True) response = await client.get("/api/packets/undecrypted/count") assert response.status_code == 200 assert response.json()["count"] == 3 class TestDecryptHistoricalPackets: """Test POST /api/packets/decrypt/historical.""" @pytest.mark.asyncio async def test_channel_decrypt_with_hex_key(self, test_db, client): """Channel decryption with a valid hex key starts background task.""" await _insert_raw_packets(5) response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "channel", "channel_key": "0123456789abcdef0123456789abcdef", }, ) assert response.status_code == 200 data = response.json() assert data["started"] is True assert data["total_packets"] == 5 assert "background" in data["message"].lower() @pytest.mark.asyncio async def test_channel_decrypt_with_hashtag_name(self, test_db, client): """Channel decryption with a channel name derives key from hash.""" await _insert_raw_packets(3) response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "channel", "channel_name": "#general", }, ) assert response.status_code == 200 data = response.json() assert data["started"] is True assert data["total_packets"] == 3 @pytest.mark.asyncio async def test_channel_decrypt_invalid_hex(self, test_db, client): """Invalid hex string for channel key returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "channel", "channel_key": "not_valid_hex", }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "invalid" in data["message"].lower() @pytest.mark.asyncio async def test_channel_decrypt_wrong_key_length(self, test_db, client): """Channel key with wrong length returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "channel", "channel_key": "aabbccdd", # Only 4 bytes, need 16 }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "16 bytes" in data["message"] @pytest.mark.asyncio async def test_channel_decrypt_no_key_or_name(self, test_db, client): """Channel decryption without key or name returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={"key_type": "channel"}, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "must provide" in data["message"].lower() @pytest.mark.asyncio async def test_channel_decrypt_no_undecrypted_packets(self, test_db, client): """Channel decryption with no undecrypted packets returns not started.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "channel", "channel_key": "0123456789abcdef0123456789abcdef", }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert data["total_packets"] == 0 @pytest.mark.asyncio async def test_channel_decrypt_resolves_channel_name(self, test_db, client): """Channel decryption finds display name from DB when channel exists.""" key_hex = "0123456789ABCDEF0123456789ABCDEF" await ChannelRepository.upsert(key=key_hex, name="#test-channel", is_hashtag=True) await _insert_raw_packets(1) response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "channel", "channel_key": key_hex.lower(), }, ) assert response.status_code == 200 assert response.json()["started"] is True @pytest.mark.asyncio async def test_contact_decrypt_missing_private_key(self, test_db, client): """Contact decryption without private key returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "contact", "contact_public_key": "aa" * 32, }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "private_key" in data["message"].lower() @pytest.mark.asyncio async def test_contact_decrypt_missing_contact_key(self, test_db, client): """Contact decryption without contact public key returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "contact", "private_key": "aa" * 64, }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "contact_public_key" in data["message"].lower() @pytest.mark.asyncio async def test_contact_decrypt_wrong_private_key_length(self, test_db, client): """Private key with wrong length returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "contact", "private_key": "aa" * 32, # 32 bytes, need 64 "contact_public_key": "bb" * 32, }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "64 bytes" in data["message"] @pytest.mark.asyncio async def test_contact_decrypt_wrong_public_key_length(self, test_db, client): """Contact public key with wrong length returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "contact", "private_key": "aa" * 64, "contact_public_key": "bb" * 16, # 16 bytes, need 32 }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "32 bytes" in data["message"] @pytest.mark.asyncio async def test_contact_decrypt_invalid_hex(self, test_db, client): """Invalid hex for private key returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={ "key_type": "contact", "private_key": "zz" * 64, "contact_public_key": "bb" * 32, }, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "invalid" in data["message"].lower() @pytest.mark.asyncio async def test_invalid_key_type(self, test_db, client): """Invalid key_type returns error.""" response = await client.post( "/api/packets/decrypt/historical", json={"key_type": "invalid"}, ) assert response.status_code == 200 data = response.json() assert data["started"] is False assert "key_type" in data["message"].lower() class TestRunHistoricalChannelDecryption: """Test the _run_historical_channel_decryption background task.""" @pytest.mark.asyncio async def test_decrypts_matching_packets(self, test_db): """Background task decrypts packets that match the channel key.""" from app.routers.packets import _run_historical_channel_decryption # Insert undecrypted packets await _insert_raw_packets(3) channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD" channel_key_bytes = bytes.fromhex(channel_key_hex) # Each packet must have unique content to avoid message deduplication call_count = 0 def make_unique_result(*_args, **_kwargs): nonlocal call_count call_count += 1 return type( "DecryptResult", (), { "sender": f"User{call_count}", "message": f"Hello {call_count}", "timestamp": 1700000000 + call_count, }, )() with ( patch( "app.routers.packets.try_decrypt_packet_with_channel_key", side_effect=make_unique_result, ), patch( "app.routers.packets.parse_packet", return_value=None, ), patch("app.routers.packets.broadcast_success") as mock_success, ): await _run_historical_channel_decryption(channel_key_bytes, channel_key_hex, "#test") mock_success.assert_called_once() assert "3" in mock_success.call_args[0][1] # "Decrypted 3 messages" @pytest.mark.asyncio async def test_skips_non_matching_packets(self, test_db): """Background task skips packets that don't match the channel key.""" from app.routers.packets import _run_historical_channel_decryption await _insert_raw_packets(2) channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD" channel_key_bytes = bytes.fromhex(channel_key_hex) with ( patch( "app.routers.packets.try_decrypt_packet_with_channel_key", return_value=None, # No match ), patch("app.routers.packets.broadcast_success") as mock_success, ): await _run_historical_channel_decryption(channel_key_bytes, channel_key_hex, "#test") # No success broadcast when nothing was decrypted mock_success.assert_not_called() @pytest.mark.asyncio async def test_no_packets_returns_early(self, test_db): """Background task returns early when no undecrypted packets exist.""" from app.routers.packets import _run_historical_channel_decryption channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD" channel_key_bytes = bytes.fromhex(channel_key_hex) with patch("app.routers.packets.broadcast_success") as mock_success: await _run_historical_channel_decryption(channel_key_bytes, channel_key_hex) mock_success.assert_not_called() @pytest.mark.asyncio async def test_display_name_fallback(self, test_db): """Uses channel key prefix when no display name is provided.""" from app.routers.packets import _run_historical_channel_decryption await _insert_raw_packets(1) channel_key_hex = "AABBCCDDAABBCCDDAABBCCDDAABBCCDD" channel_key_bytes = bytes.fromhex(channel_key_hex) mock_result = type( "DecryptResult", (), { "sender": "User", "message": "msg", "timestamp": 1700000000, }, )() with ( patch( "app.routers.packets.try_decrypt_packet_with_channel_key", return_value=mock_result, ), patch("app.routers.packets.parse_packet", return_value=None), patch("app.routers.packets.broadcast_success") as mock_success, ): await _run_historical_channel_decryption( channel_key_bytes, channel_key_hex, None, # No display name ) # Should use key prefix as display name call_msg = mock_success.call_args[0][0] assert channel_key_hex[:12] in call_msg class TestMaintenanceEndpoint: """Test POST /api/packets/maintenance.""" @pytest.mark.asyncio async def test_prune_old_undecrypted(self, test_db, client): """Prune deletes undecrypted packets older than threshold.""" await _insert_raw_packets(3, decrypted=False, age_days=30) await _insert_raw_packets(2, decrypted=False, age_days=0) response = await client.post( "/api/packets/maintenance", json={"prune_undecrypted_days": 7}, ) assert response.status_code == 200 data = response.json() assert data["packets_deleted"] == 3 # Verify only recent packets remain remaining = await RawPacketRepository.get_undecrypted_count() assert remaining == 2 @pytest.mark.asyncio async def test_purge_linked_raw_packets(self, test_db, client): """Purge deletes raw packets that are linked to stored messages.""" await _insert_raw_packets(3, decrypted=True) await _insert_raw_packets(2, decrypted=False) response = await client.post( "/api/packets/maintenance", json={"purge_linked_raw_packets": True}, ) assert response.status_code == 200 data = response.json() assert data["packets_deleted"] == 3 # Undecrypted packets should remain remaining = await RawPacketRepository.get_undecrypted_count() assert remaining == 2 @pytest.mark.asyncio async def test_both_prune_and_purge(self, test_db, client): """Both prune and purge can run in a single request.""" await _insert_raw_packets(2, decrypted=True) await _insert_raw_packets(3, decrypted=False, age_days=30) await _insert_raw_packets(1, decrypted=False, age_days=0) response = await client.post( "/api/packets/maintenance", json={ "prune_undecrypted_days": 7, "purge_linked_raw_packets": True, }, ) assert response.status_code == 200 data = response.json() # 2 linked + 3 old undecrypted = 5 deleted assert data["packets_deleted"] == 5 @pytest.mark.asyncio async def test_no_options_deletes_nothing(self, test_db, client): """No options specified means no deletions (only vacuum).""" await _insert_raw_packets(5) response = await client.post( "/api/packets/maintenance", json={}, ) assert response.status_code == 200 data = response.json() assert data["packets_deleted"] == 0 @pytest.mark.asyncio async def test_vacuum_reports_status(self, test_db, client): """Maintenance endpoint reports vacuum status.""" response = await client.post( "/api/packets/maintenance", json={}, ) assert response.status_code == 200 data = response.json() # vacuumed is a boolean (may be True or False depending on DB state) assert isinstance(data["vacuumed"], bool) @pytest.mark.asyncio async def test_prune_days_validation(self, test_db, client): """prune_undecrypted_days must be >= 1.""" response = await client.post( "/api/packets/maintenance", json={"prune_undecrypted_days": 0}, ) assert response.status_code == 422