diff --git a/app/event_handlers.py b/app/event_handlers.py index 527a842..dc56b54 100644 --- a/app/event_handlers.py +++ b/app/event_handlers.py @@ -1,3 +1,4 @@ +import asyncio import logging import time from typing import TYPE_CHECKING @@ -140,16 +141,18 @@ async def on_contact_message(event: "Event") -> None: # Run bot if enabled from app.bot import run_bot_for_message - await run_bot_for_message( - sender_name=contact.name if contact else None, - sender_key=sender_pubkey, - message_text=payload.get("text", ""), - is_dm=True, - channel_key=None, - channel_name=None, - sender_timestamp=payload.get("sender_timestamp"), - path=payload.get("path"), - is_outgoing=False, + asyncio.create_task( + run_bot_for_message( + sender_name=contact.name if contact else None, + sender_key=sender_pubkey, + message_text=payload.get("text", ""), + is_dm=True, + channel_key=None, + channel_name=None, + sender_timestamp=payload.get("sender_timestamp"), + path=payload.get("path"), + is_outgoing=False, + ) ) diff --git a/app/packet_processor.py b/app/packet_processor.py index 7377441..fec80ed 100644 --- a/app/packet_processor.py +++ b/app/packet_processor.py @@ -192,16 +192,18 @@ async def create_message_from_decrypted( if trigger_bot: from app.bot import run_bot_for_message - await run_bot_for_message( - sender_name=sender, - sender_key=None, # Channel messages don't have a sender public key - message_text=message_text, - is_dm=False, - channel_key=channel_key_normalized, - channel_name=channel_name, - sender_timestamp=timestamp, - path=path, - is_outgoing=False, + asyncio.create_task( + run_bot_for_message( + sender_name=sender, + sender_key=None, # Channel messages don't have a sender public key + message_text=message_text, + is_dm=False, + channel_key=channel_key_normalized, + channel_name=channel_name, + sender_timestamp=timestamp, + path=path, + is_outgoing=False, + ) ) return msg_id @@ -316,16 +318,18 @@ async def create_dm_message_from_decrypted( contact = await ContactRepository.get_by_key(their_public_key) sender_name = contact.name if contact else None - await run_bot_for_message( - sender_name=sender_name, - sender_key=their_public_key, - message_text=decrypted.message, - is_dm=True, - channel_key=None, - channel_name=None, - sender_timestamp=decrypted.timestamp, - path=path, - is_outgoing=outgoing, + asyncio.create_task( + run_bot_for_message( + sender_name=sender_name, + sender_key=their_public_key, + message_text=decrypted.message, + is_dm=True, + channel_key=None, + channel_name=None, + sender_timestamp=decrypted.timestamp, + path=path, + is_outgoing=outgoing, + ) ) return msg_id diff --git a/tests/test_event_handlers.py b/tests/test_event_handlers.py index 12ec830..c35f7f4 100644 --- a/tests/test_event_handlers.py +++ b/tests/test_event_handlers.py @@ -183,6 +183,38 @@ class TestContactMessageCLIFiltering: # Should NOT update contact last_contacted mock_contact_repo.update_last_contacted.assert_not_called() + @pytest.mark.asyncio + async def test_normal_message_schedules_bot_in_background(self): + """Normal messages should schedule bot execution without blocking.""" + from app.event_handlers import on_contact_message + + def _capture_task(coro): + coro.close() + return MagicMock() + + with ( + patch("app.event_handlers.MessageRepository") as mock_repo, + patch("app.event_handlers.ContactRepository") as mock_contact_repo, + patch("app.event_handlers.broadcast_event"), + patch("app.event_handlers.asyncio.create_task", side_effect=_capture_task) as mock_task, + patch("app.bot.run_bot_for_message", new_callable=AsyncMock) as mock_bot, + ): + mock_repo.create = AsyncMock(return_value=42) + mock_contact_repo.get_by_key_or_prefix = AsyncMock(return_value=None) + + class MockEvent: + payload = { + "pubkey_prefix": "abc123def456", + "text": "Hello, bot", + "txt_type": 0, + "sender_timestamp": 1700000000, + } + + await on_contact_message(MockEvent()) + + mock_task.assert_called_once() + mock_bot.assert_called_once() + @pytest.mark.asyncio async def test_normal_message_still_processed(self): """Normal messages (txt_type=0) are still processed normally.""" diff --git a/tests/test_packet_pipeline.py b/tests/test_packet_pipeline.py index 498f45f..382d3b2 100644 --- a/tests/test_packet_pipeline.py +++ b/tests/test_packet_pipeline.py @@ -495,6 +495,40 @@ class TestAckPipeline: class TestCreateMessageFromDecrypted: """Test the shared message creation function used by both real-time and historical decryption.""" + @pytest.mark.asyncio + async def test_schedules_bot_in_background(self, test_db, captured_broadcasts): + """Bot execution is scheduled and does not block channel message persistence.""" + from app.packet_processor import create_message_from_decrypted + + packet_id, _ = await RawPacketRepository.create(b"test_packet_bot_channel", 1700000000) + broadcasts, mock_broadcast = captured_broadcasts + + def _capture_task(coro): + coro.close() + return MagicMock() + + with ( + patch("app.packet_processor.broadcast_event", mock_broadcast), + patch( + "app.packet_processor.asyncio.create_task", side_effect=_capture_task + ) as mock_task, + patch("app.bot.run_bot_for_message", new_callable=AsyncMock) as mock_bot, + ): + msg_id = await create_message_from_decrypted( + packet_id=packet_id, + channel_key="ABC123DEF456", + sender="BotTrigger", + message_text="Hello from channel", + timestamp=1700000000, + received_at=1700000001, + trigger_bot=True, + ) + + assert msg_id is not None + mock_task.assert_called_once() + mock_bot.assert_called_once() + assert mock_bot.await_count == 0 + @pytest.mark.asyncio async def test_creates_message_and_broadcasts(self, test_db, captured_broadcasts): """create_message_from_decrypted creates message and broadcasts correctly.""" @@ -712,6 +746,48 @@ class TestCreateDMMessageFromDecrypted: FACE12_PUB = "FACE123334789E2B81519AFDBC39A3C9EB7EA3457AD367D3243597A484847E46" A1B2C3_PUB = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7" + @pytest.mark.asyncio + async def test_schedules_bot_in_background(self, test_db, captured_broadcasts): + """Bot execution is scheduled and does not block DM persistence.""" + from app.decoder import DecryptedDirectMessage + from app.packet_processor import create_dm_message_from_decrypted + + packet_id, _ = await RawPacketRepository.create(b"test_packet_bot_dm", 1700000000) + decrypted = DecryptedDirectMessage( + timestamp=1700000000, + flags=0, + message="Hello from DM", + dest_hash="fa", + src_hash="a1", + ) + broadcasts, mock_broadcast = captured_broadcasts + + def _capture_task(coro): + coro.close() + return MagicMock() + + with ( + patch("app.packet_processor.broadcast_event", mock_broadcast), + patch( + "app.packet_processor.asyncio.create_task", side_effect=_capture_task + ) as mock_task, + patch("app.bot.run_bot_for_message", new_callable=AsyncMock) as mock_bot, + ): + msg_id = await create_dm_message_from_decrypted( + packet_id=packet_id, + decrypted=decrypted, + their_public_key=self.A1B2C3_PUB, + our_public_key=self.FACE12_PUB, + received_at=1700000001, + outgoing=False, + trigger_bot=True, + ) + + assert msg_id is not None + mock_task.assert_called_once() + mock_bot.assert_called_once() + assert mock_bot.await_count == 0 + @pytest.mark.asyncio async def test_creates_dm_message_and_broadcasts(self, test_db, captured_broadcasts): """create_dm_message_from_decrypted creates message and broadcasts correctly."""