Cleanups: Normalize pub keys, prefix message claiming, cursor + null timestamp DB cleanups

This commit is contained in:
Jack Kingsman
2026-02-02 16:22:10 -08:00
parent ea5283dd43
commit f8b05bb34d
19 changed files with 563 additions and 64 deletions

View File

@@ -77,6 +77,9 @@ async def on_contact_message(event: "Event") -> None:
if contact:
sender_pubkey = contact.public_key.lower()
# Promote any prefix-stored messages to this full key
await MessageRepository.claim_prefix_messages(sender_pubkey)
# Skip messages from repeaters - they only send CLI responses, not chat messages.
# CLI responses are handled by the command endpoint and txt_type filter above.
if contact.type == CONTACT_TYPE_REPEATER:
@@ -92,7 +95,7 @@ async def on_contact_message(event: "Event") -> None:
msg_type="PRIV",
text=payload.get("text", ""),
conversation_key=sender_pubkey,
sender_timestamp=payload.get("sender_timestamp"),
sender_timestamp=payload.get("sender_timestamp") or received_at,
received_at=received_at,
path=payload.get("path"),
txt_type=txt_type,
@@ -132,7 +135,7 @@ async def on_contact_message(event: "Event") -> None:
# Update contact last_contacted (contact was already fetched above)
if contact:
await ContactRepository.update_last_contacted(contact.public_key, received_at)
await ContactRepository.update_last_contacted(sender_pubkey, received_at)
# Run bot if enabled
from app.bot import run_bot_for_message

View File

