Don't use prefix matching if we can help it

This commit is contained in:
Jack Kingsman
2026-02-10 22:05:59 -08:00
parent bfdccc4a94
commit 1aa26c05d0
13 changed files with 227 additions and 55 deletions

View File

@@ -7,7 +7,7 @@ from meshcore import EventType
from app.models import CONTACT_TYPE_REPEATER, Contact
from app.packet_processor import process_raw_packet
from app.repository import ContactRepository, MessageRepository
from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository
from app.websocket import broadcast_event
if TYPE_CHECKING:
@@ -74,7 +74,14 @@ async def on_contact_message(event: "Event") -> None:
# Look up contact from database - use prefix lookup only if needed
# (get_by_key_or_prefix does exact match first, then prefix fallback)
contact = await ContactRepository.get_by_key_or_prefix(sender_pubkey)
try:
contact = await ContactRepository.get_by_key_or_prefix(sender_pubkey)
except AmbiguousPublicKeyPrefixError:
logger.warning(
"DM sender prefix '%s' is ambiguous; storing under prefix until full key is known",
sender_pubkey,
)
contact = None
if contact:
sender_pubkey = contact.public_key.lower()

View File

@@ -132,7 +132,9 @@ class SendMessageRequest(BaseModel):
class SendDirectMessageRequest(SendMessageRequest):
destination: str = Field(description="Public key or prefix of recipient")
destination: str = Field(
description="Recipient public key (64-char hex preferred; prefix must resolve uniquely)"
)
class SendChannelMessageRequest(SendMessageRequest):

View File

@@ -238,7 +238,7 @@ async def create_dm_message_from_decrypted(
"""
# Check if sender is a repeater - repeaters only send CLI responses, not chat messages.
# CLI responses are handled by the command endpoint, not stored in chat history.
contact = await ContactRepository.get_by_key_or_prefix(their_public_key)
contact = await ContactRepository.get_by_key(their_public_key)
if contact and contact.type == CONTACT_TYPE_REPEATER:
logger.debug(
"Skipping message from repeater %s (CLI responses not stored): %s",

View File

@@ -19,6 +19,7 @@ from meshcore import EventType
from app.models import Contact
from app.radio import RadioOperationBusyError, radio_manager
from app.repository import (
AmbiguousPublicKeyPrefixError,
AppSettingsRepository,
ChannelRepository,
ContactRepository,
@@ -585,7 +586,14 @@ async def sync_recent_contacts_to_radio(force: bool = False) -> dict:
for favorite in app_settings.favorites:
if favorite.type != "contact":
continue
contact = await ContactRepository.get_by_key_or_prefix(favorite.id)
try:
contact = await ContactRepository.get_by_key_or_prefix(favorite.id)
except AmbiguousPublicKeyPrefixError:
logger.warning(
"Skipping favorite contact '%s': ambiguous key prefix; use full key",
favorite.id,
)
continue
if not contact:
continue
key = contact.public_key.lower()

View File

@@ -20,6 +20,15 @@ from app.models import (
logger = logging.getLogger(__name__)
class AmbiguousPublicKeyPrefixError(ValueError):
"""Raised when a public key prefix matches multiple contacts."""
def __init__(self, prefix: str, matches: list[str]):
self.prefix = prefix.lower()
self.matches = matches
super().__init__(f"Ambiguous public key prefix '{self.prefix}'")
class ContactRepository:
@staticmethod
async def upsert(contact: dict[str, Any]) -> None:
@@ -89,12 +98,30 @@ class ContactRepository:
@staticmethod
async def get_by_key_prefix(prefix: str) -> Contact | None:
"""Get a contact by key prefix only if it resolves uniquely.
Returns None when no contacts match OR when multiple contacts match
the prefix (to avoid silently selecting the wrong contact).
"""
normalized_prefix = prefix.lower()
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key LIKE ? LIMIT 1",
(f"{prefix.lower()}%",),
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2",
(f"{normalized_prefix}%",),
)
row = await cursor.fetchone()
return ContactRepository._row_to_contact(row) if row else None
rows = list(await cursor.fetchall())
if len(rows) != 1:
return None
return ContactRepository._row_to_contact(rows[0])
@staticmethod
async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]:
"""Get contacts matching a key prefix, up to limit."""
cursor = await db.conn.execute(
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?",
(f"{prefix.lower()}%", limit),
)
rows = list(await cursor.fetchall())
return [ContactRepository._row_to_contact(row) for row in rows]
@staticmethod
async def get_by_key_or_prefix(key_or_prefix: str) -> Contact | None:
@@ -103,9 +130,18 @@ class ContactRepository:
Useful when the input might be a full 64-char public key or a shorter prefix.
"""
contact = await ContactRepository.get_by_key(key_or_prefix)
if not contact:
contact = await ContactRepository.get_by_key_prefix(key_or_prefix)
return contact
if contact:
return contact
matches = await ContactRepository._get_prefix_matches(key_or_prefix, limit=2)
if len(matches) == 1:
return matches[0]
if len(matches) > 1:
raise AmbiguousPublicKeyPrefixError(
key_or_prefix,
[m.public_key for m in matches],
)
return None
@staticmethod
async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]:
@@ -416,9 +452,20 @@ class MessageRepository:
query += " AND type = ?"
params.append(msg_type)
if conversation_key:
# Support both exact match and prefix match for DMs
query += " AND conversation_key LIKE ?"
params.append(f"{conversation_key}%")
normalized_key = conversation_key
# Prefer exact matching for full keys.
if len(conversation_key) == 64:
normalized_key = conversation_key.lower()
query += " AND conversation_key = ?"
params.append(normalized_key)
elif len(conversation_key) == 32:
normalized_key = conversation_key.upper()
query += " AND conversation_key = ?"
params.append(normalized_key)
else:
# Prefix match is only for legacy/partial key callers.
query += " AND conversation_key LIKE ?"
params.append(f"{conversation_key}%")
if before is not None and before_id is not None:
query += " AND (received_at < ? OR (received_at = ? AND id < ?))"

