mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-03-28 17:43:05 +01:00
412 lines
15 KiB
Python
412 lines
15 KiB
Python
import json
|
|
import time
|
|
from typing import Any
|
|
|
|
from app.database import db
|
|
from app.models import Message, MessagePath
|
|
|
|
|
|
class MessageRepository:
|
|
@staticmethod
|
|
def _parse_paths(paths_json: str | None) -> list[MessagePath] | None:
|
|
"""Parse paths JSON string to list of MessagePath objects."""
|
|
if not paths_json:
|
|
return None
|
|
try:
|
|
paths_data = json.loads(paths_json)
|
|
return [MessagePath(**p) for p in paths_data]
|
|
except (json.JSONDecodeError, TypeError, KeyError):
|
|
return None
|
|
|
|
@staticmethod
|
|
async def create(
|
|
msg_type: str,
|
|
text: str,
|
|
received_at: int,
|
|
conversation_key: str,
|
|
sender_timestamp: int | None = None,
|
|
path: str | None = None,
|
|
txt_type: int = 0,
|
|
signature: str | None = None,
|
|
outgoing: bool = False,
|
|
sender_name: str | None = None,
|
|
sender_key: str | None = None,
|
|
) -> int | None:
|
|
"""Create a message, returning the ID or None if duplicate.
|
|
|
|
Uses INSERT OR IGNORE to handle the UNIQUE constraint on
|
|
(type, conversation_key, text, sender_timestamp). This prevents
|
|
duplicate messages when the same message arrives via multiple RF paths.
|
|
|
|
The path parameter is converted to the paths JSON array format.
|
|
"""
|
|
# Convert single path to paths array format
|
|
paths_json = None
|
|
if path is not None:
|
|
paths_json = json.dumps([{"path": path, "received_at": received_at}])
|
|
|
|
cursor = await db.conn.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp,
|
|
received_at, paths, txt_type, signature, outgoing,
|
|
sender_name, sender_key)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
msg_type,
|
|
conversation_key,
|
|
text,
|
|
sender_timestamp,
|
|
received_at,
|
|
paths_json,
|
|
txt_type,
|
|
signature,
|
|
outgoing,
|
|
sender_name,
|
|
sender_key,
|
|
),
|
|
)
|
|
await db.conn.commit()
|
|
# rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation
|
|
if cursor.rowcount == 0:
|
|
return None
|
|
return cursor.lastrowid
|
|
|
|
@staticmethod
|
|
async def add_path(
|
|
message_id: int, path: str, received_at: int | None = None
|
|
) -> list[MessagePath]:
|
|
"""Add a new path to an existing message.
|
|
|
|
This is used when a repeat/echo of a message arrives via a different route.
|
|
Returns the updated list of paths.
|
|
"""
|
|
ts = received_at if received_at is not None else int(time.time())
|
|
|
|
# Atomic append: use json_insert to avoid read-modify-write race when
|
|
# multiple duplicate packets arrive concurrently for the same message.
|
|
new_entry = json.dumps({"path": path, "received_at": ts})
|
|
await db.conn.execute(
|
|
"""UPDATE messages SET paths = json_insert(
|
|
COALESCE(paths, '[]'), '$[#]', json(?)
|
|
) WHERE id = ?""",
|
|
(new_entry, message_id),
|
|
)
|
|
await db.conn.commit()
|
|
|
|
# Read back the full list for the return value
|
|
cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,))
|
|
row = await cursor.fetchone()
|
|
if not row or not row["paths"]:
|
|
return []
|
|
|
|
try:
|
|
all_paths = json.loads(row["paths"])
|
|
except json.JSONDecodeError:
|
|
return []
|
|
|
|
return [MessagePath(**p) for p in all_paths]
|
|
|
|
@staticmethod
|
|
async def claim_prefix_messages(full_key: str) -> int:
|
|
"""Promote prefix-stored messages to the full conversation key.
|
|
|
|
When a full key becomes known for a contact, any messages stored with
|
|
only a prefix as conversation_key are updated to use the full key.
|
|
"""
|
|
lower_key = full_key.lower()
|
|
cursor = await db.conn.execute(
|
|
"""UPDATE messages SET conversation_key = ?
|
|
WHERE type = 'PRIV' AND length(conversation_key) < 64
|
|
AND ? LIKE conversation_key || '%'
|
|
AND (
|
|
SELECT COUNT(*) FROM contacts
|
|
WHERE public_key LIKE messages.conversation_key || '%'
|
|
) = 1""",
|
|
(lower_key, lower_key),
|
|
)
|
|
await db.conn.commit()
|
|
return cursor.rowcount
|
|
|
|
@staticmethod
|
|
async def get_all(
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
msg_type: str | None = None,
|
|
conversation_key: str | None = None,
|
|
before: int | None = None,
|
|
before_id: int | None = None,
|
|
) -> list[Message]:
|
|
query = "SELECT * FROM messages WHERE 1=1"
|
|
params: list[Any] = []
|
|
|
|
if msg_type:
|
|
query += " AND type = ?"
|
|
params.append(msg_type)
|
|
if conversation_key:
|
|
normalized_key = conversation_key
|
|
# Prefer exact matching for full keys.
|
|
if len(conversation_key) == 64:
|
|
normalized_key = conversation_key.lower()
|
|
query += " AND conversation_key = ?"
|
|
params.append(normalized_key)
|
|
elif len(conversation_key) == 32:
|
|
normalized_key = conversation_key.upper()
|
|
query += " AND conversation_key = ?"
|
|
params.append(normalized_key)
|
|
else:
|
|
# Prefix match is only for legacy/partial key callers.
|
|
query += " AND conversation_key LIKE ?"
|
|
params.append(f"{conversation_key}%")
|
|
|
|
if before is not None and before_id is not None:
|
|
query += " AND (received_at < ? OR (received_at = ? AND id < ?))"
|
|
params.extend([before, before, before_id])
|
|
|
|
query += " ORDER BY received_at DESC, id DESC LIMIT ?"
|
|
params.append(limit)
|
|
if before is None or before_id is None:
|
|
query += " OFFSET ?"
|
|
params.append(offset)
|
|
|
|
cursor = await db.conn.execute(query, params)
|
|
rows = await cursor.fetchall()
|
|
return [
|
|
Message(
|
|
id=row["id"],
|
|
type=row["type"],
|
|
conversation_key=row["conversation_key"],
|
|
text=row["text"],
|
|
sender_timestamp=row["sender_timestamp"],
|
|
received_at=row["received_at"],
|
|
paths=MessageRepository._parse_paths(row["paths"]),
|
|
txt_type=row["txt_type"],
|
|
signature=row["signature"],
|
|
outgoing=bool(row["outgoing"]),
|
|
acked=row["acked"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
@staticmethod
|
|
async def increment_ack_count(message_id: int) -> int:
|
|
"""Increment ack count and return the new value."""
|
|
await db.conn.execute("UPDATE messages SET acked = acked + 1 WHERE id = ?", (message_id,))
|
|
await db.conn.commit()
|
|
cursor = await db.conn.execute("SELECT acked FROM messages WHERE id = ?", (message_id,))
|
|
row = await cursor.fetchone()
|
|
return row["acked"] if row else 1
|
|
|
|
@staticmethod
|
|
async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]:
|
|
"""Get the current ack count and paths for a message."""
|
|
cursor = await db.conn.execute(
|
|
"SELECT acked, paths FROM messages WHERE id = ?", (message_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
if not row:
|
|
return 0, None
|
|
return row["acked"], MessageRepository._parse_paths(row["paths"])
|
|
|
|
@staticmethod
|
|
async def get_by_id(message_id: int) -> "Message | None":
|
|
"""Look up a message by its ID."""
|
|
cursor = await db.conn.execute(
|
|
"""
|
|
SELECT id, type, conversation_key, text, sender_timestamp, received_at,
|
|
paths, txt_type, signature, outgoing, acked
|
|
FROM messages
|
|
WHERE id = ?
|
|
""",
|
|
(message_id,),
|
|
)
|
|
row = await cursor.fetchone()
|
|
if not row:
|
|
return None
|
|
|
|
return Message(
|
|
id=row["id"],
|
|
type=row["type"],
|
|
conversation_key=row["conversation_key"],
|
|
text=row["text"],
|
|
sender_timestamp=row["sender_timestamp"],
|
|
received_at=row["received_at"],
|
|
paths=MessageRepository._parse_paths(row["paths"]),
|
|
txt_type=row["txt_type"],
|
|
signature=row["signature"],
|
|
outgoing=bool(row["outgoing"]),
|
|
acked=row["acked"],
|
|
)
|
|
|
|
@staticmethod
|
|
async def get_by_content(
|
|
msg_type: str,
|
|
conversation_key: str,
|
|
text: str,
|
|
sender_timestamp: int | None,
|
|
) -> "Message | None":
|
|
"""Look up a message by its unique content fields."""
|
|
cursor = await db.conn.execute(
|
|
"""
|
|
SELECT id, type, conversation_key, text, sender_timestamp, received_at,
|
|
paths, txt_type, signature, outgoing, acked
|
|
FROM messages
|
|
WHERE type = ? AND conversation_key = ? AND text = ?
|
|
AND (sender_timestamp = ? OR (sender_timestamp IS NULL AND ? IS NULL))
|
|
""",
|
|
(msg_type, conversation_key, text, sender_timestamp, sender_timestamp),
|
|
)
|
|
row = await cursor.fetchone()
|
|
if not row:
|
|
return None
|
|
|
|
paths = None
|
|
if row["paths"]:
|
|
try:
|
|
paths_data = json.loads(row["paths"])
|
|
paths = [
|
|
MessagePath(path=p["path"], received_at=p["received_at"]) for p in paths_data
|
|
]
|
|
except (json.JSONDecodeError, KeyError):
|
|
pass
|
|
|
|
return Message(
|
|
id=row["id"],
|
|
type=row["type"],
|
|
conversation_key=row["conversation_key"],
|
|
text=row["text"],
|
|
sender_timestamp=row["sender_timestamp"],
|
|
received_at=row["received_at"],
|
|
paths=paths,
|
|
txt_type=row["txt_type"],
|
|
signature=row["signature"],
|
|
outgoing=bool(row["outgoing"]),
|
|
acked=row["acked"],
|
|
)
|
|
|
|
@staticmethod
|
|
async def get_unread_counts(name: str | None = None) -> dict:
|
|
"""Get unread message counts, mention flags, and last message times for all conversations.
|
|
|
|
Args:
|
|
name: User's display name for @[name] mention detection. If None, mentions are skipped.
|
|
|
|
Returns:
|
|
Dict with 'counts', 'mentions', and 'last_message_times' keys.
|
|
"""
|
|
counts: dict[str, int] = {}
|
|
mention_flags: dict[str, bool] = {}
|
|
last_message_times: dict[str, int] = {}
|
|
|
|
mention_token = f"@[{name}]" if name else None
|
|
|
|
# Channel unreads
|
|
cursor = await db.conn.execute(
|
|
"""
|
|
SELECT m.conversation_key,
|
|
COUNT(*) as unread_count,
|
|
SUM(CASE
|
|
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
|
|
ELSE 0
|
|
END) > 0 as has_mention
|
|
FROM messages m
|
|
JOIN channels c ON m.conversation_key = c.key
|
|
WHERE m.type = 'CHAN' AND m.outgoing = 0
|
|
AND m.received_at > COALESCE(c.last_read_at, 0)
|
|
GROUP BY m.conversation_key
|
|
""",
|
|
(mention_token or "", mention_token or ""),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
for row in rows:
|
|
state_key = f"channel-{row['conversation_key']}"
|
|
counts[state_key] = row["unread_count"]
|
|
if mention_token and row["has_mention"]:
|
|
mention_flags[state_key] = True
|
|
|
|
# Contact unreads
|
|
cursor = await db.conn.execute(
|
|
"""
|
|
SELECT m.conversation_key,
|
|
COUNT(*) as unread_count,
|
|
SUM(CASE
|
|
WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1
|
|
ELSE 0
|
|
END) > 0 as has_mention
|
|
FROM messages m
|
|
JOIN contacts ct ON m.conversation_key = ct.public_key
|
|
WHERE m.type = 'PRIV' AND m.outgoing = 0
|
|
AND m.received_at > COALESCE(ct.last_read_at, 0)
|
|
GROUP BY m.conversation_key
|
|
""",
|
|
(mention_token or "", mention_token or ""),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
for row in rows:
|
|
state_key = f"contact-{row['conversation_key']}"
|
|
counts[state_key] = row["unread_count"]
|
|
if mention_token and row["has_mention"]:
|
|
mention_flags[state_key] = True
|
|
|
|
# Last message times for all conversations (including read ones)
|
|
cursor = await db.conn.execute(
|
|
"""
|
|
SELECT type, conversation_key, MAX(received_at) as last_message_time
|
|
FROM messages
|
|
GROUP BY type, conversation_key
|
|
"""
|
|
)
|
|
rows = await cursor.fetchall()
|
|
for row in rows:
|
|
prefix = "channel" if row["type"] == "CHAN" else "contact"
|
|
state_key = f"{prefix}-{row['conversation_key']}"
|
|
last_message_times[state_key] = row["last_message_time"]
|
|
|
|
return {
|
|
"counts": counts,
|
|
"mentions": mention_flags,
|
|
"last_message_times": last_message_times,
|
|
}
|
|
|
|
@staticmethod
|
|
async def count_dm_messages(contact_key: str) -> int:
|
|
"""Count total DM messages for a contact."""
|
|
cursor = await db.conn.execute(
|
|
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?",
|
|
(contact_key.lower(),),
|
|
)
|
|
row = await cursor.fetchone()
|
|
return row["cnt"] if row else 0
|
|
|
|
@staticmethod
|
|
async def count_channel_messages_by_sender(sender_key: str) -> int:
|
|
"""Count channel messages sent by a specific contact."""
|
|
cursor = await db.conn.execute(
|
|
"SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?",
|
|
(sender_key.lower(),),
|
|
)
|
|
row = await cursor.fetchone()
|
|
return row["cnt"] if row else 0
|
|
|
|
@staticmethod
|
|
async def get_most_active_rooms(sender_key: str, limit: int = 5) -> list[tuple[str, str, int]]:
|
|
"""Get channels where a contact has sent the most messages.
|
|
|
|
Returns list of (channel_key, channel_name, message_count) tuples.
|
|
"""
|
|
cursor = await db.conn.execute(
|
|
"""
|
|
SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name,
|
|
COUNT(*) AS cnt
|
|
FROM messages m
|
|
LEFT JOIN channels c ON m.conversation_key = c.key
|
|
WHERE m.type = 'CHAN' AND m.sender_key = ?
|
|
GROUP BY m.conversation_key
|
|
ORDER BY cnt DESC
|
|
LIMIT ?
|
|
""",
|
|
(sender_key.lower(), limit),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows]
|