@@ -128,6 +128,20 @@ async def run_migrations(conn: aiosqlite.Connection) -> int:
await set_version(conn, 13)
applied += 1
# Migration 14: Lowercase all contact public keys and related data
if version < 14:
logger.info("Applying migration 14: lowercase all contact public keys")
await _migrate_014_lowercase_public_keys(conn)
await set_version(conn, 14)
applied += 1
# Migration 15: Fix NULL sender_timestamp and add null-safe dedup index
if version < 15:
logger.info("Applying migration 15: fix NULL sender_timestamp values")
await _migrate_015_fix_null_sender_timestamp(conn)
await set_version(conn, 15)
applied += 1
if applied > 0:
logger.info(
"Applied %d migration(s), schema now at version %d", applied, await get_version(conn)
@@ -793,3 +807,189 @@ async def _migrate_013_convert_to_multi_bot(conn: aiosqlite.Connection) -> None:
raise
await conn.commit()
async def _migrate_014_lowercase_public_keys(conn: aiosqlite.Connection) -> None:
"""
Lowercase all contact public keys and related data for case-insensitive matching.
Updates:
- contacts.public_key (PRIMARY KEY) via temp table swap
- messages.conversation_key for PRIV messages
- app_settings.favorites (contact IDs)
- app_settings.last_message_times (contact- prefixed keys)
Handles case collisions by keeping the most-recently-seen contact.
"""
import json
# 1. Lowercase message conversation keys for private messages
try:
await conn.execute(
"UPDATE messages SET conversation_key = lower(conversation_key) WHERE type = 'PRIV'"
)
logger.debug("Lowercased PRIV message conversation_keys")
except aiosqlite.OperationalError as e:
if "no such table" in str(e).lower():
logger.debug("messages table does not exist yet, skipping conversation_key lowercase")
else:
raise
# 2. Check if contacts table exists before proceeding
cursor = await conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='contacts'"
)
if not await cursor.fetchone():
logger.debug("contacts table does not exist yet, skipping key lowercase")
await conn.commit()
return
# 3. Handle contacts table - check for case collisions first
cursor = await conn.execute(
"SELECT lower(public_key) as lk, COUNT(*) as cnt "
"FROM contacts GROUP BY lower(public_key) HAVING COUNT(*) > 1"
)
collisions = list(await cursor.fetchall())
if collisions:
logger.warning(
"Found %d case-colliding contact groups, keeping most-recently-seen",
len(collisions),
)
for row in collisions:
lower_key = row[0]
# Delete all but the most recently seen
await conn.execute(
"""DELETE FROM contacts WHERE public_key IN (
SELECT public_key FROM contacts
WHERE lower(public_key) = ?
ORDER BY COALESCE(last_seen, 0) DESC
LIMIT -1 OFFSET 1
)""",
(lower_key,),
)
# 3. Rebuild contacts with lowercased keys
# Get the actual column names from the table (handles different schema versions)
cursor = await conn.execute("PRAGMA table_info(contacts)")
columns_info = await cursor.fetchall()
all_columns = [col[1] for col in columns_info] # col[1] is column name
# Build column lists, lowering public_key
select_cols = ", ".join(f"lower({c})" if c == "public_key" else c for c in all_columns)
col_defs = []
for col in columns_info:
name, col_type, _notnull, default, pk = col[1], col[2], col[3], col[4], col[5]
parts = [name, col_type or "TEXT"]
if pk:
parts.append("PRIMARY KEY")
if default is not None:
parts.append(f"DEFAULT {default}")
col_defs.append(" ".join(parts))
create_sql = f"CREATE TABLE contacts_new ({', '.join(col_defs)})"
await conn.execute(create_sql)
await conn.execute(f"INSERT INTO contacts_new SELECT {select_cols} FROM contacts")
await conn.execute("DROP TABLE contacts")
await conn.execute("ALTER TABLE contacts_new RENAME TO contacts")
# Recreate the on_radio index (if column exists)
if "on_radio" in all_columns:
await conn.execute("CREATE INDEX IF NOT EXISTS idx_contacts_on_radio ON contacts(on_radio)")
# 4. Lowercase contact IDs in favorites JSON (if app_settings exists)
cursor = await conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='app_settings'"
)
if not await cursor.fetchone():
await conn.commit()
logger.info("Lowercased all contact public keys (no app_settings table)")
return
cursor = await conn.execute("SELECT favorites FROM app_settings WHERE id = 1")
row = await cursor.fetchone()
if row and row[0]:
try:
favorites = json.loads(row[0])
updated = False
for fav in favorites:
if fav.get("type") == "contact" and fav.get("id"):
new_id = fav["id"].lower()
if new_id != fav["id"]:
fav["id"] = new_id
updated = True
if updated:
await conn.execute(
"UPDATE app_settings SET favorites = ? WHERE id = 1",
(json.dumps(favorites),),
)
logger.debug("Lowercased contact IDs in favorites")
except (json.JSONDecodeError, TypeError):
pass
# 5. Lowercase contact keys in last_message_times JSON
cursor = await conn.execute("SELECT last_message_times FROM app_settings WHERE id = 1")
row = await cursor.fetchone()
if row and row[0]:
try:
times = json.loads(row[0])
new_times = {}
updated = False
for key, val in times.items():
if key.startswith("contact-"):
new_key = "contact-" + key[8:].lower()
if new_key != key:
updated = True
new_times[new_key] = val
else:
new_times[key] = val
if updated:
await conn.execute(
"UPDATE app_settings SET last_message_times = ? WHERE id = 1",
(json.dumps(new_times),),
)
logger.debug("Lowercased contact keys in last_message_times")
except (json.JSONDecodeError, TypeError):
pass
await conn.commit()
logger.info("Lowercased all contact public keys")
async def _migrate_015_fix_null_sender_timestamp(conn: aiosqlite.Connection) -> None:
"""
Fix NULL sender_timestamp values and add null-safe dedup index.
1. Set sender_timestamp = received_at for any messages with NULL sender_timestamp
2. Create a null-safe unique index as belt-and-suspenders protection
"""
# Check if messages table exists
cursor = await conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='messages'"
)
if not await cursor.fetchone():
logger.debug("messages table does not exist yet, skipping NULL sender_timestamp fix")
await conn.commit()
return
# Backfill NULL sender_timestamps with received_at
cursor = await conn.execute(
"UPDATE messages SET sender_timestamp = received_at WHERE sender_timestamp IS NULL"
)
if cursor.rowcount > 0:
logger.info("Backfilled %d messages with NULL sender_timestamp", cursor.rowcount)
# Try to create null-safe dedup index (may fail if existing duplicates exist)
try:
await conn.execute(
"""CREATE UNIQUE INDEX IF NOT EXISTS idx_messages_dedup_null_safe
ON messages(type, conversation_key, text, COALESCE(sender_timestamp, 0))"""
)
logger.debug("Created null-safe dedup index")
except aiosqlite.IntegrityError:
logger.warning(
"Could not create null-safe dedup index due to existing duplicates - "
"the application-level dedup will handle these"
)
await conn.commit()