View File

@@ -20,7 +20,7 @@ from app.models import (
)
from app.packet_processor import start_historical_dm_decryption
from app.radio import radio_manager
from app.repository import ContactRepository, MessageRepository
from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository
logger = logging.getLogger(__name__)
@@ -37,6 +37,26 @@ router = APIRouter(prefix="/contacts", tags=["contacts"])
REPEATER_OP_DELAY_SECONDS = 2.0
def _ambiguous_contact_detail(err: AmbiguousPublicKeyPrefixError) -> str:
sample = ", ".join(key[:12] for key in err.matches[:2])
return (
f"Ambiguous contact key prefix '{err.prefix}'. "
f"Use a full 64-character public key. Matching contacts: {sample}"
)
async def _resolve_contact_or_404(
public_key: str, not_found_detail: str = "Contact not found"
) -> Contact:
try:
contact = await ContactRepository.get_by_key_or_prefix(public_key)
except AmbiguousPublicKeyPrefixError as err:
raise HTTPException(status_code=409, detail=_ambiguous_contact_detail(err)) from err
if not contact:
raise HTTPException(status_code=404, detail=not_found_detail)
return contact
async def prepare_repeater_connection(mc, contact: Contact, password: str) -> None:
"""Prepare connection to a repeater by adding to radio and logging in.
@@ -89,7 +109,7 @@ async def create_contact(
raise HTTPException(status_code=400, detail="Invalid public key: must be valid hex") from e
# Check if contact already exists
existing = await ContactRepository.get_by_key_or_prefix(request.public_key)
existing = await ContactRepository.get_by_key(request.public_key)
if existing:
# Update name if provided
if request.name:
@@ -153,10 +173,7 @@ async def create_contact(
@router.get("/{public_key}", response_model=Contact)
async def get_contact(public_key: str) -> Contact:
"""Get a specific contact by public key or prefix."""
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
return contact
return await _resolve_contact_or_404(public_key)
@router.post("/sync")
@@ -189,9 +206,7 @@ async def remove_contact_from_radio(public_key: str) -> dict:
"""Remove a contact from the radio (keeps it in database)."""
mc = require_connected()
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
contact = await _resolve_contact_or_404(public_key)
# Get the contact from radio
radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12])
@@ -216,9 +231,7 @@ async def add_contact_to_radio(public_key: str) -> dict:
"""Add a contact from the database to the radio."""
mc = require_connected()
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found in database")
contact = await _resolve_contact_or_404(public_key, "Contact not found in database")
# Check if already on radio
radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12])
@@ -239,9 +252,7 @@ async def add_contact_to_radio(public_key: str) -> dict:
@router.post("/{public_key}/mark-read")
async def mark_contact_read(public_key: str) -> dict:
"""Mark a contact conversation as read (update last_read_at timestamp)."""
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
contact = await _resolve_contact_or_404(public_key)
updated = await ContactRepository.update_last_read_at(contact.public_key)
if not updated:
@@ -253,9 +264,7 @@ async def mark_contact_read(public_key: str) -> dict:
@router.delete("/{public_key}")
async def delete_contact(public_key: str) -> dict:
"""Delete a contact from the database (and radio if present)."""
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
contact = await _resolve_contact_or_404(public_key)
# Remove from radio if connected and contact is on radio
if radio_manager.is_connected and radio_manager.meshcore:
@@ -282,9 +291,7 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem
mc = require_connected()
# Get contact from database
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
contact = await _resolve_contact_or_404(public_key)
# Verify it's a repeater
if contact.type != CONTACT_TYPE_REPEATER:
@@ -459,9 +466,7 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com
mc = require_connected()
# Get contact from database
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
contact = await _resolve_contact_or_404(public_key)
# Verify it's a repeater
if contact.type != CONTACT_TYPE_REPEATER:
@@ -540,9 +545,7 @@ async def request_trace(public_key: str) -> TraceResponse:
"""
mc = require_connected()
contact = await ContactRepository.get_by_key_or_prefix(public_key)
if not contact:
raise HTTPException(status_code=404, detail="Contact not found")
contact = await _resolve_contact_or_404(public_key)
tag = random.randint(1, 0xFFFFFFFF)
# First 2 hex chars of pubkey = 1-byte hash used by the trace protocol

