From 1aa26c05d01b74eb25bb1260b7067c83cd2bce85 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Tue, 10 Feb 2026 22:05:59 -0800 Subject: [PATCH] Don't use prefix matching if we can help it --- app/event_handlers.py | 11 ++++- app/models.py | 4 +- app/packet_processor.py | 2 +- app/radio_sync.py | 10 ++++- app/repository.py | 67 ++++++++++++++++++++++++++----- app/routers/contacts.py | 57 +++++++++++++------------- app/routers/messages.py | 14 ++++++- app/routers/packets.py | 2 +- frontend/src/test/urlHash.test.ts | 5 +-- frontend/src/utils/urlHash.ts | 8 +++- tests/test_contacts_router.py | 30 ++++++++++++-- tests/test_key_normalization.py | 42 ++++++++++++++++++- tests/test_send_messages.py | 30 ++++++++++++++ 13 files changed, 227 insertions(+), 55 deletions(-) diff --git a/app/event_handlers.py b/app/event_handlers.py index dc56b54..6e1ee13 100644 --- a/app/event_handlers.py +++ b/app/event_handlers.py @@ -7,7 +7,7 @@ from meshcore import EventType from app.models import CONTACT_TYPE_REPEATER, Contact from app.packet_processor import process_raw_packet -from app.repository import ContactRepository, MessageRepository +from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository from app.websocket import broadcast_event if TYPE_CHECKING: @@ -74,7 +74,14 @@ async def on_contact_message(event: "Event") -> None: # Look up contact from database - use prefix lookup only if needed # (get_by_key_or_prefix does exact match first, then prefix fallback) - contact = await ContactRepository.get_by_key_or_prefix(sender_pubkey) + try: + contact = await ContactRepository.get_by_key_or_prefix(sender_pubkey) + except AmbiguousPublicKeyPrefixError: + logger.warning( + "DM sender prefix '%s' is ambiguous; storing under prefix until full key is known", + sender_pubkey, + ) + contact = None if contact: sender_pubkey = contact.public_key.lower() diff --git a/app/models.py b/app/models.py index d3b9be9..a3613bb 100644 --- a/app/models.py +++ b/app/models.py @@ -132,7 +132,9 @@ class SendMessageRequest(BaseModel): class SendDirectMessageRequest(SendMessageRequest): - destination: str = Field(description="Public key or prefix of recipient") + destination: str = Field( + description="Recipient public key (64-char hex preferred; prefix must resolve uniquely)" + ) class SendChannelMessageRequest(SendMessageRequest): diff --git a/app/packet_processor.py b/app/packet_processor.py index fec80ed..1755ab4 100644 --- a/app/packet_processor.py +++ b/app/packet_processor.py @@ -238,7 +238,7 @@ async def create_dm_message_from_decrypted( """ # Check if sender is a repeater - repeaters only send CLI responses, not chat messages. # CLI responses are handled by the command endpoint, not stored in chat history. - contact = await ContactRepository.get_by_key_or_prefix(their_public_key) + contact = await ContactRepository.get_by_key(their_public_key) if contact and contact.type == CONTACT_TYPE_REPEATER: logger.debug( "Skipping message from repeater %s (CLI responses not stored): %s", diff --git a/app/radio_sync.py b/app/radio_sync.py index 08e3945..467be8f 100644 --- a/app/radio_sync.py +++ b/app/radio_sync.py @@ -19,6 +19,7 @@ from meshcore import EventType from app.models import Contact from app.radio import RadioOperationBusyError, radio_manager from app.repository import ( + AmbiguousPublicKeyPrefixError, AppSettingsRepository, ChannelRepository, ContactRepository, @@ -585,7 +586,14 @@ async def sync_recent_contacts_to_radio(force: bool = False) -> dict: for favorite in app_settings.favorites: if favorite.type != "contact": continue - contact = await ContactRepository.get_by_key_or_prefix(favorite.id) + try: + contact = await ContactRepository.get_by_key_or_prefix(favorite.id) + except AmbiguousPublicKeyPrefixError: + logger.warning( + "Skipping favorite contact '%s': ambiguous key prefix; use full key", + favorite.id, + ) + continue if not contact: continue key = contact.public_key.lower() diff --git a/app/repository.py b/app/repository.py index 9d7b432..665d047 100644 --- a/app/repository.py +++ b/app/repository.py @@ -20,6 +20,15 @@ from app.models import ( logger = logging.getLogger(__name__) +class AmbiguousPublicKeyPrefixError(ValueError): + """Raised when a public key prefix matches multiple contacts.""" + + def __init__(self, prefix: str, matches: list[str]): + self.prefix = prefix.lower() + self.matches = matches + super().__init__(f"Ambiguous public key prefix '{self.prefix}'") + + class ContactRepository: @staticmethod async def upsert(contact: dict[str, Any]) -> None: @@ -89,12 +98,30 @@ class ContactRepository: @staticmethod async def get_by_key_prefix(prefix: str) -> Contact | None: + """Get a contact by key prefix only if it resolves uniquely. + + Returns None when no contacts match OR when multiple contacts match + the prefix (to avoid silently selecting the wrong contact). + """ + normalized_prefix = prefix.lower() cursor = await db.conn.execute( - "SELECT * FROM contacts WHERE public_key LIKE ? LIMIT 1", - (f"{prefix.lower()}%",), + "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2", + (f"{normalized_prefix}%",), ) - row = await cursor.fetchone() - return ContactRepository._row_to_contact(row) if row else None + rows = list(await cursor.fetchall()) + if len(rows) != 1: + return None + return ContactRepository._row_to_contact(rows[0]) + + @staticmethod + async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]: + """Get contacts matching a key prefix, up to limit.""" + cursor = await db.conn.execute( + "SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?", + (f"{prefix.lower()}%", limit), + ) + rows = list(await cursor.fetchall()) + return [ContactRepository._row_to_contact(row) for row in rows] @staticmethod async def get_by_key_or_prefix(key_or_prefix: str) -> Contact | None: @@ -103,9 +130,18 @@ class ContactRepository: Useful when the input might be a full 64-char public key or a shorter prefix. """ contact = await ContactRepository.get_by_key(key_or_prefix) - if not contact: - contact = await ContactRepository.get_by_key_prefix(key_or_prefix) - return contact + if contact: + return contact + + matches = await ContactRepository._get_prefix_matches(key_or_prefix, limit=2) + if len(matches) == 1: + return matches[0] + if len(matches) > 1: + raise AmbiguousPublicKeyPrefixError( + key_or_prefix, + [m.public_key for m in matches], + ) + return None @staticmethod async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]: @@ -416,9 +452,20 @@ class MessageRepository: query += " AND type = ?" params.append(msg_type) if conversation_key: - # Support both exact match and prefix match for DMs - query += " AND conversation_key LIKE ?" - params.append(f"{conversation_key}%") + normalized_key = conversation_key + # Prefer exact matching for full keys. + if len(conversation_key) == 64: + normalized_key = conversation_key.lower() + query += " AND conversation_key = ?" + params.append(normalized_key) + elif len(conversation_key) == 32: + normalized_key = conversation_key.upper() + query += " AND conversation_key = ?" + params.append(normalized_key) + else: + # Prefix match is only for legacy/partial key callers. + query += " AND conversation_key LIKE ?" + params.append(f"{conversation_key}%") if before is not None and before_id is not None: query += " AND (received_at < ? OR (received_at = ? AND id < ?))" diff --git a/app/routers/contacts.py b/app/routers/contacts.py index 80057dc..20c32a4 100644 --- a/app/routers/contacts.py +++ b/app/routers/contacts.py @@ -20,7 +20,7 @@ from app.models import ( ) from app.packet_processor import start_historical_dm_decryption from app.radio import radio_manager -from app.repository import ContactRepository, MessageRepository +from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository logger = logging.getLogger(__name__) @@ -37,6 +37,26 @@ router = APIRouter(prefix="/contacts", tags=["contacts"]) REPEATER_OP_DELAY_SECONDS = 2.0 +def _ambiguous_contact_detail(err: AmbiguousPublicKeyPrefixError) -> str: + sample = ", ".join(key[:12] for key in err.matches[:2]) + return ( + f"Ambiguous contact key prefix '{err.prefix}'. " + f"Use a full 64-character public key. Matching contacts: {sample}" + ) + + +async def _resolve_contact_or_404( + public_key: str, not_found_detail: str = "Contact not found" +) -> Contact: + try: + contact = await ContactRepository.get_by_key_or_prefix(public_key) + except AmbiguousPublicKeyPrefixError as err: + raise HTTPException(status_code=409, detail=_ambiguous_contact_detail(err)) from err + if not contact: + raise HTTPException(status_code=404, detail=not_found_detail) + return contact + + async def prepare_repeater_connection(mc, contact: Contact, password: str) -> None: """Prepare connection to a repeater by adding to radio and logging in. @@ -89,7 +109,7 @@ async def create_contact( raise HTTPException(status_code=400, detail="Invalid public key: must be valid hex") from e # Check if contact already exists - existing = await ContactRepository.get_by_key_or_prefix(request.public_key) + existing = await ContactRepository.get_by_key(request.public_key) if existing: # Update name if provided if request.name: @@ -153,10 +173,7 @@ async def create_contact( @router.get("/{public_key}", response_model=Contact) async def get_contact(public_key: str) -> Contact: """Get a specific contact by public key or prefix.""" - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found") - return contact + return await _resolve_contact_or_404(public_key) @router.post("/sync") @@ -189,9 +206,7 @@ async def remove_contact_from_radio(public_key: str) -> dict: """Remove a contact from the radio (keeps it in database).""" mc = require_connected() - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found") + contact = await _resolve_contact_or_404(public_key) # Get the contact from radio radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) @@ -216,9 +231,7 @@ async def add_contact_to_radio(public_key: str) -> dict: """Add a contact from the database to the radio.""" mc = require_connected() - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found in database") + contact = await _resolve_contact_or_404(public_key, "Contact not found in database") # Check if already on radio radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12]) @@ -239,9 +252,7 @@ async def add_contact_to_radio(public_key: str) -> dict: @router.post("/{public_key}/mark-read") async def mark_contact_read(public_key: str) -> dict: """Mark a contact conversation as read (update last_read_at timestamp).""" - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found") + contact = await _resolve_contact_or_404(public_key) updated = await ContactRepository.update_last_read_at(contact.public_key) if not updated: @@ -253,9 +264,7 @@ async def mark_contact_read(public_key: str) -> dict: @router.delete("/{public_key}") async def delete_contact(public_key: str) -> dict: """Delete a contact from the database (and radio if present).""" - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found") + contact = await _resolve_contact_or_404(public_key) # Remove from radio if connected and contact is on radio if radio_manager.is_connected and radio_manager.meshcore: @@ -282,9 +291,7 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem mc = require_connected() # Get contact from database - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found") + contact = await _resolve_contact_or_404(public_key) # Verify it's a repeater if contact.type != CONTACT_TYPE_REPEATER: @@ -459,9 +466,7 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com mc = require_connected() # Get contact from database - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found") + contact = await _resolve_contact_or_404(public_key) # Verify it's a repeater if contact.type != CONTACT_TYPE_REPEATER: @@ -540,9 +545,7 @@ async def request_trace(public_key: str) -> TraceResponse: """ mc = require_connected() - contact = await ContactRepository.get_by_key_or_prefix(public_key) - if not contact: - raise HTTPException(status_code=404, detail="Contact not found") + contact = await _resolve_contact_or_404(public_key) tag = random.randint(1, 0xFFFFFFFF) # First 2 hex chars of pubkey = 1-byte hash used by the trace protocol diff --git a/app/routers/messages.py b/app/routers/messages.py index 22ffbd5..41b701c 100644 --- a/app/routers/messages.py +++ b/app/routers/messages.py @@ -9,7 +9,7 @@ from app.dependencies import require_connected from app.event_handlers import track_pending_ack from app.models import Message, SendChannelMessageRequest, SendDirectMessageRequest from app.radio import radio_manager -from app.repository import MessageRepository +from app.repository import AmbiguousPublicKeyPrefixError, MessageRepository logger = logging.getLogger(__name__) router = APIRouter(prefix="/messages", tags=["messages"]) @@ -47,7 +47,17 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message: # First check our database for the contact from app.repository import ContactRepository - db_contact = await ContactRepository.get_by_key_or_prefix(request.destination) + try: + db_contact = await ContactRepository.get_by_key_or_prefix(request.destination) + except AmbiguousPublicKeyPrefixError as err: + sample = ", ".join(key[:12] for key in err.matches[:2]) + raise HTTPException( + status_code=409, + detail=( + f"Ambiguous destination key prefix '{err.prefix}'. " + f"Use a full 64-character public key. Matching contacts: {sample}" + ), + ) from err if not db_contact: raise HTTPException( status_code=404, detail=f"Contact not found in database: {request.destination}" diff --git a/app/routers/packets.py b/app/routers/packets.py index 395c492..554f5f3 100644 --- a/app/routers/packets.py +++ b/app/routers/packets.py @@ -211,7 +211,7 @@ async def decrypt_historical_packets( # Try to find contact name for display from app.repository import ContactRepository - contact = await ContactRepository.get_by_key_or_prefix(contact_public_key_hex) + contact = await ContactRepository.get_by_key(contact_public_key_hex) display_name = contact.name if contact else None background_tasks.add_task( diff --git a/frontend/src/test/urlHash.test.ts b/frontend/src/test/urlHash.test.ts index 4d71bd5..c29fc1a 100644 --- a/frontend/src/test/urlHash.test.ts +++ b/frontend/src/test/urlHash.test.ts @@ -289,10 +289,7 @@ describe('resolveChannelFromHashToken', () => { ]; it('prefers stable key lookup (case-insensitive)', () => { - const result = resolveChannelFromHashToken( - 'abcdef0123456789abcdef0123456789', - channels - ); + const result = resolveChannelFromHashToken('abcdef0123456789abcdef0123456789', channels); expect(result?.key).toBe('ABCDEF0123456789ABCDEF0123456789'); }); diff --git a/frontend/src/utils/urlHash.ts b/frontend/src/utils/urlHash.ts index 4657192..dede798 100644 --- a/frontend/src/utils/urlHash.ts +++ b/frontend/src/utils/urlHash.ts @@ -74,7 +74,9 @@ export function resolveChannelFromHashToken(token: string, channels: Channel[]): if (byKey) return byKey; // Backward compatibility for legacy name-based hashes. - return channels.find((c) => c.name === normalizedToken || c.name === `#${normalizedToken}`) || null; + return ( + channels.find((c) => c.name === normalizedToken || c.name === `#${normalizedToken}`) || null + ); } export function resolveContactFromHashToken(token: string, contacts: Contact[]): Contact | null { @@ -86,7 +88,9 @@ export function resolveContactFromHashToken(token: string, contacts: Contact[]): if (byKey) return byKey; // Backward compatibility for legacy name/prefix-based hashes. - return contacts.find((c) => getContactDisplayName(c.name, c.public_key) === normalizedToken) || null; + return ( + contacts.find((c) => getContactDisplayName(c.name, c.public_key) === normalizedToken) || null + ); } /** diff --git a/tests/test_contacts_router.py b/tests/test_contacts_router.py index dd0c35d..1b9b6d8 100644 --- a/tests/test_contacts_router.py +++ b/tests/test_contacts_router.py @@ -90,7 +90,7 @@ class TestCreateContact: with ( patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", + "app.routers.contacts.ContactRepository.get_by_key", new_callable=AsyncMock, return_value=None, ), @@ -123,7 +123,7 @@ class TestCreateContact: from fastapi.testclient import TestClient with patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", + "app.routers.contacts.ContactRepository.get_by_key", new_callable=AsyncMock, return_value=None, ): @@ -160,7 +160,7 @@ class TestCreateContact: with ( patch( - "app.routers.contacts.ContactRepository.get_by_key_or_prefix", + "app.routers.contacts.ContactRepository.get_by_key", new_callable=AsyncMock, return_value=existing, ), @@ -220,6 +220,30 @@ class TestGetContact: assert response.status_code == 404 + def test_get_ambiguous_prefix_returns_409(self): + from fastapi.testclient import TestClient + + from app.repository import AmbiguousPublicKeyPrefixError + + with patch( + "app.routers.contacts.ContactRepository.get_by_key_or_prefix", + new_callable=AsyncMock, + side_effect=AmbiguousPublicKeyPrefixError( + "abcd12", + [ + "abcd120000000000000000000000000000000000000000000000000000000000", + "abcd12ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + ], + ), + ): + from app.main import app + + client = TestClient(app) + response = client.get("/api/contacts/abcd12") + + assert response.status_code == 409 + assert "ambiguous" in response.json()["detail"].lower() + class TestMarkRead: """Test POST /api/contacts/{public_key}/mark-read.""" diff --git a/tests/test_key_normalization.py b/tests/test_key_normalization.py index 24ebdce..8017192 100644 --- a/tests/test_key_normalization.py +++ b/tests/test_key_normalization.py @@ -3,7 +3,7 @@ import pytest from app.database import Database -from app.repository import ContactRepository, MessageRepository +from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository @pytest.fixture @@ -117,3 +117,43 @@ async def test_duplicate_with_same_text_and_null_timestamp_rejected(test_db): received_at=received_at, ) assert msg_id2 is None # duplicate rejected + + +@pytest.mark.asyncio +async def test_get_by_key_prefix_returns_none_when_ambiguous(test_db): + """Ambiguous prefixes should not resolve to an arbitrary contact.""" + key1 = "abc1230000000000000000000000000000000000000000000000000000000000" + key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + + await ContactRepository.upsert({"public_key": key1, "name": "A"}) + await ContactRepository.upsert({"public_key": key2, "name": "B"}) + + contact = await ContactRepository.get_by_key_prefix("abc123") + assert contact is None + + +@pytest.mark.asyncio +async def test_get_by_key_or_prefix_raises_on_ambiguous_prefix(test_db): + """Prefix lookup should raise when multiple contacts match.""" + key1 = "abc1230000000000000000000000000000000000000000000000000000000000" + key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + + await ContactRepository.upsert({"public_key": key1, "name": "A"}) + await ContactRepository.upsert({"public_key": key2, "name": "B"}) + + with pytest.raises(AmbiguousPublicKeyPrefixError): + await ContactRepository.get_by_key_or_prefix("abc123") + + +@pytest.mark.asyncio +async def test_get_by_key_or_prefix_prefers_exact_full_key(test_db): + """Exact key lookup works even when the shorter prefix is ambiguous.""" + key1 = "abc1230000000000000000000000000000000000000000000000000000000000" + key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + + await ContactRepository.upsert({"public_key": key1, "name": "A"}) + await ContactRepository.upsert({"public_key": key2, "name": "B"}) + + contact = await ContactRepository.get_by_key_or_prefix(key2.upper()) + assert contact is not None + assert contact.public_key == key2 diff --git a/tests/test_send_messages.py b/tests/test_send_messages.py index 4e96b49..396105a 100644 --- a/tests/test_send_messages.py +++ b/tests/test_send_messages.py @@ -4,6 +4,7 @@ import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest +from fastapi import HTTPException from meshcore import EventType from app.models import ( @@ -13,6 +14,7 @@ from app.models import ( SendChannelMessageRequest, SendDirectMessageRequest, ) +from app.repository import AmbiguousPublicKeyPrefixError from app.routers.messages import send_channel_message, send_direct_message @@ -123,6 +125,34 @@ class TestOutgoingDMBotTrigger: call_kwargs = mock_bot.call_args[1] assert call_kwargs["sender_name"] is None + @pytest.mark.asyncio + async def test_send_dm_ambiguous_prefix_returns_409(self): + """Ambiguous destination prefix should fail instead of selecting a random contact.""" + mc = _make_mc() + + with ( + patch("app.routers.messages.require_connected", return_value=mc), + patch( + "app.repository.ContactRepository.get_by_key_or_prefix", + new=AsyncMock( + side_effect=AmbiguousPublicKeyPrefixError( + "abc123", + [ + "abc1230000000000000000000000000000000000000000000000000000000000", + "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + ], + ) + ), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await send_direct_message( + SendDirectMessageRequest(destination="abc123", text="Hello") + ) + + assert exc_info.value.status_code == 409 + assert "ambiguous" in exc_info.value.detail.lower() + class TestOutgoingChannelBotTrigger: """Test that sending a channel message triggers bots with is_outgoing=True."""