Fix path data update race condition

This commit is contained in:
Jack Kingsman
2026-02-12 00:15:21 -08:00
parent 6ac5a1e7db
commit 8be5a22730
2 changed files with 97 additions and 141 deletions

View File

@@ -388,32 +388,29 @@ class MessageRepository:
"""
ts = received_at if received_at is not None else int(time.time())
# Get current paths
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
row = await cursor.fetchone()
if not row:
return []
# Parse existing paths or start with empty list
existing_paths = []
if row["paths"]:
try:
existing_paths = json.loads(row["paths"])
except json.JSONDecodeError:
existing_paths = []
# Add new path
existing_paths.append({"path": path, "received_at": ts})
# Update database
paths_json = json.dumps(existing_paths)
# Atomic append: use json_insert to avoid read-modify-write race when
# multiple duplicate packets arrive concurrently for the same message.
new_entry = json.dumps({"path": path, "received_at": ts})
await db.conn.execute(
"UPDATE messages SET paths = ? WHERE id = ?",
(paths_json, message_id),
"""UPDATE messages SET paths = json_insert(
COALESCE(paths, '[]'), '$[#]', json(?)
) WHERE id = ?""",
(new_entry, message_id),
)
await db.conn.commit()
return [MessagePath(**p) for p in existing_paths]
# Read back the full list for the return value
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
row = await cursor.fetchone()
if not row or not row["paths"]:
return []
try:
all_paths = json.loads(row["paths"])
except json.JSONDecodeError:
return []
return [MessagePath(**p) for p in all_paths]
@staticmethod
async def claim_prefix_messages(full_key: str) -> int:

View File

@@ -1,168 +1,127 @@
"""Tests for repository layer, specifically the add_path method."""
"""Tests for repository layer."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.database import Database
from app.repository import MessageRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database with the module-level db swapped in."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
async def _create_message(test_db, **overrides) -> int:
"""Helper to insert a message and return its id."""
defaults = {
"msg_type": "CHAN",
"text": "Hello",
"conversation_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0",
"sender_timestamp": 1700000000,
"received_at": 1700000000,
}
defaults.update(overrides)
msg_id = await MessageRepository.create(**defaults)
assert msg_id is not None
return msg_id
class TestMessageRepositoryAddPath:
"""Test MessageRepository.add_path method."""
"""Test MessageRepository.add_path against a real SQLite database."""
@pytest.mark.asyncio
async def test_add_path_to_message_with_no_existing_paths(self):
async def test_add_path_to_message_with_no_existing_paths(self, test_db):
"""Adding a path to a message with no existing paths creates a new array."""
msg_id = await _create_message(test_db)
# Mock the database connection
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(return_value={"paths": None})
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_conn.commit = AsyncMock()
mock_db = MagicMock()
mock_db.conn = mock_conn
with patch("app.repository.db", mock_db):
from app.repository import MessageRepository
result = await MessageRepository.add_path(
message_id=42, path="1A2B", received_at=1700000000
)
result = await MessageRepository.add_path(
message_id=msg_id, path="1A2B", received_at=1700000000
)
assert len(result) == 1
assert result[0].path == "1A2B"
assert result[0].received_at == 1700000000
# Verify the UPDATE was called with correct JSON
update_call = mock_conn.execute.call_args_list[-1]
assert update_call[0][0] == "UPDATE messages SET paths = ? WHERE id = ?"
paths_json = update_call[0][1][0]
parsed = json.loads(paths_json)
assert len(parsed) == 1
assert parsed[0]["path"] == "1A2B"
@pytest.mark.asyncio
async def test_add_path_to_message_with_existing_paths(self):
async def test_add_path_to_message_with_existing_paths(self, test_db):
"""Adding a path to a message with existing paths appends to the array."""
existing_paths = json.dumps([{"path": "1A", "received_at": 1699999999}])
msg_id = await _create_message(test_db)
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(return_value={"paths": existing_paths})
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_conn.commit = AsyncMock()
mock_db = MagicMock()
mock_db.conn = mock_conn
with patch("app.repository.db", mock_db):
from app.repository import MessageRepository
result = await MessageRepository.add_path(
message_id=42, path="2B3C", received_at=1700000000
)
await MessageRepository.add_path(
message_id=msg_id, path="1A", received_at=1699999999
)
result = await MessageRepository.add_path(
message_id=msg_id, path="2B3C", received_at=1700000000
)
assert len(result) == 2
assert result[0].path == "1A"
assert result[1].path == "2B3C"
# Verify the UPDATE contains both paths
update_call = mock_conn.execute.call_args_list[-1]
paths_json = update_call[0][1][0]
parsed = json.loads(paths_json)
assert len(parsed) == 2
assert parsed[0]["path"] == "1A"
assert parsed[1]["path"] == "2B3C"
@pytest.mark.asyncio
async def test_add_path_to_nonexistent_message_returns_empty(self):
async def test_add_path_to_nonexistent_message_returns_empty(self, test_db):
"""Adding a path to a nonexistent message returns empty list."""
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(return_value=None)
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_db = MagicMock()
mock_db.conn = mock_conn
with patch("app.repository.db", mock_db):
from app.repository import MessageRepository
result = await MessageRepository.add_path(
message_id=999, path="1A2B", received_at=1700000000
)
result = await MessageRepository.add_path(
message_id=999999, path="1A2B", received_at=1700000000
)
assert result == []
@pytest.mark.asyncio
async def test_add_path_handles_corrupted_json(self):
"""Adding a path handles corrupted JSON in existing paths gracefully."""
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(return_value={"paths": "not valid json {"})
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_conn.commit = AsyncMock()
mock_db = MagicMock()
mock_db.conn = mock_conn
with patch("app.repository.db", mock_db):
from app.repository import MessageRepository
result = await MessageRepository.add_path(
message_id=42, path="1A2B", received_at=1700000000
)
# Should recover and create new array with just the new path
assert len(result) == 1
assert result[0].path == "1A2B"
@pytest.mark.asyncio
async def test_add_path_uses_current_time_if_not_provided(self):
async def test_add_path_uses_current_time_if_not_provided(self, test_db):
"""Adding a path without received_at uses current timestamp."""
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(return_value={"paths": None})
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_conn.commit = AsyncMock()
msg_id = await _create_message(test_db)
mock_db = MagicMock()
mock_db.conn = mock_conn
with patch("app.repository.db", mock_db), patch("app.repository.time") as mock_time:
with patch("app.repository.time") as mock_time:
mock_time.time.return_value = 1700000500.5
from app.repository import MessageRepository
result = await MessageRepository.add_path(message_id=42, path="1A2B")
result = await MessageRepository.add_path(message_id=msg_id, path="1A2B")
assert len(result) == 1
assert result[0].received_at == 1700000500
@pytest.mark.asyncio
async def test_add_empty_path_for_direct_message(self):
async def test_add_empty_path_for_direct_message(self, test_db):
"""Adding an empty path (direct message) works correctly."""
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(return_value={"paths": None})
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_conn.commit = AsyncMock()
msg_id = await _create_message(test_db)
mock_db = MagicMock()
mock_db.conn = mock_conn
with patch("app.repository.db", mock_db):
from app.repository import MessageRepository
result = await MessageRepository.add_path(
message_id=42, path="", received_at=1700000000
)
result = await MessageRepository.add_path(
message_id=msg_id, path="", received_at=1700000000
)
assert len(result) == 1
assert result[0].path == "" # Empty path = direct
assert result[0].received_at == 1700000000
@pytest.mark.asyncio
async def test_add_multiple_paths_accumulate(self, test_db):
"""Multiple add_path calls accumulate all paths."""
msg_id = await _create_message(test_db)
await MessageRepository.add_path(msg_id, "", received_at=1700000001)
await MessageRepository.add_path(msg_id, "1A", received_at=1700000002)
result = await MessageRepository.add_path(msg_id, "1A2B", received_at=1700000003)
assert len(result) == 3
assert result[0].path == ""
assert result[1].path == "1A"
assert result[2].path == "1A2B"
class TestMessageRepositoryGetByContent:
"""Test MessageRepository.get_by_content method."""