View File

@@ -9,7 +9,7 @@ from app.dependencies import require_connected
from app.event_handlers import track_pending_ack
from app.models import Message, SendChannelMessageRequest, SendDirectMessageRequest
from app.radio import radio_manager
from app.repository import MessageRepository
from app.repository import AmbiguousPublicKeyPrefixError, MessageRepository
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/messages", tags=["messages"])
@@ -47,7 +47,17 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
# First check our database for the contact
from app.repository import ContactRepository
db_contact = await ContactRepository.get_by_key_or_prefix(request.destination)
try:
db_contact = await ContactRepository.get_by_key_or_prefix(request.destination)
except AmbiguousPublicKeyPrefixError as err:
sample = ", ".join(key[:12] for key in err.matches[:2])
raise HTTPException(
status_code=409,
detail=(
f"Ambiguous destination key prefix '{err.prefix}'. "
f"Use a full 64-character public key. Matching contacts: {sample}"
),
) from err
if not db_contact:
raise HTTPException(
status_code=404, detail=f"Contact not found in database: {request.destination}"

View File

@@ -211,7 +211,7 @@ async def decrypt_historical_packets(
# Try to find contact name for display
from app.repository import ContactRepository
contact = await ContactRepository.get_by_key_or_prefix(contact_public_key_hex)
contact = await ContactRepository.get_by_key(contact_public_key_hex)
display_name = contact.name if contact else None
background_tasks.add_task(

View File

@@ -289,10 +289,7 @@ describe('resolveChannelFromHashToken', () => {
];
it('prefers stable key lookup (case-insensitive)', () => {
const result = resolveChannelFromHashToken(
'abcdef0123456789abcdef0123456789',
channels
);
const result = resolveChannelFromHashToken('abcdef0123456789abcdef0123456789', channels);
expect(result?.key).toBe('ABCDEF0123456789ABCDEF0123456789');
});

View File

@@ -74,7 +74,9 @@ export function resolveChannelFromHashToken(token: string, channels: Channel[]):
if (byKey) return byKey;
// Backward compatibility for legacy name-based hashes.
return channels.find((c) => c.name === normalizedToken || c.name === `#${normalizedToken}`) || null;
return (
channels.find((c) => c.name === normalizedToken || c.name === `#${normalizedToken}`) || null
);
}
export function resolveContactFromHashToken(token: string, contacts: Contact[]): Contact | null {
@@ -86,7 +88,9 @@ export function resolveContactFromHashToken(token: string, contacts: Contact[]):
if (byKey) return byKey;
// Backward compatibility for legacy name/prefix-based hashes.
return contacts.find((c) => getContactDisplayName(c.name, c.public_key) === normalizedToken) || null;
return (
contacts.find((c) => getContactDisplayName(c.name, c.public_key) === normalizedToken) || null
);
}
/**

View File

@@ -90,7 +90,7 @@ class TestCreateContact:
with (
patch(
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
"app.routers.contacts.ContactRepository.get_by_key",
new_callable=AsyncMock,
return_value=None,
),
@@ -123,7 +123,7 @@ class TestCreateContact:
from fastapi.testclient import TestClient
with patch(
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
"app.routers.contacts.ContactRepository.get_by_key",
new_callable=AsyncMock,
return_value=None,
):
@@ -160,7 +160,7 @@ class TestCreateContact:
with (
patch(
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
"app.routers.contacts.ContactRepository.get_by_key",
new_callable=AsyncMock,
return_value=existing,
),
@@ -220,6 +220,30 @@ class TestGetContact:
assert response.status_code == 404
def test_get_ambiguous_prefix_returns_409(self):
from fastapi.testclient import TestClient
from app.repository import AmbiguousPublicKeyPrefixError
with patch(
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
new_callable=AsyncMock,
side_effect=AmbiguousPublicKeyPrefixError(
"abcd12",
[
"abcd120000000000000000000000000000000000000000000000000000000000",
"abcd12ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
],
),
):
from app.main import app
client = TestClient(app)
response = client.get("/api/contacts/abcd12")
assert response.status_code == 409
assert "ambiguous" in response.json()["detail"].lower()
class TestMarkRead:
"""Test POST /api/contacts/{public_key}/mark-read."""

View File

@@ -3,7 +3,7 @@
import pytest
from app.database import Database
from app.repository import ContactRepository, MessageRepository
from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository
@pytest.fixture
@@ -117,3 +117,43 @@ async def test_duplicate_with_same_text_and_null_timestamp_rejected(test_db):
received_at=received_at,
)
assert msg_id2 is None # duplicate rejected
@pytest.mark.asyncio
async def test_get_by_key_prefix_returns_none_when_ambiguous(test_db):
"""Ambiguous prefixes should not resolve to an arbitrary contact."""
key1 = "abc1230000000000000000000000000000000000000000000000000000000000"
key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
await ContactRepository.upsert({"public_key": key1, "name": "A"})
await ContactRepository.upsert({"public_key": key2, "name": "B"})
contact = await ContactRepository.get_by_key_prefix("abc123")
assert contact is None
@pytest.mark.asyncio
async def test_get_by_key_or_prefix_raises_on_ambiguous_prefix(test_db):
"""Prefix lookup should raise when multiple contacts match."""
key1 = "abc1230000000000000000000000000000000000000000000000000000000000"
key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
await ContactRepository.upsert({"public_key": key1, "name": "A"})
await ContactRepository.upsert({"public_key": key2, "name": "B"})
with pytest.raises(AmbiguousPublicKeyPrefixError):
await ContactRepository.get_by_key_or_prefix("abc123")
@pytest.mark.asyncio
async def test_get_by_key_or_prefix_prefers_exact_full_key(test_db):
"""Exact key lookup works even when the shorter prefix is ambiguous."""
key1 = "abc1230000000000000000000000000000000000000000000000000000000000"
key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
await ContactRepository.upsert({"public_key": key1, "name": "A"})
await ContactRepository.upsert({"public_key": key2, "name": "B"})
contact = await ContactRepository.get_by_key_or_prefix(key2.upper())
assert contact is not None
assert contact.public_key == key2

View File

@@ -4,6 +4,7 @@ import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from meshcore import EventType
from app.models import (
@@ -13,6 +14,7 @@ from app.models import (
SendChannelMessageRequest,
SendDirectMessageRequest,
)
from app.repository import AmbiguousPublicKeyPrefixError
from app.routers.messages import send_channel_message, send_direct_message
@@ -123,6 +125,34 @@ class TestOutgoingDMBotTrigger:
call_kwargs = mock_bot.call_args[1]
assert call_kwargs["sender_name"] is None
@pytest.mark.asyncio
async def test_send_dm_ambiguous_prefix_returns_409(self):
"""Ambiguous destination prefix should fail instead of selecting a random contact."""
mc = _make_mc()
with (
patch("app.routers.messages.require_connected", return_value=mc),
patch(
"app.repository.ContactRepository.get_by_key_or_prefix",
new=AsyncMock(
side_effect=AmbiguousPublicKeyPrefixError(
"abc123",
[
"abc1230000000000000000000000000000000000000000000000000000000000",
"abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
],
)
),
),
):
with pytest.raises(HTTPException) as exc_info:
await send_direct_message(
SendDirectMessageRequest(destination="abc123", text="Hello")
)
assert exc_info.value.status_code == 409
assert "ambiguous" in exc_info.value.detail.lower()
class TestOutgoingChannelBotTrigger:
"""Test that sending a channel message triggers bots with is_outgoing=True."""