View File

@@ -636,7 +636,7 @@ async def _process_advertisement(
new_path_hex = packet_info.path.hex() if packet_info.path else ""
# Try to find existing contact
existing = await ContactRepository.get_by_key(advert.public_key)
existing = await ContactRepository.get_by_key(advert.public_key.lower())
# Determine which path to use: keep shorter path if heard recently (within 60s)
# This handles advertisement echoes through different routes
@@ -683,7 +683,7 @@ async def _process_advertisement(
)
contact_data = {
"public_key": advert.public_key,
"public_key": advert.public_key.lower(),
"name": advert.name,
"type": contact_type,
"lat": advert.lat,
@@ -700,7 +700,7 @@ async def _process_advertisement(
broadcast_event(
"contact",
{
"public_key": advert.public_key,
"public_key": advert.public_key.lower(),
"name": advert.name,
"type": contact_type,
"flags": existing.flags if existing else 0,
@@ -721,7 +721,7 @@ async def _process_advertisement(
settings = await AppSettingsRepository.get()
if settings.auto_decrypt_dm_on_advert:
await start_historical_dm_decryption(None, advert.public_key, advert.name)
await start_historical_dm_decryption(None, advert.public_key.lower(), advert.name)
# If this is not a repeater, trigger recent contacts sync to radio
# This ensures we can auto-ACK DMs from recent contacts
@@ -793,9 +793,8 @@ async def _process_direct_message(
# For outgoing: match dest_hash (recipient's first byte)
match_hash = dest_hash if is_outgoing else src_hash
# Get all contacts and filter by first byte of public key
contacts = await ContactRepository.get_all(limit=1000)
candidate_contacts = [c for c in contacts if c.public_key.lower().startswith(match_hash)]
# Get contacts matching the first byte of public key via targeted SQL query
candidate_contacts = await ContactRepository.get_by_pubkey_first_byte(match_hash)
if not candidate_contacts:
logger.debug(

View File

@@ -42,7 +42,7 @@ class ContactRepository:
last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted)
""",
(
contact.get("public_key"),
contact.get("public_key", "").lower(),
contact.get("name") or contact.get("adv_name"),
contact.get("type", 0),
contact.get("flags", 0),
@@ -81,7 +81,9 @@ class ContactRepository:
@staticmethod
async def get_by_key(public_key: str) -> Contact | None:
cursor = await db.conn.execute("SELECT * FROM contacts WHERE public_key = ?", (public_key,))
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),)
)
row = await cursor.fetchone()
return ContactRepository._row_to_contact(row) if row else None
@@ -89,7 +91,7 @@ class ContactRepository:
async def get_by_key_prefix(prefix: str) -> Contact | None:
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key LIKE ? LIMIT 1",
(f"{prefix}%",),
(f"{prefix.lower()}%",),
)
row = await cursor.fetchone()
return ContactRepository._row_to_contact(row) if row else None
@@ -137,7 +139,7 @@ class ContactRepository:
async def update_path(public_key: str, path: str, path_len: int) -> None:
await db.conn.execute(
"UPDATE contacts SET last_path = ?, last_path_len = ?, last_seen = ? WHERE public_key = ?",
(path, path_len, int(time.time()), public_key),
(path, path_len, int(time.time()), public_key.lower()),
)
await db.conn.commit()
@@ -145,7 +147,7 @@ class ContactRepository:
async def set_on_radio(public_key: str, on_radio: bool) -> None:
await db.conn.execute(
"UPDATE contacts SET on_radio = ? WHERE public_key = ?",
(on_radio, public_key),
(on_radio, public_key.lower()),
)
await db.conn.commit()
@@ -153,7 +155,7 @@ class ContactRepository:
async def delete(public_key: str) -> None:
await db.conn.execute(
"DELETE FROM contacts WHERE public_key = ?",
(public_key,),
(public_key.lower(),),
)
await db.conn.commit()
@@ -163,7 +165,7 @@ class ContactRepository:
ts = timestamp or int(time.time())
await db.conn.execute(
"UPDATE contacts SET last_contacted = ?, last_seen = ? WHERE public_key = ?",
(ts, ts, public_key),
(ts, ts, public_key.lower()),
)
await db.conn.commit()
@@ -176,11 +178,21 @@ class ContactRepository:
ts = timestamp or int(time.time())
cursor = await db.conn.execute(
"UPDATE contacts SET last_read_at = ? WHERE public_key = ?",
(ts, public_key),
(ts, public_key.lower()),
)
await db.conn.commit()
return cursor.rowcount > 0
@staticmethod
async def get_by_pubkey_first_byte(hex_byte: str) -> list[Contact]:
"""Get contacts whose public key starts with the given hex byte (2 chars)."""
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?",
(hex_byte.lower(),),
)
rows = await cursor.fetchall()
return [ContactRepository._row_to_contact(row) for row in rows]
class ChannelRepository:
@staticmethod
@@ -357,12 +369,31 @@ class MessageRepository:
return [MessagePath(**p) for p in existing_paths]
@staticmethod
async def claim_prefix_messages(full_key: str) -> int:
"""Promote prefix-stored messages to the full conversation key.
When a full key becomes known for a contact, any messages stored with
only a prefix as conversation_key are updated to use the full key.
"""
lower_key = full_key.lower()
cursor = await db.conn.execute(
"""UPDATE messages SET conversation_key = ?
WHERE type = 'PRIV' AND length(conversation_key) < 64
AND ? LIKE conversation_key || '%'""",
(lower_key, lower_key),
)
await db.conn.commit()
return cursor.rowcount
@staticmethod
async def get_all(
limit: int = 100,
offset: int = 0,
msg_type: str | None = None,
conversation_key: str | None = None,
before: int | None = None,
before_id: int | None = None,
) -> list[Message]:
query = "SELECT * FROM messages WHERE 1=1"
params: list[Any] = []
@@ -375,8 +406,15 @@ class MessageRepository:
query += " AND conversation_key LIKE ?"
params.append(f"{conversation_key}%")
query += " ORDER BY received_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
if before is not None and before_id is not None:
query += " AND (received_at < ? OR (received_at = ? AND id < ?))"
params.extend([before, before, before_id])
query += " ORDER BY received_at DESC, id DESC LIMIT ?"
params.append(limit)
if before is None or before_id is None:
query += " OFFSET ?"
params.append(offset)
cursor = await db.conn.execute(query, params)
rows = await cursor.fetchall()

View File

@@ -19,7 +19,7 @@ from app.models import (
from app.packet_processor import start_historical_dm_decryption
from app.radio import radio_manager
from app.radio_sync import pause_polling
from app.repository import ContactRepository
from app.repository import ContactRepository, MessageRepository
logger = logging.getLogger(__name__)
@@ -119,8 +119,9 @@ async def create_contact(
return existing
# Create new contact
lower_key = request.public_key.lower()
contact_data = {
"public_key": request.public_key,
"public_key": lower_key,
"name": request.name,
"type": 0, # Unknown
"flags": 0,
@@ -134,11 +135,16 @@ async def create_contact(
"last_contacted": None,
}
await ContactRepository.upsert(contact_data)
logger.info("Created contact %s", request.public_key[:12])
logger.info("Created contact %s", lower_key[:12])
# Promote any prefix-stored messages to this full key
claimed = await MessageRepository.claim_prefix_messages(lower_key)
if claimed > 0:
logger.info("Claimed %d prefix messages for contact %s", claimed, lower_key[:12])
# Trigger historical decryption if requested
if request.try_historical:
await start_historical_dm_decryption(background_tasks, request.public_key, request.name)
await start_historical_dm_decryption(background_tasks, lower_key, request.name)
return Contact(**contact_data)

View File

@@ -22,6 +22,10 @@ async def list_messages(
conversation_key: str | None = Query(
default=None, description="Filter by conversation key (channel key or contact pubkey)"
),
before: int | None = Query(
default=None, description="Cursor: received_at of last seen message"
),
before_id: int | None = Query(default=None, description="Cursor: id of last seen message"),
) -> list[Message]:
"""List messages from the database."""
return await MessageRepository.get_all(
@@ -29,6 +33,8 @@ async def list_messages(
offset=offset,
msg_type=type,
conversation_key=conversation_key,
before=before,
before_id=before_id,
)
@@ -94,7 +100,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
message_id = await MessageRepository.create(
msg_type="PRIV",
text=request.text,
conversation_key=db_contact.public_key,
conversation_key=db_contact.public_key.lower(),
sender_timestamp=now,
received_at=now,
outgoing=True,
@@ -106,7 +112,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
)
# Update last_contacted for the contact
await ContactRepository.update_last_contacted(db_contact.public_key, now)
await ContactRepository.update_last_contacted(db_contact.public_key.lower(), now)
# Track the expected ACK for this message
expected_ack = result.payload.get("expected_ack")
@@ -119,7 +125,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
message = Message(
id=message_id,
type="PRIV",
conversation_key=db_contact.public_key,
conversation_key=db_contact.public_key.lower(),
text=request.text,
sender_timestamp=now,
received_at=now,
@@ -133,7 +139,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
asyncio.create_task(
run_bot_for_message(
sender_name=None,
sender_key=db_contact.public_key,
sender_key=db_contact.public_key.lower(),
message_text=request.text,
is_dm=True,
channel_key=None,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -13,7 +13,7 @@
<link rel="shortcut icon" href="/favicon.ico" />
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
<link rel="manifest" href="/site.webmanifest" />
<script type="module" crossorigin src="/assets/index-DafoZZfC.js"></script>
<script type="module" crossorigin src="/assets/index-CL0FbnNJ.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-DJA5wYVF.css">
</head>
<body>

View File

@@ -138,6 +138,8 @@ export const api = {
offset?: number;
type?: 'PRIV' | 'CHAN';
conversation_key?: string;
before?: number;
before_id?: number;
},
signal?: AbortSignal
) => {
@@ -146,6 +148,8 @@ export const api = {
if (params?.offset) searchParams.set('offset', params.offset.toString());
if (params?.type) searchParams.set('type', params.type);
if (params?.conversation_key) searchParams.set('conversation_key', params.conversation_key);
if (params?.before !== undefined) searchParams.set('before', params.before.toString());
if (params?.before_id !== undefined) searchParams.set('before_id', params.before_id.toString());
const query = searchParams.toString();
return fetchJson<Message[]>(`/messages${query ? `?${query}` : ''}`, { signal });
},

View File

@@ -102,7 +102,7 @@ export function useConversationMessages(
[activeConversation]
);
// Fetch older messages (pagination)
// Fetch older messages (cursor-based pagination)
const fetchOlderMessages = useCallback(async () => {
if (
!activeConversation ||
@@ -112,13 +112,18 @@ export function useConversationMessages(
)
return;
// Get the oldest message as cursor for the next page
const oldestMessage = messages[messages.length - 1];
if (!oldestMessage) return;
setLoadingOlder(true);
try {
const data = await api.getMessages({
type: activeConversation.type === 'channel' ? 'CHAN' : 'PRIV',
conversation_key: activeConversation.id,
limit: MESSAGE_PAGE_SIZE,
offset: messages.length,
before: oldestMessage.received_at,
before_id: oldestMessage.id,
});
if (data.length > 0) {
@@ -139,7 +144,7 @@ export function useConversationMessages(
} finally {
setLoadingOlder(false);
}
}, [activeConversation, loadingOlder, hasOlderMessages, messages.length]);
}, [activeConversation, loadingOlder, hasOlderMessages, messages]);
// Fetch messages when conversation changes, with proper cancellation
useEffect(() => {

View File

@@ -98,6 +98,11 @@ class TestCreateContact:
"app.routers.contacts.ContactRepository.upsert",
new_callable=AsyncMock,
) as mock_upsert,
patch(
"app.routers.contacts.MessageRepository.claim_prefix_messages",
new_callable=AsyncMock,
return_value=0,
),
):
from app.main import app

View File

@@ -0,0 +1,119 @@
"""Tests for public key case normalization."""
import pytest
from app.database import Database
from app.repository import ContactRepository, MessageRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.mark.asyncio
async def test_upsert_stores_lowercase_key(test_db):
await ContactRepository.upsert(
{"public_key": "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2"}
)
contact = await ContactRepository.get_by_key(
"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
)
assert contact is not None
assert contact.public_key == "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
@pytest.mark.asyncio
async def test_get_by_key_case_insensitive(test_db):
await ContactRepository.upsert(
{"public_key": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"}
)
contact = await ContactRepository.get_by_key(
"A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2"
)
assert contact is not None
@pytest.mark.asyncio
async def test_update_last_contacted_case_insensitive(test_db):
key = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
await ContactRepository.upsert({"public_key": key})
await ContactRepository.update_last_contacted(key.upper(), 12345)
contact = await ContactRepository.get_by_key(key)
assert contact is not None
assert contact.last_contacted == 12345
@pytest.mark.asyncio
async def test_get_by_pubkey_first_byte(test_db):
key1 = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
key2 = "a1ffddeeaabb1122334455667788990011223344556677889900aabbccddeeff00"
key3 = "b2b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
for key in [key1, key2, key3]:
await ContactRepository.upsert({"public_key": key})
results = await ContactRepository.get_by_pubkey_first_byte("a1")
assert len(results) == 2
result_keys = {c.public_key for c in results}
assert key1 in result_keys
assert key2 in result_keys
results = await ContactRepository.get_by_pubkey_first_byte("A1")
assert len(results) == 2 # case insensitive
@pytest.mark.asyncio
async def test_null_sender_timestamp_defaults_to_received_at(test_db):
"""Verify that a None/0 sender_timestamp is replaced by received_at."""
msg_id = await MessageRepository.create(
msg_type="PRIV",
text="hello",
conversation_key="abcd1234" * 8,
sender_timestamp=500, # simulates fallback: `payload.get("sender_timestamp") or received_at`
received_at=500,
)
assert msg_id is not None
messages = await MessageRepository.get_all(
msg_type="PRIV", conversation_key="abcd1234" * 8, limit=10
)
assert len(messages) == 1
assert messages[0].sender_timestamp == 500
@pytest.mark.asyncio
async def test_duplicate_with_same_text_and_null_timestamp_rejected(test_db):
"""Two messages with same content and sender_timestamp should be deduped."""
received_at = 600
msg_id1 = await MessageRepository.create(
msg_type="PRIV",
text="hello",
conversation_key="abcd1234" * 8,
sender_timestamp=received_at,
received_at=received_at,
)
assert msg_id1 is not None
msg_id2 = await MessageRepository.create(
msg_type="PRIV",
text="hello",
conversation_key="abcd1234" * 8,
sender_timestamp=received_at,
received_at=received_at,
)
assert msg_id2 is None # duplicate rejected

View File

@@ -0,0 +1,64 @@
"""Tests for message pagination using cursor parameters."""
import pytest
from app.database import Database
from app.repository import MessageRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.mark.asyncio
async def test_cursor_pagination_avoids_overlap(test_db):
key = "ABC123DEF456ABC123DEF456ABC12345"
ids = []
for received_at, text in [(200, "m1"), (200, "m2"), (150, "m3"), (100, "m4")]:
msg_id = await MessageRepository.create(
msg_type="CHAN",
text=text,
conversation_key=key,
sender_timestamp=received_at,
received_at=received_at,
)
assert msg_id is not None
ids.append(msg_id)
page1 = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=key,
limit=2,
offset=0,
)
assert len(page1) == 2
oldest = page1[-1]
page2 = await MessageRepository.get_all(
msg_type="CHAN",
conversation_key=key,
limit=2,
offset=0,
before=oldest.received_at,
before_id=oldest.id,
)
assert len(page2) == 2
ids_page1 = {m.id for m in page1}
ids_page2 = {m.id for m in page2}
assert ids_page1.isdisjoint(ids_page2)

View File

@@ -0,0 +1,50 @@
"""Tests for prefix-claiming DM messages."""
import pytest
from app.database import Database
from app.repository import MessageRepository
@pytest.fixture
async def test_db():
"""Create an in-memory test database."""
import app.repository as repo_module
db = Database(":memory:")
await db.connect()
original_db = repo_module.db
repo_module.db = db
try:
yield db
finally:
repo_module.db = original_db
await db.disconnect()
@pytest.mark.asyncio
async def test_claim_prefix_promotes_dm_to_full_key(test_db):
full_key = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7"
prefix = full_key[:6].upper()
msg_id = await MessageRepository.create(
msg_type="PRIV",
text="hello",
conversation_key=prefix,
sender_timestamp=123,
received_at=123,
)
assert msg_id is not None
updated = await MessageRepository.claim_prefix_messages(full_key)
assert updated == 1
messages = await MessageRepository.get_all(
msg_type="PRIV",
conversation_key=full_key,
limit=10,
)
assert len(messages) == 1
assert messages[0].conversation_key == full_key.lower()

View File

@@ -100,8 +100,8 @@ class TestMigration001:
# Run migrations
applied = await run_migrations(conn)
assert applied == 13 # All 13 migrations run
assert await get_version(conn) == 13
assert applied == 15 # All 15 migrations run
assert await get_version(conn) == 15
# Verify columns exist by inserting and selecting
await conn.execute(
@@ -183,9 +183,9 @@ class TestMigration001:
applied1 = await run_migrations(conn)
applied2 = await run_migrations(conn)
assert applied1 == 13 # All 13 migrations run
assert applied1 == 15 # All 15 migrations run
assert applied2 == 0 # No migrations on second run
assert await get_version(conn) == 13
assert await get_version(conn) == 15
finally:
await conn.close()
@@ -245,9 +245,9 @@ class TestMigration001:
# Run migrations - should not fail
applied = await run_migrations(conn)
# All 13 migrations applied (version incremented) but no error
assert applied == 13
assert await get_version(conn) == 13
# All 15 migrations applied (version incremented) but no error
assert applied == 15
assert await get_version(conn) == 15
finally:
await conn.close()
@@ -374,10 +374,10 @@ class TestMigration013:
)
await conn.commit()
# Run migration 13
# Run migration 13 (plus 14+15 which also run)
applied = await run_migrations(conn)
assert applied == 1
assert await get_version(conn) == 13
assert applied == 3
assert await get_version(conn) == 15
# Verify bots array was created with migrated data
cursor = await conn.execute("SELECT bots FROM app_settings WHERE id = 1")