import json import time from typing import Any from app.database import db from app.models import Message, MessagePath class MessageRepository: @staticmethod def _parse_paths(paths_json: str | None) -> list[MessagePath] | None: """Parse paths JSON string to list of MessagePath objects.""" if not paths_json: return None try: paths_data = json.loads(paths_json) return [MessagePath(**p) for p in paths_data] except (json.JSONDecodeError, TypeError, KeyError): return None @staticmethod async def create( msg_type: str, text: str, received_at: int, conversation_key: str, sender_timestamp: int | None = None, path: str | None = None, txt_type: int = 0, signature: str | None = None, outgoing: bool = False, sender_name: str | None = None, sender_key: str | None = None, ) -> int | None: """Create a message, returning the ID or None if duplicate. Uses INSERT OR IGNORE to handle the UNIQUE constraint on (type, conversation_key, text, sender_timestamp). This prevents duplicate messages when the same message arrives via multiple RF paths. The path parameter is converted to the paths JSON array format. """ # Convert single path to paths array format paths_json = None if path is not None: paths_json = json.dumps([{"path": path, "received_at": received_at}]) cursor = await db.conn.execute( """ INSERT OR IGNORE INTO messages (type, conversation_key, text, sender_timestamp, received_at, paths, txt_type, signature, outgoing, sender_name, sender_key) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( msg_type, conversation_key, text, sender_timestamp, received_at, paths_json, txt_type, signature, outgoing, sender_name, sender_key, ), ) await db.conn.commit() # rowcount is 0 if INSERT was ignored due to UNIQUE constraint violation if cursor.rowcount == 0: return None return cursor.lastrowid @staticmethod async def add_path( message_id: int, path: str, received_at: int | None = None ) -> list[MessagePath]: """Add a new path to an existing message. This is used when a repeat/echo of a message arrives via a different route. Returns the updated list of paths. """ ts = received_at if received_at is not None else int(time.time()) # Atomic append: use json_insert to avoid read-modify-write race when # multiple duplicate packets arrive concurrently for the same message. new_entry = json.dumps({"path": path, "received_at": ts}) await db.conn.execute( """UPDATE messages SET paths = json_insert( COALESCE(paths, '[]'), '$[#]', json(?) ) WHERE id = ?""", (new_entry, message_id), ) await db.conn.commit() # Read back the full list for the return value cursor = await db.conn.execute("SELECT paths FROM messages WHERE id = ?", (message_id,)) row = await cursor.fetchone() if not row or not row["paths"]: return [] try: all_paths = json.loads(row["paths"]) except json.JSONDecodeError: return [] return [MessagePath(**p) for p in all_paths] @staticmethod async def claim_prefix_messages(full_key: str) -> int: """Promote prefix-stored messages to the full conversation key. When a full key becomes known for a contact, any messages stored with only a prefix as conversation_key are updated to use the full key. """ lower_key = full_key.lower() cursor = await db.conn.execute( """UPDATE messages SET conversation_key = ? WHERE type = 'PRIV' AND length(conversation_key) < 64 AND ? LIKE conversation_key || '%' AND ( SELECT COUNT(*) FROM contacts WHERE public_key LIKE messages.conversation_key || '%' ) = 1""", (lower_key, lower_key), ) await db.conn.commit() return cursor.rowcount @staticmethod async def get_all( limit: int = 100, offset: int = 0, msg_type: str | None = None, conversation_key: str | None = None, before: int | None = None, before_id: int | None = None, ) -> list[Message]: query = "SELECT * FROM messages WHERE 1=1" params: list[Any] = [] if msg_type: query += " AND type = ?" params.append(msg_type) if conversation_key: normalized_key = conversation_key # Prefer exact matching for full keys. if len(conversation_key) == 64: normalized_key = conversation_key.lower() query += " AND conversation_key = ?" params.append(normalized_key) elif len(conversation_key) == 32: normalized_key = conversation_key.upper() query += " AND conversation_key = ?" params.append(normalized_key) else: # Prefix match is only for legacy/partial key callers. query += " AND conversation_key LIKE ?" params.append(f"{conversation_key}%") if before is not None and before_id is not None: query += " AND (received_at < ? OR (received_at = ? AND id < ?))" params.extend([before, before, before_id]) query += " ORDER BY received_at DESC, id DESC LIMIT ?" params.append(limit) if before is None or before_id is None: query += " OFFSET ?" params.append(offset) cursor = await db.conn.execute(query, params) rows = await cursor.fetchall() return [ Message( id=row["id"], type=row["type"], conversation_key=row["conversation_key"], text=row["text"], sender_timestamp=row["sender_timestamp"], received_at=row["received_at"], paths=MessageRepository._parse_paths(row["paths"]), txt_type=row["txt_type"], signature=row["signature"], outgoing=bool(row["outgoing"]), acked=row["acked"], ) for row in rows ] @staticmethod async def increment_ack_count(message_id: int) -> int: """Increment ack count and return the new value.""" await db.conn.execute("UPDATE messages SET acked = acked + 1 WHERE id = ?", (message_id,)) await db.conn.commit() cursor = await db.conn.execute("SELECT acked FROM messages WHERE id = ?", (message_id,)) row = await cursor.fetchone() return row["acked"] if row else 1 @staticmethod async def get_ack_and_paths(message_id: int) -> tuple[int, list[MessagePath] | None]: """Get the current ack count and paths for a message.""" cursor = await db.conn.execute( "SELECT acked, paths FROM messages WHERE id = ?", (message_id,) ) row = await cursor.fetchone() if not row: return 0, None return row["acked"], MessageRepository._parse_paths(row["paths"]) @staticmethod async def get_by_id(message_id: int) -> "Message | None": """Look up a message by its ID.""" cursor = await db.conn.execute( """ SELECT id, type, conversation_key, text, sender_timestamp, received_at, paths, txt_type, signature, outgoing, acked FROM messages WHERE id = ? """, (message_id,), ) row = await cursor.fetchone() if not row: return None return Message( id=row["id"], type=row["type"], conversation_key=row["conversation_key"], text=row["text"], sender_timestamp=row["sender_timestamp"], received_at=row["received_at"], paths=MessageRepository._parse_paths(row["paths"]), txt_type=row["txt_type"], signature=row["signature"], outgoing=bool(row["outgoing"]), acked=row["acked"], ) @staticmethod async def get_by_content( msg_type: str, conversation_key: str, text: str, sender_timestamp: int | None, ) -> "Message | None": """Look up a message by its unique content fields.""" cursor = await db.conn.execute( """ SELECT id, type, conversation_key, text, sender_timestamp, received_at, paths, txt_type, signature, outgoing, acked FROM messages WHERE type = ? AND conversation_key = ? AND text = ? AND (sender_timestamp = ? OR (sender_timestamp IS NULL AND ? IS NULL)) """, (msg_type, conversation_key, text, sender_timestamp, sender_timestamp), ) row = await cursor.fetchone() if not row: return None paths = None if row["paths"]: try: paths_data = json.loads(row["paths"]) paths = [ MessagePath(path=p["path"], received_at=p["received_at"]) for p in paths_data ] except (json.JSONDecodeError, KeyError): pass return Message( id=row["id"], type=row["type"], conversation_key=row["conversation_key"], text=row["text"], sender_timestamp=row["sender_timestamp"], received_at=row["received_at"], paths=paths, txt_type=row["txt_type"], signature=row["signature"], outgoing=bool(row["outgoing"]), acked=row["acked"], ) @staticmethod async def get_unread_counts(name: str | None = None) -> dict: """Get unread message counts, mention flags, and last message times for all conversations. Args: name: User's display name for @[name] mention detection. If None, mentions are skipped. Returns: Dict with 'counts', 'mentions', and 'last_message_times' keys. """ counts: dict[str, int] = {} mention_flags: dict[str, bool] = {} last_message_times: dict[str, int] = {} mention_token = f"@[{name}]" if name else None # Channel unreads cursor = await db.conn.execute( """ SELECT m.conversation_key, COUNT(*) as unread_count, SUM(CASE WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 ELSE 0 END) > 0 as has_mention FROM messages m JOIN channels c ON m.conversation_key = c.key WHERE m.type = 'CHAN' AND m.outgoing = 0 AND m.received_at > COALESCE(c.last_read_at, 0) GROUP BY m.conversation_key """, (mention_token or "", mention_token or ""), ) rows = await cursor.fetchall() for row in rows: state_key = f"channel-{row['conversation_key']}" counts[state_key] = row["unread_count"] if mention_token and row["has_mention"]: mention_flags[state_key] = True # Contact unreads cursor = await db.conn.execute( """ SELECT m.conversation_key, COUNT(*) as unread_count, SUM(CASE WHEN ? <> '' AND INSTR(LOWER(m.text), LOWER(?)) > 0 THEN 1 ELSE 0 END) > 0 as has_mention FROM messages m JOIN contacts ct ON m.conversation_key = ct.public_key WHERE m.type = 'PRIV' AND m.outgoing = 0 AND m.received_at > COALESCE(ct.last_read_at, 0) GROUP BY m.conversation_key """, (mention_token or "", mention_token or ""), ) rows = await cursor.fetchall() for row in rows: state_key = f"contact-{row['conversation_key']}" counts[state_key] = row["unread_count"] if mention_token and row["has_mention"]: mention_flags[state_key] = True # Last message times for all conversations (including read ones) cursor = await db.conn.execute( """ SELECT type, conversation_key, MAX(received_at) as last_message_time FROM messages GROUP BY type, conversation_key """ ) rows = await cursor.fetchall() for row in rows: prefix = "channel" if row["type"] == "CHAN" else "contact" state_key = f"{prefix}-{row['conversation_key']}" last_message_times[state_key] = row["last_message_time"] return { "counts": counts, "mentions": mention_flags, "last_message_times": last_message_times, } @staticmethod async def count_dm_messages(contact_key: str) -> int: """Count total DM messages for a contact.""" cursor = await db.conn.execute( "SELECT COUNT(*) as cnt FROM messages WHERE type = 'PRIV' AND conversation_key = ?", (contact_key.lower(),), ) row = await cursor.fetchone() return row["cnt"] if row else 0 @staticmethod async def count_channel_messages_by_sender(sender_key: str) -> int: """Count channel messages sent by a specific contact.""" cursor = await db.conn.execute( "SELECT COUNT(*) as cnt FROM messages WHERE type = 'CHAN' AND sender_key = ?", (sender_key.lower(),), ) row = await cursor.fetchone() return row["cnt"] if row else 0 @staticmethod async def get_most_active_rooms(sender_key: str, limit: int = 5) -> list[tuple[str, str, int]]: """Get channels where a contact has sent the most messages. Returns list of (channel_key, channel_name, message_count) tuples. """ cursor = await db.conn.execute( """ SELECT m.conversation_key, COALESCE(c.name, m.conversation_key) AS channel_name, COUNT(*) AS cnt FROM messages m LEFT JOIN channels c ON m.conversation_key = c.key WHERE m.type = 'CHAN' AND m.sender_key = ? GROUP BY m.conversation_key ORDER BY cnt DESC LIMIT ? """, (sender_key.lower(), limit), ) rows = await cursor.fetchall() return [(row["conversation_key"], row["channel_name"], row["cnt"]) for row in rows]