mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-03-28 17:43:05 +01:00
Fix path data update race condition
This commit is contained in:
@@ -388,32 +388,29 @@ class MessageRepository:
|
||||
"""
|
||||
ts = received_at if received_at is not None else int(time.time())
|
||||
|
||||
# Get current paths
|
||||
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return []
|
||||
|
||||
# Parse existing paths or start with empty list
|
||||
existing_paths = []
|
||||
if row["paths"]:
|
||||
try:
|
||||
existing_paths = json.loads(row["paths"])
|
||||
except json.JSONDecodeError:
|
||||
existing_paths = []
|
||||
|
||||
# Add new path
|
||||
existing_paths.append({"path": path, "received_at": ts})
|
||||
|
||||
# Update database
|
||||
paths_json = json.dumps(existing_paths)
|
||||
# Atomic append: use json_insert to avoid read-modify-write race when
|
||||
# multiple duplicate packets arrive concurrently for the same message.
|
||||
new_entry = json.dumps({"path": path, "received_at": ts})
|
||||
await db.conn.execute(
|
||||
"UPDATE messages SET paths = ? WHERE id = ?",
|
||||
(paths_json, message_id),
|
||||
"""UPDATE messages SET paths = json_insert(
|
||||
COALESCE(paths, '[]'), '$[#]', json(?)
|
||||
) WHERE id = ?""",
|
||||
(new_entry, message_id),
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
return [MessagePath(**p) for p in existing_paths]
|
||||
# Read back the full list for the return value
|
||||
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
|
||||
row = await cursor.fetchone()
|
||||
if not row or not row["paths"]:
|
||||
return []
|
||||
|
||||
try:
|
||||
all_paths = json.loads(row["paths"])
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
return [MessagePath(**p) for p in all_paths]
|
||||
|
||||
@staticmethod
|
||||
async def claim_prefix_messages(full_key: str) -> int:
|
||||
|
||||
@@ -1,168 +1,127 @@
|
||||
"""Tests for repository layer, specifically the add_path method."""
|
||||
"""Tests for repository layer."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.database import Database
|
||||
from app.repository import MessageRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db():
|
||||
"""Create an in-memory test database with the module-level db swapped in."""
|
||||
import app.repository as repo_module
|
||||
|
||||
db = Database(":memory:")
|
||||
await db.connect()
|
||||
|
||||
original_db = repo_module.db
|
||||
repo_module.db = db
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
repo_module.db = original_db
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
async def _create_message(test_db, **overrides) -> int:
|
||||
"""Helper to insert a message and return its id."""
|
||||
defaults = {
|
||||
"msg_type": "CHAN",
|
||||
"text": "Hello",
|
||||
"conversation_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0",
|
||||
"sender_timestamp": 1700000000,
|
||||
"received_at": 1700000000,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
msg_id = await MessageRepository.create(**defaults)
|
||||
assert msg_id is not None
|
||||
return msg_id
|
||||
|
||||
|
||||
class TestMessageRepositoryAddPath:
|
||||
"""Test MessageRepository.add_path method."""
|
||||
"""Test MessageRepository.add_path against a real SQLite database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_path_to_message_with_no_existing_paths(self):
|
||||
async def test_add_path_to_message_with_no_existing_paths(self, test_db):
|
||||
"""Adding a path to a message with no existing paths creates a new array."""
|
||||
msg_id = await _create_message(test_db)
|
||||
|
||||
# Mock the database connection
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(return_value={"paths": None})
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_conn.commit = AsyncMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.conn = mock_conn
|
||||
|
||||
with patch("app.repository.db", mock_db):
|
||||
from app.repository import MessageRepository
|
||||
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=42, path="1A2B", received_at=1700000000
|
||||
)
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=msg_id, path="1A2B", received_at=1700000000
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].path == "1A2B"
|
||||
assert result[0].received_at == 1700000000
|
||||
|
||||
# Verify the UPDATE was called with correct JSON
|
||||
update_call = mock_conn.execute.call_args_list[-1]
|
||||
assert update_call[0][0] == "UPDATE messages SET paths = ? WHERE id = ?"
|
||||
paths_json = update_call[0][1][0]
|
||||
parsed = json.loads(paths_json)
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0]["path"] == "1A2B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_path_to_message_with_existing_paths(self):
|
||||
async def test_add_path_to_message_with_existing_paths(self, test_db):
|
||||
"""Adding a path to a message with existing paths appends to the array."""
|
||||
existing_paths = json.dumps([{"path": "1A", "received_at": 1699999999}])
|
||||
msg_id = await _create_message(test_db)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(return_value={"paths": existing_paths})
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_conn.commit = AsyncMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.conn = mock_conn
|
||||
|
||||
with patch("app.repository.db", mock_db):
|
||||
from app.repository import MessageRepository
|
||||
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=42, path="2B3C", received_at=1700000000
|
||||
)
|
||||
await MessageRepository.add_path(
|
||||
message_id=msg_id, path="1A", received_at=1699999999
|
||||
)
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=msg_id, path="2B3C", received_at=1700000000
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].path == "1A"
|
||||
assert result[1].path == "2B3C"
|
||||
|
||||
# Verify the UPDATE contains both paths
|
||||
update_call = mock_conn.execute.call_args_list[-1]
|
||||
paths_json = update_call[0][1][0]
|
||||
parsed = json.loads(paths_json)
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0]["path"] == "1A"
|
||||
assert parsed[1]["path"] == "2B3C"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_path_to_nonexistent_message_returns_empty(self):
|
||||
async def test_add_path_to_nonexistent_message_returns_empty(self, test_db):
|
||||
"""Adding a path to a nonexistent message returns empty list."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(return_value=None)
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.conn = mock_conn
|
||||
|
||||
with patch("app.repository.db", mock_db):
|
||||
from app.repository import MessageRepository
|
||||
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=999, path="1A2B", received_at=1700000000
|
||||
)
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=999999, path="1A2B", received_at=1700000000
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_path_handles_corrupted_json(self):
|
||||
"""Adding a path handles corrupted JSON in existing paths gracefully."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(return_value={"paths": "not valid json {"})
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_conn.commit = AsyncMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.conn = mock_conn
|
||||
|
||||
with patch("app.repository.db", mock_db):
|
||||
from app.repository import MessageRepository
|
||||
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=42, path="1A2B", received_at=1700000000
|
||||
)
|
||||
|
||||
# Should recover and create new array with just the new path
|
||||
assert len(result) == 1
|
||||
assert result[0].path == "1A2B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_path_uses_current_time_if_not_provided(self):
|
||||
async def test_add_path_uses_current_time_if_not_provided(self, test_db):
|
||||
"""Adding a path without received_at uses current timestamp."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(return_value={"paths": None})
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_conn.commit = AsyncMock()
|
||||
msg_id = await _create_message(test_db)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.conn = mock_conn
|
||||
|
||||
with patch("app.repository.db", mock_db), patch("app.repository.time") as mock_time:
|
||||
with patch("app.repository.time") as mock_time:
|
||||
mock_time.time.return_value = 1700000500.5
|
||||
|
||||
from app.repository import MessageRepository
|
||||
|
||||
result = await MessageRepository.add_path(message_id=42, path="1A2B")
|
||||
result = await MessageRepository.add_path(message_id=msg_id, path="1A2B")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].received_at == 1700000500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_empty_path_for_direct_message(self):
|
||||
async def test_add_empty_path_for_direct_message(self, test_db):
|
||||
"""Adding an empty path (direct message) works correctly."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(return_value={"paths": None})
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_conn.commit = AsyncMock()
|
||||
msg_id = await _create_message(test_db)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.conn = mock_conn
|
||||
|
||||
with patch("app.repository.db", mock_db):
|
||||
from app.repository import MessageRepository
|
||||
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=42, path="", received_at=1700000000
|
||||
)
|
||||
result = await MessageRepository.add_path(
|
||||
message_id=msg_id, path="", received_at=1700000000
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].path == "" # Empty path = direct
|
||||
assert result[0].received_at == 1700000000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_multiple_paths_accumulate(self, test_db):
|
||||
"""Multiple add_path calls accumulate all paths."""
|
||||
msg_id = await _create_message(test_db)
|
||||
|
||||
await MessageRepository.add_path(msg_id, "", received_at=1700000001)
|
||||
await MessageRepository.add_path(msg_id, "1A", received_at=1700000002)
|
||||
result = await MessageRepository.add_path(msg_id, "1A2B", received_at=1700000003)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0].path == ""
|
||||
assert result[1].path == "1A"
|
||||
assert result[2].path == "1A2B"
|
||||
|
||||
|
||||
class TestMessageRepositoryGetByContent:
|
||||
"""Test MessageRepository.get_by_content method."""
|
||||
|
||||
Reference in New Issue
Block a user