diff --git a/app/repository.py b/app/repository.py index 5cafc07..9f0ce6f 100644 --- a/app/repository.py +++ b/app/repository.py @@ -388,32 +388,29 @@ class MessageRepository: """ ts = received_at if received_at is not None else int(time.time()) - # Get current paths - cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,)) - row = await cursor.fetchone() - if not row: - return [] - - # Parse existing paths or start with empty list - existing_paths = [] - if row["paths"]: - try: - existing_paths = json.loads(row["paths"]) - except json.JSONDecodeError: - existing_paths = [] - - # Add new path - existing_paths.append({"path": path, "received_at": ts}) - - # Update database - paths_json = json.dumps(existing_paths) + # Atomic append: use json_insert to avoid read-modify-write race when + # multiple duplicate packets arrive concurrently for the same message. + new_entry = json.dumps({"path": path, "received_at": ts}) await db.conn.execute( - "UPDATE messages SET paths = ? WHERE id = ?", - (paths_json, message_id), + """UPDATE messages SET paths = json_insert( + COALESCE(paths, '[]'), '$[#]', json(?) + ) WHERE id = ?""", + (new_entry, message_id), ) await db.conn.commit() - return [MessagePath(**p) for p in existing_paths] + # Read back the full list for the return value + cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,)) + row = await cursor.fetchone() + if not row or not row["paths"]: + return [] + + try: + all_paths = json.loads(row["paths"]) + except json.JSONDecodeError: + return [] + + return [MessagePath(**p) for p in all_paths] @staticmethod async def claim_prefix_messages(full_key: str) -> int: diff --git a/tests/test_repository.py b/tests/test_repository.py index 97851a4..7a5cb7f 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -1,168 +1,127 @@ -"""Tests for repository layer, specifically the add_path method.""" +"""Tests for repository layer.""" import json from unittest.mock import AsyncMock, MagicMock, patch import pytest +from app.database import Database +from app.repository import MessageRepository + + +@pytest.fixture +async def test_db(): + """Create an in-memory test database with the module-level db swapped in.""" + import app.repository as repo_module + + db = Database(":memory:") + await db.connect() + + original_db = repo_module.db + repo_module.db = db + + try: + yield db + finally: + repo_module.db = original_db + await db.disconnect() + + +async def _create_message(test_db, **overrides) -> int: + """Helper to insert a message and return its id.""" + defaults = { + "msg_type": "CHAN", + "text": "Hello", + "conversation_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0", + "sender_timestamp": 1700000000, + "received_at": 1700000000, + } + defaults.update(overrides) + msg_id = await MessageRepository.create(**defaults) + assert msg_id is not None + return msg_id + class TestMessageRepositoryAddPath: - """Test MessageRepository.add_path method.""" + """Test MessageRepository.add_path against a real SQLite database.""" @pytest.mark.asyncio - async def test_add_path_to_message_with_no_existing_paths(self): + async def test_add_path_to_message_with_no_existing_paths(self, test_db): """Adding a path to a message with no existing paths creates a new array.""" + msg_id = await _create_message(test_db) - # Mock the database connection - mock_conn = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchone = AsyncMock(return_value={"paths": None}) - mock_conn.execute = AsyncMock(return_value=mock_cursor) - mock_conn.commit = AsyncMock() - - mock_db = MagicMock() - mock_db.conn = mock_conn - - with patch("app.repository.db", mock_db): - from app.repository import MessageRepository - - result = await MessageRepository.add_path( - message_id=42, path="1A2B", received_at=1700000000 - ) + result = await MessageRepository.add_path( + message_id=msg_id, path="1A2B", received_at=1700000000 + ) assert len(result) == 1 assert result[0].path == "1A2B" assert result[0].received_at == 1700000000 - # Verify the UPDATE was called with correct JSON - update_call = mock_conn.execute.call_args_list[-1] - assert update_call[0][0] == "UPDATE messages SET paths = ? WHERE id = ?" - paths_json = update_call[0][1][0] - parsed = json.loads(paths_json) - assert len(parsed) == 1 - assert parsed[0]["path"] == "1A2B" - @pytest.mark.asyncio - async def test_add_path_to_message_with_existing_paths(self): + async def test_add_path_to_message_with_existing_paths(self, test_db): """Adding a path to a message with existing paths appends to the array.""" - existing_paths = json.dumps([{"path": "1A", "received_at": 1699999999}]) + msg_id = await _create_message(test_db) - mock_conn = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchone = AsyncMock(return_value={"paths": existing_paths}) - mock_conn.execute = AsyncMock(return_value=mock_cursor) - mock_conn.commit = AsyncMock() - - mock_db = MagicMock() - mock_db.conn = mock_conn - - with patch("app.repository.db", mock_db): - from app.repository import MessageRepository - - result = await MessageRepository.add_path( - message_id=42, path="2B3C", received_at=1700000000 - ) + await MessageRepository.add_path( + message_id=msg_id, path="1A", received_at=1699999999 + ) + result = await MessageRepository.add_path( + message_id=msg_id, path="2B3C", received_at=1700000000 + ) assert len(result) == 2 assert result[0].path == "1A" assert result[1].path == "2B3C" - # Verify the UPDATE contains both paths - update_call = mock_conn.execute.call_args_list[-1] - paths_json = update_call[0][1][0] - parsed = json.loads(paths_json) - assert len(parsed) == 2 - assert parsed[0]["path"] == "1A" - assert parsed[1]["path"] == "2B3C" - @pytest.mark.asyncio - async def test_add_path_to_nonexistent_message_returns_empty(self): + async def test_add_path_to_nonexistent_message_returns_empty(self, test_db): """Adding a path to a nonexistent message returns empty list.""" - mock_conn = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchone = AsyncMock(return_value=None) - mock_conn.execute = AsyncMock(return_value=mock_cursor) - - mock_db = MagicMock() - mock_db.conn = mock_conn - - with patch("app.repository.db", mock_db): - from app.repository import MessageRepository - - result = await MessageRepository.add_path( - message_id=999, path="1A2B", received_at=1700000000 - ) + result = await MessageRepository.add_path( + message_id=999999, path="1A2B", received_at=1700000000 + ) assert result == [] @pytest.mark.asyncio - async def test_add_path_handles_corrupted_json(self): - """Adding a path handles corrupted JSON in existing paths gracefully.""" - mock_conn = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchone = AsyncMock(return_value={"paths": "not valid json {"}) - mock_conn.execute = AsyncMock(return_value=mock_cursor) - mock_conn.commit = AsyncMock() - - mock_db = MagicMock() - mock_db.conn = mock_conn - - with patch("app.repository.db", mock_db): - from app.repository import MessageRepository - - result = await MessageRepository.add_path( - message_id=42, path="1A2B", received_at=1700000000 - ) - - # Should recover and create new array with just the new path - assert len(result) == 1 - assert result[0].path == "1A2B" - - @pytest.mark.asyncio - async def test_add_path_uses_current_time_if_not_provided(self): + async def test_add_path_uses_current_time_if_not_provided(self, test_db): """Adding a path without received_at uses current timestamp.""" - mock_conn = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchone = AsyncMock(return_value={"paths": None}) - mock_conn.execute = AsyncMock(return_value=mock_cursor) - mock_conn.commit = AsyncMock() + msg_id = await _create_message(test_db) - mock_db = MagicMock() - mock_db.conn = mock_conn - - with patch("app.repository.db", mock_db), patch("app.repository.time") as mock_time: + with patch("app.repository.time") as mock_time: mock_time.time.return_value = 1700000500.5 - - from app.repository import MessageRepository - - result = await MessageRepository.add_path(message_id=42, path="1A2B") + result = await MessageRepository.add_path(message_id=msg_id, path="1A2B") assert len(result) == 1 assert result[0].received_at == 1700000500 @pytest.mark.asyncio - async def test_add_empty_path_for_direct_message(self): + async def test_add_empty_path_for_direct_message(self, test_db): """Adding an empty path (direct message) works correctly.""" - mock_conn = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchone = AsyncMock(return_value={"paths": None}) - mock_conn.execute = AsyncMock(return_value=mock_cursor) - mock_conn.commit = AsyncMock() + msg_id = await _create_message(test_db) - mock_db = MagicMock() - mock_db.conn = mock_conn - - with patch("app.repository.db", mock_db): - from app.repository import MessageRepository - - result = await MessageRepository.add_path( - message_id=42, path="", received_at=1700000000 - ) + result = await MessageRepository.add_path( + message_id=msg_id, path="", received_at=1700000000 + ) assert len(result) == 1 assert result[0].path == "" # Empty path = direct assert result[0].received_at == 1700000000 + @pytest.mark.asyncio + async def test_add_multiple_paths_accumulate(self, test_db): + """Multiple add_path calls accumulate all paths.""" + msg_id = await _create_message(test_db) + + await MessageRepository.add_path(msg_id, "", received_at=1700000001) + await MessageRepository.add_path(msg_id, "1A", received_at=1700000002) + result = await MessageRepository.add_path(msg_id, "1A2B", received_at=1700000003) + + assert len(result) == 3 + assert result[0].path == "" + assert result[1].path == "1A" + assert result[2].path == "1A2B" + class TestMessageRepositoryGetByContent: """Test MessageRepository.get_by_content method."""