mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-06-11 00:44:48 +02:00
Cleanups: Normalize pub keys, prefix message claiming, cursor + null timestamp DB cleanups
This commit is contained in:
@@ -77,6 +77,9 @@ async def on_contact_message(event: "Event") -> None:
|
||||
if contact:
|
||||
sender_pubkey = contact.public_key.lower()
|
||||
|
||||
# Promote any prefix-stored messages to this full key
|
||||
await MessageRepository.claim_prefix_messages(sender_pubkey)
|
||||
|
||||
# Skip messages from repeaters - they only send CLI responses, not chat messages.
|
||||
# CLI responses are handled by the command endpoint and txt_type filter above.
|
||||
if contact.type == CONTACT_TYPE_REPEATER:
|
||||
@@ -92,7 +95,7 @@ async def on_contact_message(event: "Event") -> None:
|
||||
msg_type="PRIV",
|
||||
text=payload.get("text", ""),
|
||||
conversation_key=sender_pubkey,
|
||||
sender_timestamp=payload.get("sender_timestamp"),
|
||||
sender_timestamp=payload.get("sender_timestamp") or received_at,
|
||||
received_at=received_at,
|
||||
path=payload.get("path"),
|
||||
txt_type=txt_type,
|
||||
@@ -132,7 +135,7 @@ async def on_contact_message(event: "Event") -> None:
|
||||
|
||||
# Update contact last_contacted (contact was already fetched above)
|
||||
if contact:
|
||||
await ContactRepository.update_last_contacted(contact.public_key, received_at)
|
||||
await ContactRepository.update_last_contacted(sender_pubkey, received_at)
|
||||
|
||||
# Run bot if enabled
|
||||
from app.bot import run_bot_for_message
|
||||
|
||||
@@ -128,6 +128,20 @@ async def run_migrations(conn: aiosqlite.Connection) -> int:
|
||||
await set_version(conn, 13)
|
||||
applied += 1
|
||||
|
||||
# Migration 14: Lowercase all contact public keys and related data
|
||||
if version < 14:
|
||||
logger.info("Applying migration 14: lowercase all contact public keys")
|
||||
await _migrate_014_lowercase_public_keys(conn)
|
||||
await set_version(conn, 14)
|
||||
applied += 1
|
||||
|
||||
# Migration 15: Fix NULL sender_timestamp and add null-safe dedup index
|
||||
if version < 15:
|
||||
logger.info("Applying migration 15: fix NULL sender_timestamp values")
|
||||
await _migrate_015_fix_null_sender_timestamp(conn)
|
||||
await set_version(conn, 15)
|
||||
applied += 1
|
||||
|
||||
if applied > 0:
|
||||
logger.info(
|
||||
"Applied %d migration(s), schema now at version %d", applied, await get_version(conn)
|
||||
@@ -793,3 +807,189 @@ async def _migrate_013_convert_to_multi_bot(conn: aiosqlite.Connection) -> None:
|
||||
raise
|
||||
|
||||
await conn.commit()
|
||||
|
||||
|
||||
async def _migrate_014_lowercase_public_keys(conn: aiosqlite.Connection) -> None:
|
||||
"""
|
||||
Lowercase all contact public keys and related data for case-insensitive matching.
|
||||
|
||||
Updates:
|
||||
- contacts.public_key (PRIMARY KEY) via temp table swap
|
||||
- messages.conversation_key for PRIV messages
|
||||
- app_settings.favorites (contact IDs)
|
||||
- app_settings.last_message_times (contact- prefixed keys)
|
||||
|
||||
Handles case collisions by keeping the most-recently-seen contact.
|
||||
"""
|
||||
import json
|
||||
|
||||
# 1. Lowercase message conversation keys for private messages
|
||||
try:
|
||||
await conn.execute(
|
||||
"UPDATE messages SET conversation_key = lower(conversation_key) WHERE type = 'PRIV'"
|
||||
)
|
||||
logger.debug("Lowercased PRIV message conversation_keys")
|
||||
except aiosqlite.OperationalError as e:
|
||||
if "no such table" in str(e).lower():
|
||||
logger.debug("messages table does not exist yet, skipping conversation_key lowercase")
|
||||
else:
|
||||
raise
|
||||
|
||||
# 2. Check if contacts table exists before proceeding
|
||||
cursor = await conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='contacts'"
|
||||
)
|
||||
if not await cursor.fetchone():
|
||||
logger.debug("contacts table does not exist yet, skipping key lowercase")
|
||||
await conn.commit()
|
||||
return
|
||||
|
||||
# 3. Handle contacts table - check for case collisions first
|
||||
cursor = await conn.execute(
|
||||
"SELECT lower(public_key) as lk, COUNT(*) as cnt "
|
||||
"FROM contacts GROUP BY lower(public_key) HAVING COUNT(*) > 1"
|
||||
)
|
||||
collisions = list(await cursor.fetchall())
|
||||
|
||||
if collisions:
|
||||
logger.warning(
|
||||
"Found %d case-colliding contact groups, keeping most-recently-seen",
|
||||
len(collisions),
|
||||
)
|
||||
for row in collisions:
|
||||
lower_key = row[0]
|
||||
# Delete all but the most recently seen
|
||||
await conn.execute(
|
||||
"""DELETE FROM contacts WHERE public_key IN (
|
||||
SELECT public_key FROM contacts
|
||||
WHERE lower(public_key) = ?
|
||||
ORDER BY COALESCE(last_seen, 0) DESC
|
||||
LIMIT -1 OFFSET 1
|
||||
)""",
|
||||
(lower_key,),
|
||||
)
|
||||
|
||||
# 3. Rebuild contacts with lowercased keys
|
||||
# Get the actual column names from the table (handles different schema versions)
|
||||
cursor = await conn.execute("PRAGMA table_info(contacts)")
|
||||
columns_info = await cursor.fetchall()
|
||||
all_columns = [col[1] for col in columns_info] # col[1] is column name
|
||||
|
||||
# Build column lists, lowering public_key
|
||||
select_cols = ", ".join(f"lower({c})" if c == "public_key" else c for c in all_columns)
|
||||
col_defs = []
|
||||
for col in columns_info:
|
||||
name, col_type, _notnull, default, pk = col[1], col[2], col[3], col[4], col[5]
|
||||
parts = [name, col_type or "TEXT"]
|
||||
if pk:
|
||||
parts.append("PRIMARY KEY")
|
||||
if default is not None:
|
||||
parts.append(f"DEFAULT {default}")
|
||||
col_defs.append(" ".join(parts))
|
||||
|
||||
create_sql = f"CREATE TABLE contacts_new ({', '.join(col_defs)})"
|
||||
await conn.execute(create_sql)
|
||||
await conn.execute(f"INSERT INTO contacts_new SELECT {select_cols} FROM contacts")
|
||||
await conn.execute("DROP TABLE contacts")
|
||||
await conn.execute("ALTER TABLE contacts_new RENAME TO contacts")
|
||||
|
||||
# Recreate the on_radio index (if column exists)
|
||||
if "on_radio" in all_columns:
|
||||
await conn.execute("CREATE INDEX IF NOT EXISTS idx_contacts_on_radio ON contacts(on_radio)")
|
||||
|
||||
# 4. Lowercase contact IDs in favorites JSON (if app_settings exists)
|
||||
cursor = await conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='app_settings'"
|
||||
)
|
||||
if not await cursor.fetchone():
|
||||
await conn.commit()
|
||||
logger.info("Lowercased all contact public keys (no app_settings table)")
|
||||
return
|
||||
|
||||
cursor = await conn.execute("SELECT favorites FROM app_settings WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row and row[0]:
|
||||
try:
|
||||
favorites = json.loads(row[0])
|
||||
updated = False
|
||||
for fav in favorites:
|
||||
if fav.get("type") == "contact" and fav.get("id"):
|
||||
new_id = fav["id"].lower()
|
||||
if new_id != fav["id"]:
|
||||
fav["id"] = new_id
|
||||
updated = True
|
||||
if updated:
|
||||
await conn.execute(
|
||||
"UPDATE app_settings SET favorites = ? WHERE id = 1",
|
||||
(json.dumps(favorites),),
|
||||
)
|
||||
logger.debug("Lowercased contact IDs in favorites")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 5. Lowercase contact keys in last_message_times JSON
|
||||
cursor = await conn.execute("SELECT last_message_times FROM app_settings WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row and row[0]:
|
||||
try:
|
||||
times = json.loads(row[0])
|
||||
new_times = {}
|
||||
updated = False
|
||||
for key, val in times.items():
|
||||
if key.startswith("contact-"):
|
||||
new_key = "contact-" + key[8:].lower()
|
||||
if new_key != key:
|
||||
updated = True
|
||||
new_times[new_key] = val
|
||||
else:
|
||||
new_times[key] = val
|
||||
if updated:
|
||||
await conn.execute(
|
||||
"UPDATE app_settings SET last_message_times = ? WHERE id = 1",
|
||||
(json.dumps(new_times),),
|
||||
)
|
||||
logger.debug("Lowercased contact keys in last_message_times")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
await conn.commit()
|
||||
logger.info("Lowercased all contact public keys")
|
||||
|
||||
|
||||
async def _migrate_015_fix_null_sender_timestamp(conn: aiosqlite.Connection) -> None:
|
||||
"""
|
||||
Fix NULL sender_timestamp values and add null-safe dedup index.
|
||||
|
||||
1. Set sender_timestamp = received_at for any messages with NULL sender_timestamp
|
||||
2. Create a null-safe unique index as belt-and-suspenders protection
|
||||
"""
|
||||
# Check if messages table exists
|
||||
cursor = await conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='messages'"
|
||||
)
|
||||
if not await cursor.fetchone():
|
||||
logger.debug("messages table does not exist yet, skipping NULL sender_timestamp fix")
|
||||
await conn.commit()
|
||||
return
|
||||
|
||||
# Backfill NULL sender_timestamps with received_at
|
||||
cursor = await conn.execute(
|
||||
"UPDATE messages SET sender_timestamp = received_at WHERE sender_timestamp IS NULL"
|
||||
)
|
||||
if cursor.rowcount > 0:
|
||||
logger.info("Backfilled %d messages with NULL sender_timestamp", cursor.rowcount)
|
||||
|
||||
# Try to create null-safe dedup index (may fail if existing duplicates exist)
|
||||
try:
|
||||
await conn.execute(
|
||||
"""CREATE UNIQUE INDEX IF NOT EXISTS idx_messages_dedup_null_safe
|
||||
ON messages(type, conversation_key, text, COALESCE(sender_timestamp, 0))"""
|
||||
)
|
||||
logger.debug("Created null-safe dedup index")
|
||||
except aiosqlite.IntegrityError:
|
||||
logger.warning(
|
||||
"Could not create null-safe dedup index due to existing duplicates - "
|
||||
"the application-level dedup will handle these"
|
||||
)
|
||||
|
||||
await conn.commit()
|
||||
|
||||
@@ -636,7 +636,7 @@ async def _process_advertisement(
|
||||
new_path_hex = packet_info.path.hex() if packet_info.path else ""
|
||||
|
||||
# Try to find existing contact
|
||||
existing = await ContactRepository.get_by_key(advert.public_key)
|
||||
existing = await ContactRepository.get_by_key(advert.public_key.lower())
|
||||
|
||||
# Determine which path to use: keep shorter path if heard recently (within 60s)
|
||||
# This handles advertisement echoes through different routes
|
||||
@@ -683,7 +683,7 @@ async def _process_advertisement(
|
||||
)
|
||||
|
||||
contact_data = {
|
||||
"public_key": advert.public_key,
|
||||
"public_key": advert.public_key.lower(),
|
||||
"name": advert.name,
|
||||
"type": contact_type,
|
||||
"lat": advert.lat,
|
||||
@@ -700,7 +700,7 @@ async def _process_advertisement(
|
||||
broadcast_event(
|
||||
"contact",
|
||||
{
|
||||
"public_key": advert.public_key,
|
||||
"public_key": advert.public_key.lower(),
|
||||
"name": advert.name,
|
||||
"type": contact_type,
|
||||
"flags": existing.flags if existing else 0,
|
||||
@@ -721,7 +721,7 @@ async def _process_advertisement(
|
||||
|
||||
settings = await AppSettingsRepository.get()
|
||||
if settings.auto_decrypt_dm_on_advert:
|
||||
await start_historical_dm_decryption(None, advert.public_key, advert.name)
|
||||
await start_historical_dm_decryption(None, advert.public_key.lower(), advert.name)
|
||||
|
||||
# If this is not a repeater, trigger recent contacts sync to radio
|
||||
# This ensures we can auto-ACK DMs from recent contacts
|
||||
@@ -793,9 +793,8 @@ async def _process_direct_message(
|
||||
# For outgoing: match dest_hash (recipient's first byte)
|
||||
match_hash = dest_hash if is_outgoing else src_hash
|
||||
|
||||
# Get all contacts and filter by first byte of public key
|
||||
contacts = await ContactRepository.get_all(limit=1000)
|
||||
candidate_contacts = [c for c in contacts if c.public_key.lower().startswith(match_hash)]
|
||||
# Get contacts matching the first byte of public key via targeted SQL query
|
||||
candidate_contacts = await ContactRepository.get_by_pubkey_first_byte(match_hash)
|
||||
|
||||
if not candidate_contacts:
|
||||
logger.debug(
|
||||
|
||||
+48
-10
@@ -42,7 +42,7 @@ class ContactRepository:
|
||||
last_contacted = COALESCE(excluded.last_contacted, contacts.last_contacted)
|
||||
""",
|
||||
(
|
||||
contact.get("public_key"),
|
||||
contact.get("public_key", "").lower(),
|
||||
contact.get("name") or contact.get("adv_name"),
|
||||
contact.get("type", 0),
|
||||
contact.get("flags", 0),
|
||||
@@ -81,7 +81,9 @@ class ContactRepository:
|
||||
|
||||
@staticmethod
|
||||
async def get_by_key(public_key: str) -> Contact | None:
|
||||
cursor = await db.conn.execute("SELECT * FROM contacts WHERE public_key = ?", (public_key,))
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT * FROM contacts WHERE public_key = ?", (public_key.lower(),)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return ContactRepository._row_to_contact(row) if row else None
|
||||
|
||||
@@ -89,7 +91,7 @@ class ContactRepository:
|
||||
async def get_by_key_prefix(prefix: str) -> Contact | None:
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT * FROM contacts WHERE public_key LIKE ? LIMIT 1",
|
||||
(f"{prefix}%",),
|
||||
(f"{prefix.lower()}%",),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return ContactRepository._row_to_contact(row) if row else None
|
||||
@@ -137,7 +139,7 @@ class ContactRepository:
|
||||
async def update_path(public_key: str, path: str, path_len: int) -> None:
|
||||
await db.conn.execute(
|
||||
"UPDATE contacts SET last_path = ?, last_path_len = ?, last_seen = ? WHERE public_key = ?",
|
||||
(path, path_len, int(time.time()), public_key),
|
||||
(path, path_len, int(time.time()), public_key.lower()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
@@ -145,7 +147,7 @@ class ContactRepository:
|
||||
async def set_on_radio(public_key: str, on_radio: bool) -> None:
|
||||
await db.conn.execute(
|
||||
"UPDATE contacts SET on_radio = ? WHERE public_key = ?",
|
||||
(on_radio, public_key),
|
||||
(on_radio, public_key.lower()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
@@ -153,7 +155,7 @@ class ContactRepository:
|
||||
async def delete(public_key: str) -> None:
|
||||
await db.conn.execute(
|
||||
"DELETE FROM contacts WHERE public_key = ?",
|
||||
(public_key,),
|
||||
(public_key.lower(),),
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
@@ -163,7 +165,7 @@ class ContactRepository:
|
||||
ts = timestamp or int(time.time())
|
||||
await db.conn.execute(
|
||||
"UPDATE contacts SET last_contacted = ?, last_seen = ? WHERE public_key = ?",
|
||||
(ts, ts, public_key),
|
||||
(ts, ts, public_key.lower()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
@@ -176,11 +178,21 @@ class ContactRepository:
|
||||
ts = timestamp or int(time.time())
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE contacts SET last_read_at = ? WHERE public_key = ?",
|
||||
(ts, public_key),
|
||||
(ts, public_key.lower()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
@staticmethod
|
||||
async def get_by_pubkey_first_byte(hex_byte: str) -> list[Contact]:
|
||||
"""Get contacts whose public key starts with the given hex byte (2 chars)."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT * FROM contacts WHERE substr(public_key, 1, 2) = ?",
|
||||
(hex_byte.lower(),),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [ContactRepository._row_to_contact(row) for row in rows]
|
||||
|
||||
|
||||
class ChannelRepository:
|
||||
@staticmethod
|
||||
@@ -357,12 +369,31 @@ class MessageRepository:
|
||||
|
||||
return [MessagePath(**p) for p in existing_paths]
|
||||
|
||||
@staticmethod
|
||||
async def claim_prefix_messages(full_key: str) -> int:
|
||||
"""Promote prefix-stored messages to the full conversation key.
|
||||
|
||||
When a full key becomes known for a contact, any messages stored with
|
||||
only a prefix as conversation_key are updated to use the full key.
|
||||
"""
|
||||
lower_key = full_key.lower()
|
||||
cursor = await db.conn.execute(
|
||||
"""UPDATE messages SET conversation_key = ?
|
||||
WHERE type = 'PRIV' AND length(conversation_key) < 64
|
||||
AND ? LIKE conversation_key || '%'""",
|
||||
(lower_key, lower_key),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
@staticmethod
|
||||
async def get_all(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
msg_type: str | None = None,
|
||||
conversation_key: str | None = None,
|
||||
before: int | None = None,
|
||||
before_id: int | None = None,
|
||||
) -> list[Message]:
|
||||
query = "SELECT * FROM messages WHERE 1=1"
|
||||
params: list[Any] = []
|
||||
@@ -375,8 +406,15 @@ class MessageRepository:
|
||||
query += " AND conversation_key LIKE ?"
|
||||
params.append(f"{conversation_key}%")
|
||||
|
||||
query += " ORDER BY received_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
if before is not None and before_id is not None:
|
||||
query += " AND (received_at < ? OR (received_at = ? AND id < ?))"
|
||||
params.extend([before, before, before_id])
|
||||
|
||||
query += " ORDER BY received_at DESC, id DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
if before is None or before_id is None:
|
||||
query += " OFFSET ?"
|
||||
params.append(offset)
|
||||
|
||||
cursor = await db.conn.execute(query, params)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
+10
-4
@@ -19,7 +19,7 @@ from app.models import (
|
||||
from app.packet_processor import start_historical_dm_decryption
|
||||
from app.radio import radio_manager
|
||||
from app.radio_sync import pause_polling
|
||||
from app.repository import ContactRepository
|
||||
from app.repository import ContactRepository, MessageRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -119,8 +119,9 @@ async def create_contact(
|
||||
return existing
|
||||
|
||||
# Create new contact
|
||||
lower_key = request.public_key.lower()
|
||||
contact_data = {
|
||||
"public_key": request.public_key,
|
||||
"public_key": lower_key,
|
||||
"name": request.name,
|
||||
"type": 0, # Unknown
|
||||
"flags": 0,
|
||||
@@ -134,11 +135,16 @@ async def create_contact(
|
||||
"last_contacted": None,
|
||||
}
|
||||
await ContactRepository.upsert(contact_data)
|
||||
logger.info("Created contact %s", request.public_key[:12])
|
||||
logger.info("Created contact %s", lower_key[:12])
|
||||
|
||||
# Promote any prefix-stored messages to this full key
|
||||
claimed = await MessageRepository.claim_prefix_messages(lower_key)
|
||||
if claimed > 0:
|
||||
logger.info("Claimed %d prefix messages for contact %s", claimed, lower_key[:12])
|
||||
|
||||
# Trigger historical decryption if requested
|
||||
if request.try_historical:
|
||||
await start_historical_dm_decryption(background_tasks, request.public_key, request.name)
|
||||
await start_historical_dm_decryption(background_tasks, lower_key, request.name)
|
||||
|
||||
return Contact(**contact_data)
|
||||
|
||||
|
||||
+10
-4
@@ -22,6 +22,10 @@ async def list_messages(
|
||||
conversation_key: str | None = Query(
|
||||
default=None, description="Filter by conversation key (channel key or contact pubkey)"
|
||||
),
|
||||
before: int | None = Query(
|
||||
default=None, description="Cursor: received_at of last seen message"
|
||||
),
|
||||
before_id: int | None = Query(default=None, description="Cursor: id of last seen message"),
|
||||
) -> list[Message]:
|
||||
"""List messages from the database."""
|
||||
return await MessageRepository.get_all(
|
||||
@@ -29,6 +33,8 @@ async def list_messages(
|
||||
offset=offset,
|
||||
msg_type=type,
|
||||
conversation_key=conversation_key,
|
||||
before=before,
|
||||
before_id=before_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,7 +100,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
|
||||
message_id = await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text=request.text,
|
||||
conversation_key=db_contact.public_key,
|
||||
conversation_key=db_contact.public_key.lower(),
|
||||
sender_timestamp=now,
|
||||
received_at=now,
|
||||
outgoing=True,
|
||||
@@ -106,7 +112,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
|
||||
)
|
||||
|
||||
# Update last_contacted for the contact
|
||||
await ContactRepository.update_last_contacted(db_contact.public_key, now)
|
||||
await ContactRepository.update_last_contacted(db_contact.public_key.lower(), now)
|
||||
|
||||
# Track the expected ACK for this message
|
||||
expected_ack = result.payload.get("expected_ack")
|
||||
@@ -119,7 +125,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
|
||||
message = Message(
|
||||
id=message_id,
|
||||
type="PRIV",
|
||||
conversation_key=db_contact.public_key,
|
||||
conversation_key=db_contact.public_key.lower(),
|
||||
text=request.text,
|
||||
sender_timestamp=now,
|
||||
received_at=now,
|
||||
@@ -133,7 +139,7 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
|
||||
asyncio.create_task(
|
||||
run_bot_for_message(
|
||||
sender_name=None,
|
||||
sender_key=db_contact.public_key,
|
||||
sender_key=db_contact.public_key.lower(),
|
||||
message_text=request.text,
|
||||
is_dm=True,
|
||||
channel_key=None,
|
||||
|
||||
Vendored
+2
-2
File diff suppressed because one or more lines are too long
+1
-1
File diff suppressed because one or more lines are too long
+19
-19
File diff suppressed because one or more lines are too long
+1
File diff suppressed because one or more lines are too long
-1
File diff suppressed because one or more lines are too long
Vendored
+1
-1
@@ -13,7 +13,7 @@
|
||||
<link rel="shortcut icon" href="/favicon.ico" />
|
||||
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<script type="module" crossorigin src="/assets/index-DafoZZfC.js"></script>
|
||||
<script type="module" crossorigin src="/assets/index-CL0FbnNJ.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-DJA5wYVF.css">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
@@ -138,6 +138,8 @@ export const api = {
|
||||
offset?: number;
|
||||
type?: 'PRIV' | 'CHAN';
|
||||
conversation_key?: string;
|
||||
before?: number;
|
||||
before_id?: number;
|
||||
},
|
||||
signal?: AbortSignal
|
||||
) => {
|
||||
@@ -146,6 +148,8 @@ export const api = {
|
||||
if (params?.offset) searchParams.set('offset', params.offset.toString());
|
||||
if (params?.type) searchParams.set('type', params.type);
|
||||
if (params?.conversation_key) searchParams.set('conversation_key', params.conversation_key);
|
||||
if (params?.before !== undefined) searchParams.set('before', params.before.toString());
|
||||
if (params?.before_id !== undefined) searchParams.set('before_id', params.before_id.toString());
|
||||
const query = searchParams.toString();
|
||||
return fetchJson<Message[]>(`/messages${query ? `?${query}` : ''}`, { signal });
|
||||
},
|
||||
|
||||
@@ -102,7 +102,7 @@ export function useConversationMessages(
|
||||
[activeConversation]
|
||||
);
|
||||
|
||||
// Fetch older messages (pagination)
|
||||
// Fetch older messages (cursor-based pagination)
|
||||
const fetchOlderMessages = useCallback(async () => {
|
||||
if (
|
||||
!activeConversation ||
|
||||
@@ -112,13 +112,18 @@ export function useConversationMessages(
|
||||
)
|
||||
return;
|
||||
|
||||
// Get the oldest message as cursor for the next page
|
||||
const oldestMessage = messages[messages.length - 1];
|
||||
if (!oldestMessage) return;
|
||||
|
||||
setLoadingOlder(true);
|
||||
try {
|
||||
const data = await api.getMessages({
|
||||
type: activeConversation.type === 'channel' ? 'CHAN' : 'PRIV',
|
||||
conversation_key: activeConversation.id,
|
||||
limit: MESSAGE_PAGE_SIZE,
|
||||
offset: messages.length,
|
||||
before: oldestMessage.received_at,
|
||||
before_id: oldestMessage.id,
|
||||
});
|
||||
|
||||
if (data.length > 0) {
|
||||
@@ -139,7 +144,7 @@ export function useConversationMessages(
|
||||
} finally {
|
||||
setLoadingOlder(false);
|
||||
}
|
||||
}, [activeConversation, loadingOlder, hasOlderMessages, messages.length]);
|
||||
}, [activeConversation, loadingOlder, hasOlderMessages, messages]);
|
||||
|
||||
// Fetch messages when conversation changes, with proper cancellation
|
||||
useEffect(() => {
|
||||
|
||||
@@ -98,6 +98,11 @@ class TestCreateContact:
|
||||
"app.routers.contacts.ContactRepository.upsert",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_upsert,
|
||||
patch(
|
||||
"app.routers.contacts.MessageRepository.claim_prefix_messages",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
),
|
||||
):
|
||||
from app.main import app
|
||||
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Tests for public key case normalization."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.database import Database
|
||||
from app.repository import ContactRepository, MessageRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db():
|
||||
"""Create an in-memory test database."""
|
||||
import app.repository as repo_module
|
||||
|
||||
db = Database(":memory:")
|
||||
await db.connect()
|
||||
|
||||
original_db = repo_module.db
|
||||
repo_module.db = db
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
repo_module.db = original_db
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_stores_lowercase_key(test_db):
|
||||
await ContactRepository.upsert(
|
||||
{"public_key": "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2"}
|
||||
)
|
||||
contact = await ContactRepository.get_by_key(
|
||||
"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||
)
|
||||
assert contact is not None
|
||||
assert contact.public_key == "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_key_case_insensitive(test_db):
|
||||
await ContactRepository.upsert(
|
||||
{"public_key": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"}
|
||||
)
|
||||
contact = await ContactRepository.get_by_key(
|
||||
"A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2"
|
||||
)
|
||||
assert contact is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_contacted_case_insensitive(test_db):
|
||||
key = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||
await ContactRepository.upsert({"public_key": key})
|
||||
|
||||
await ContactRepository.update_last_contacted(key.upper(), 12345)
|
||||
contact = await ContactRepository.get_by_key(key)
|
||||
assert contact is not None
|
||||
assert contact.last_contacted == 12345
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_pubkey_first_byte(test_db):
|
||||
key1 = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||
key2 = "a1ffddeeaabb1122334455667788990011223344556677889900aabbccddeeff00"
|
||||
key3 = "b2b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||
|
||||
for key in [key1, key2, key3]:
|
||||
await ContactRepository.upsert({"public_key": key})
|
||||
|
||||
results = await ContactRepository.get_by_pubkey_first_byte("a1")
|
||||
assert len(results) == 2
|
||||
result_keys = {c.public_key for c in results}
|
||||
assert key1 in result_keys
|
||||
assert key2 in result_keys
|
||||
|
||||
results = await ContactRepository.get_by_pubkey_first_byte("A1")
|
||||
assert len(results) == 2 # case insensitive
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_sender_timestamp_defaults_to_received_at(test_db):
|
||||
"""Verify that a None/0 sender_timestamp is replaced by received_at."""
|
||||
msg_id = await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text="hello",
|
||||
conversation_key="abcd1234" * 8,
|
||||
sender_timestamp=500, # simulates fallback: `payload.get("sender_timestamp") or received_at`
|
||||
received_at=500,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
messages = await MessageRepository.get_all(
|
||||
msg_type="PRIV", conversation_key="abcd1234" * 8, limit=10
|
||||
)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].sender_timestamp == 500
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_with_same_text_and_null_timestamp_rejected(test_db):
|
||||
"""Two messages with same content and sender_timestamp should be deduped."""
|
||||
received_at = 600
|
||||
msg_id1 = await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text="hello",
|
||||
conversation_key="abcd1234" * 8,
|
||||
sender_timestamp=received_at,
|
||||
received_at=received_at,
|
||||
)
|
||||
assert msg_id1 is not None
|
||||
|
||||
msg_id2 = await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text="hello",
|
||||
conversation_key="abcd1234" * 8,
|
||||
sender_timestamp=received_at,
|
||||
received_at=received_at,
|
||||
)
|
||||
assert msg_id2 is None # duplicate rejected
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Tests for message pagination using cursor parameters."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.database import Database
|
||||
from app.repository import MessageRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db():
|
||||
"""Create an in-memory test database."""
|
||||
import app.repository as repo_module
|
||||
|
||||
db = Database(":memory:")
|
||||
await db.connect()
|
||||
|
||||
original_db = repo_module.db
|
||||
repo_module.db = db
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
repo_module.db = original_db
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_pagination_avoids_overlap(test_db):
|
||||
key = "ABC123DEF456ABC123DEF456ABC12345"
|
||||
|
||||
ids = []
|
||||
for received_at, text in [(200, "m1"), (200, "m2"), (150, "m3"), (100, "m4")]:
|
||||
msg_id = await MessageRepository.create(
|
||||
msg_type="CHAN",
|
||||
text=text,
|
||||
conversation_key=key,
|
||||
sender_timestamp=received_at,
|
||||
received_at=received_at,
|
||||
)
|
||||
assert msg_id is not None
|
||||
ids.append(msg_id)
|
||||
|
||||
page1 = await MessageRepository.get_all(
|
||||
msg_type="CHAN",
|
||||
conversation_key=key,
|
||||
limit=2,
|
||||
offset=0,
|
||||
)
|
||||
assert len(page1) == 2
|
||||
|
||||
oldest = page1[-1]
|
||||
page2 = await MessageRepository.get_all(
|
||||
msg_type="CHAN",
|
||||
conversation_key=key,
|
||||
limit=2,
|
||||
offset=0,
|
||||
before=oldest.received_at,
|
||||
before_id=oldest.id,
|
||||
)
|
||||
assert len(page2) == 2
|
||||
|
||||
ids_page1 = {m.id for m in page1}
|
||||
ids_page2 = {m.id for m in page2}
|
||||
assert ids_page1.isdisjoint(ids_page2)
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for prefix-claiming DM messages."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.database import Database
|
||||
from app.repository import MessageRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db():
|
||||
"""Create an in-memory test database."""
|
||||
import app.repository as repo_module
|
||||
|
||||
db = Database(":memory:")
|
||||
await db.connect()
|
||||
|
||||
original_db = repo_module.db
|
||||
repo_module.db = db
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
repo_module.db = original_db
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_prefix_promotes_dm_to_full_key(test_db):
|
||||
full_key = "a1b2c3d3ba9f5fa8705b9845fe11cc6f01d1d49caaf4d122ac7121663c5beec7"
|
||||
prefix = full_key[:6].upper()
|
||||
|
||||
msg_id = await MessageRepository.create(
|
||||
msg_type="PRIV",
|
||||
text="hello",
|
||||
conversation_key=prefix,
|
||||
sender_timestamp=123,
|
||||
received_at=123,
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
updated = await MessageRepository.claim_prefix_messages(full_key)
|
||||
assert updated == 1
|
||||
|
||||
messages = await MessageRepository.get_all(
|
||||
msg_type="PRIV",
|
||||
conversation_key=full_key,
|
||||
limit=10,
|
||||
)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].conversation_key == full_key.lower()
|
||||
+10
-10
@@ -100,8 +100,8 @@ class TestMigration001:
|
||||
# Run migrations
|
||||
applied = await run_migrations(conn)
|
||||
|
||||
assert applied == 13 # All 13 migrations run
|
||||
assert await get_version(conn) == 13
|
||||
assert applied == 15 # All 15 migrations run
|
||||
assert await get_version(conn) == 15
|
||||
|
||||
# Verify columns exist by inserting and selecting
|
||||
await conn.execute(
|
||||
@@ -183,9 +183,9 @@ class TestMigration001:
|
||||
applied1 = await run_migrations(conn)
|
||||
applied2 = await run_migrations(conn)
|
||||
|
||||
assert applied1 == 13 # All 13 migrations run
|
||||
assert applied1 == 15 # All 15 migrations run
|
||||
assert applied2 == 0 # No migrations on second run
|
||||
assert await get_version(conn) == 13
|
||||
assert await get_version(conn) == 15
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@@ -245,9 +245,9 @@ class TestMigration001:
|
||||
# Run migrations - should not fail
|
||||
applied = await run_migrations(conn)
|
||||
|
||||
# All 13 migrations applied (version incremented) but no error
|
||||
assert applied == 13
|
||||
assert await get_version(conn) == 13
|
||||
# All 15 migrations applied (version incremented) but no error
|
||||
assert applied == 15
|
||||
assert await get_version(conn) == 15
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@@ -374,10 +374,10 @@ class TestMigration013:
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
# Run migration 13
|
||||
# Run migration 13 (plus 14+15 which also run)
|
||||
applied = await run_migrations(conn)
|
||||
assert applied == 1
|
||||
assert await get_version(conn) == 13
|
||||
assert applied == 3
|
||||
assert await get_version(conn) == 15
|
||||
|
||||
# Verify bots array was created with migrated data
|
||||
cursor = await conn.execute("SELECT bots FROM app_settings WHERE id = 1")
|
||||
|
||||
Reference in New Issue
Block a user