diff --git a/app/bot.py b/app/bot.py index a648cea5..913a5394 100644 --- a/app/bot.py +++ b/app/bot.py @@ -12,6 +12,7 @@ the security implications. import asyncio import logging +import time from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -28,6 +29,14 @@ _bot_executor = ThreadPoolExecutor(max_workers=3, thread_name_prefix="bot_") # Timeout for bot code execution (seconds) BOT_EXECUTION_TIMEOUT = 10 +# Minimum spacing between bot message sends (seconds) +# This ensures repeaters have time to return to listening mode +BOT_MESSAGE_SPACING = 2.0 + +# Global state for rate limiting bot sends +_bot_send_lock = asyncio.Lock() +_last_bot_send_time: float = 0.0 + def execute_bot_code( code: str, @@ -124,35 +133,57 @@ async def process_bot_response( For DMs, sends a direct message back to the sender. For channel messages, sends to the same channel. + Bot messages are rate-limited to ensure at least BOT_MESSAGE_SPACING seconds + between sends, giving repeaters time to return to listening mode. + Args: response: The response text to send is_dm: Whether the original message was a DM sender_key: Public key of the original sender (for DM replies) channel_key: Channel key for channel message replies """ + global _last_bot_send_time + from app.models import SendChannelMessageRequest, SendDirectMessageRequest from app.routers.messages import send_channel_message, send_direct_message from app.websocket import broadcast_event - try: - if is_dm: - logger.info("Bot sending DM reply to %s", sender_key[:12]) - request = SendDirectMessageRequest(destination=sender_key, text=response) - message = await send_direct_message(request) - # Broadcast to WebSocket (endpoint returns to HTTP caller, bot needs explicit broadcast) - broadcast_event("message", message.model_dump()) - elif channel_key: - logger.info("Bot sending channel reply to %s", channel_key[:8]) - request = SendChannelMessageRequest(channel_key=channel_key, text=response) - message = await send_channel_message(request) - # Broadcast to WebSocket - broadcast_event("message", message.model_dump()) - else: - logger.warning("Cannot send bot response: no destination") - except HTTPException as e: - logger.error("Bot failed to send response: %s", e.detail) - except Exception as e: - logger.error("Bot failed to send response: %s", e) + # Serialize bot sends and enforce minimum spacing + async with _bot_send_lock: + # Calculate how long since last bot send + now = time.monotonic() + time_since_last = now - _last_bot_send_time + + if _last_bot_send_time > 0 and time_since_last < BOT_MESSAGE_SPACING: + wait_time = BOT_MESSAGE_SPACING - time_since_last + logger.debug("Rate limiting bot send, waiting %.2fs", wait_time) + await asyncio.sleep(wait_time) + + try: + if is_dm: + logger.info("Bot sending DM reply to %s", sender_key[:12]) + request = SendDirectMessageRequest(destination=sender_key, text=response) + message = await send_direct_message(request) + # Broadcast to WebSocket (endpoint returns to HTTP caller, bot needs explicit broadcast) + broadcast_event("message", message.model_dump()) + elif channel_key: + logger.info("Bot sending channel reply to %s", channel_key[:8]) + request = SendChannelMessageRequest(channel_key=channel_key, text=response) + message = await send_channel_message(request) + # Broadcast to WebSocket + broadcast_event("message", message.model_dump()) + else: + logger.warning("Cannot send bot response: no destination") + return # Don't update timestamp if we didn't send + except HTTPException as e: + logger.error("Bot failed to send response: %s", e.detail) + return # Don't update timestamp on failure + except Exception as e: + logger.error("Bot failed to send response: %s", e) + return # Don't update timestamp on failure + + # Update last send time after successful send + _last_bot_send_time = time.monotonic() async def run_bot_for_message( diff --git a/tests/test_bot.py b/tests/test_bot.py index aaad7575..1f46b616 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -1,12 +1,16 @@ """Tests for the bot execution module.""" +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest +import app.bot as bot_module from app.bot import ( + BOT_MESSAGE_SPACING, _bot_semaphore, execute_bot_code, + process_bot_response, run_bot_for_message, ) @@ -381,3 +385,226 @@ class TestBotCodeValidation: # Should not raise validate_bot_code("") validate_bot_code(" ") + + +class TestBotMessageRateLimiting: + """Test bot message rate limiting for repeater compatibility.""" + + @pytest.fixture(autouse=True) + def reset_rate_limit_state(self): + """Reset rate limiting state between tests.""" + bot_module._last_bot_send_time = 0.0 + yield + bot_module._last_bot_send_time = 0.0 + + @pytest.mark.asyncio + async def test_first_send_does_not_wait(self): + """First bot send should not wait (no previous send).""" + with ( + patch("app.bot.time.monotonic", return_value=100.0), + patch("app.bot.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + patch("app.routers.messages.send_direct_message", new_callable=AsyncMock) as mock_send, + patch("app.websocket.broadcast_event"), + ): + mock_message = MagicMock() + mock_message.model_dump.return_value = {} + mock_send.return_value = mock_message + + await process_bot_response( + response="Hello!", + is_dm=True, + sender_key="abc123def456" * 4, # 64 chars + channel_key=None, + ) + + # Should not have slept (first send, _last_bot_send_time was 0) + mock_sleep.assert_not_called() + mock_send.assert_called_once() + + @pytest.mark.asyncio + async def test_rapid_second_send_waits(self): + """Second send within spacing window should wait.""" + # Previous send was at 100.0, current time is 100.5 (0.5 seconds later) + # So we need to wait 1.5 more seconds to reach 2.0 second spacing + bot_module._last_bot_send_time = 100.0 + + with ( + patch("app.bot.time.monotonic", return_value=100.5), + patch("app.bot.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + patch("app.routers.messages.send_direct_message", new_callable=AsyncMock) as mock_send, + patch("app.websocket.broadcast_event"), + ): + mock_message = MagicMock() + mock_message.model_dump.return_value = {} + mock_send.return_value = mock_message + + await process_bot_response( + response="Hello again!", + is_dm=True, + sender_key="abc123def456" * 4, + channel_key=None, + ) + + # Should have waited 1.5 seconds (2.0 - 0.5 elapsed) + mock_sleep.assert_called_once() + wait_time = mock_sleep.call_args[0][0] + assert abs(wait_time - 1.5) < 0.01 + + @pytest.mark.asyncio + async def test_send_after_spacing_does_not_wait(self): + """Send after spacing window should not wait.""" + # Simulate a previous send 3 seconds ago (> BOT_MESSAGE_SPACING) + bot_module._last_bot_send_time = 97.0 + + with ( + patch("app.bot.time.monotonic", return_value=100.0), + patch("app.bot.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + patch("app.routers.messages.send_direct_message", new_callable=AsyncMock) as mock_send, + patch("app.websocket.broadcast_event"), + ): + mock_message = MagicMock() + mock_message.model_dump.return_value = {} + mock_send.return_value = mock_message + + await process_bot_response( + response="Hello!", + is_dm=True, + sender_key="abc123def456" * 4, + channel_key=None, + ) + + # Should not have slept (3 seconds > 2 second spacing) + mock_sleep.assert_not_called() + + @pytest.mark.asyncio + async def test_timestamp_updated_after_successful_send(self): + """Last send timestamp should be updated after successful send.""" + with ( + patch("app.bot.time.monotonic", return_value=150.0), + patch("app.routers.messages.send_direct_message", new_callable=AsyncMock) as mock_send, + patch("app.websocket.broadcast_event"), + ): + mock_message = MagicMock() + mock_message.model_dump.return_value = {} + mock_send.return_value = mock_message + + await process_bot_response( + response="Hello!", + is_dm=True, + sender_key="abc123def456" * 4, + channel_key=None, + ) + + assert bot_module._last_bot_send_time == 150.0 + + @pytest.mark.asyncio + async def test_timestamp_not_updated_on_failure(self): + """Last send timestamp should NOT be updated if send fails.""" + from fastapi import HTTPException + + bot_module._last_bot_send_time = 50.0 # Previous timestamp + + with ( + patch("app.bot.time.monotonic", return_value=100.0), + patch( + "app.routers.messages.send_direct_message", + new_callable=AsyncMock, + side_effect=HTTPException(status_code=500, detail="Send failed"), + ), + ): + await process_bot_response( + response="Hello!", + is_dm=True, + sender_key="abc123def456" * 4, + channel_key=None, + ) + + # Timestamp should remain unchanged + assert bot_module._last_bot_send_time == 50.0 + + @pytest.mark.asyncio + async def test_timestamp_not_updated_on_no_destination(self): + """Last send timestamp should NOT be updated if no destination.""" + bot_module._last_bot_send_time = 50.0 + + with patch("app.bot.time.monotonic", return_value=100.0): + await process_bot_response( + response="Hello!", + is_dm=False, # Not a DM + sender_key="", + channel_key=None, # No channel either + ) + + # Timestamp should remain unchanged + assert bot_module._last_bot_send_time == 50.0 + + @pytest.mark.asyncio + async def test_concurrent_sends_are_serialized(self): + """Multiple concurrent sends should be serialized by the lock.""" + send_order = [] + send_times = [] + + async def mock_send(*args, **kwargs): + send_order.append(len(send_order)) + send_times.append(bot_module.time.monotonic()) + mock_message = MagicMock() + mock_message.model_dump.return_value = {} + return mock_message + + # Use a real monotonic-like counter for this test + time_counter = [100.0] + + def mock_monotonic(): + return time_counter[0] + + async def mock_sleep(duration): + time_counter[0] += duration + + with ( + patch("app.bot.time.monotonic", side_effect=mock_monotonic), + patch("app.bot.asyncio.sleep", side_effect=mock_sleep), + patch("app.routers.messages.send_direct_message", side_effect=mock_send), + patch("app.websocket.broadcast_event"), + ): + # Launch 3 concurrent sends + await asyncio.gather( + process_bot_response("Msg 1", True, "a" * 64, None), + process_bot_response("Msg 2", True, "b" * 64, None), + process_bot_response("Msg 3", True, "c" * 64, None), + ) + + # All 3 should have sent + assert len(send_order) == 3 + + # Each send should be at least BOT_MESSAGE_SPACING apart + # First send at 100, second at 102, third at 104 + assert send_times[1] >= send_times[0] + BOT_MESSAGE_SPACING - 0.01 + assert send_times[2] >= send_times[1] + BOT_MESSAGE_SPACING - 0.01 + + @pytest.mark.asyncio + async def test_channel_message_rate_limited(self): + """Channel message sends should also be rate limited.""" + bot_module._last_bot_send_time = 99.0 # 1 second ago + + with ( + patch("app.bot.time.monotonic", return_value=100.0), + patch("app.bot.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + patch("app.routers.messages.send_channel_message", new_callable=AsyncMock) as mock_send, + patch("app.websocket.broadcast_event"), + ): + mock_message = MagicMock() + mock_message.model_dump.return_value = {} + mock_send.return_value = mock_message + + await process_bot_response( + response="Channel hello!", + is_dm=False, + sender_key="", + channel_key="AABBCCDD" * 4, + ) + + # Should have waited 1 second (2.0 - 1.0 elapsed) + mock_sleep.assert_called_once() + wait_time = mock_sleep.call_args[0][0] + assert abs(wait_time - 1.0) < 0.01 + mock_send.assert_called_once()