mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-03-28 17:43:05 +01:00
Don't use prefix matching if we can help it
This commit is contained in:
@@ -7,7 +7,7 @@ from meshcore import EventType
|
||||
|
||||
from app.models import CONTACT_TYPE_REPEATER, Contact
|
||||
from app.packet_processor import process_raw_packet
|
||||
from app.repository import ContactRepository, MessageRepository
|
||||
from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository
|
||||
from app.websocket import broadcast_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -74,7 +74,14 @@ async def on_contact_message(event: "Event") -> None:
|
||||
|
||||
# Look up contact from database - use prefix lookup only if needed
|
||||
# (get_by_key_or_prefix does exact match first, then prefix fallback)
|
||||
contact = await ContactRepository.get_by_key_or_prefix(sender_pubkey)
|
||||
try:
|
||||
contact = await ContactRepository.get_by_key_or_prefix(sender_pubkey)
|
||||
except AmbiguousPublicKeyPrefixError:
|
||||
logger.warning(
|
||||
"DM sender prefix '%s' is ambiguous; storing under prefix until full key is known",
|
||||
sender_pubkey,
|
||||
)
|
||||
contact = None
|
||||
if contact:
|
||||
sender_pubkey = contact.public_key.lower()
|
||||
|
||||
|
||||
@@ -132,7 +132,9 @@ class SendMessageRequest(BaseModel):
|
||||
|
||||
|
||||
class SendDirectMessageRequest(SendMessageRequest):
|
||||
destination: str = Field(description="Public key or prefix of recipient")
|
||||
destination: str = Field(
|
||||
description="Recipient public key (64-char hex preferred; prefix must resolve uniquely)"
|
||||
)
|
||||
|
||||
|
||||
class SendChannelMessageRequest(SendMessageRequest):
|
||||
|
||||
@@ -238,7 +238,7 @@ async def create_dm_message_from_decrypted(
|
||||
"""
|
||||
# Check if sender is a repeater - repeaters only send CLI responses, not chat messages.
|
||||
# CLI responses are handled by the command endpoint, not stored in chat history.
|
||||
contact = await ContactRepository.get_by_key_or_prefix(their_public_key)
|
||||
contact = await ContactRepository.get_by_key(their_public_key)
|
||||
if contact and contact.type == CONTACT_TYPE_REPEATER:
|
||||
logger.debug(
|
||||
"Skipping message from repeater %s (CLI responses not stored): %s",
|
||||
|
||||
@@ -19,6 +19,7 @@ from meshcore import EventType
|
||||
from app.models import Contact
|
||||
from app.radio import RadioOperationBusyError, radio_manager
|
||||
from app.repository import (
|
||||
AmbiguousPublicKeyPrefixError,
|
||||
AppSettingsRepository,
|
||||
ChannelRepository,
|
||||
ContactRepository,
|
||||
@@ -585,7 +586,14 @@ async def sync_recent_contacts_to_radio(force: bool = False) -> dict:
|
||||
for favorite in app_settings.favorites:
|
||||
if favorite.type != "contact":
|
||||
continue
|
||||
contact = await ContactRepository.get_by_key_or_prefix(favorite.id)
|
||||
try:
|
||||
contact = await ContactRepository.get_by_key_or_prefix(favorite.id)
|
||||
except AmbiguousPublicKeyPrefixError:
|
||||
logger.warning(
|
||||
"Skipping favorite contact '%s': ambiguous key prefix; use full key",
|
||||
favorite.id,
|
||||
)
|
||||
continue
|
||||
if not contact:
|
||||
continue
|
||||
key = contact.public_key.lower()
|
||||
|
||||
@@ -20,6 +20,15 @@ from app.models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AmbiguousPublicKeyPrefixError(ValueError):
|
||||
"""Raised when a public key prefix matches multiple contacts."""
|
||||
|
||||
def __init__(self, prefix: str, matches: list[str]):
|
||||
self.prefix = prefix.lower()
|
||||
self.matches = matches
|
||||
super().__init__(f"Ambiguous public key prefix '{self.prefix}'")
|
||||
|
||||
|
||||
class ContactRepository:
|
||||
@staticmethod
|
||||
async def upsert(contact: dict[str, Any]) -> None:
|
||||
@@ -89,12 +98,30 @@ class ContactRepository:
|
||||
|
||||
@staticmethod
|
||||
async def get_by_key_prefix(prefix: str) -> Contact | None:
|
||||
"""Get a contact by key prefix only if it resolves uniquely.
|
||||
|
||||
Returns None when no contacts match OR when multiple contacts match
|
||||
the prefix (to avoid silently selecting the wrong contact).
|
||||
"""
|
||||
normalized_prefix = prefix.lower()
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT * FROM contacts WHERE public_key LIKE ? LIMIT 1",
|
||||
(f"{prefix.lower()}%",),
|
||||
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT 2",
|
||||
(f"{normalized_prefix}%",),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return ContactRepository._row_to_contact(row) if row else None
|
||||
rows = list(await cursor.fetchall())
|
||||
if len(rows) != 1:
|
||||
return None
|
||||
return ContactRepository._row_to_contact(rows[0])
|
||||
|
||||
@staticmethod
|
||||
async def _get_prefix_matches(prefix: str, limit: int = 2) -> list[Contact]:
|
||||
"""Get contacts matching a key prefix, up to limit."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT * FROM contacts WHERE public_key LIKE ? ORDER BY public_key LIMIT ?",
|
||||
(f"{prefix.lower()}%", limit),
|
||||
)
|
||||
rows = list(await cursor.fetchall())
|
||||
return [ContactRepository._row_to_contact(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
async def get_by_key_or_prefix(key_or_prefix: str) -> Contact | None:
|
||||
@@ -103,9 +130,18 @@ class ContactRepository:
|
||||
Useful when the input might be a full 64-char public key or a shorter prefix.
|
||||
"""
|
||||
contact = await ContactRepository.get_by_key(key_or_prefix)
|
||||
if not contact:
|
||||
contact = await ContactRepository.get_by_key_prefix(key_or_prefix)
|
||||
return contact
|
||||
if contact:
|
||||
return contact
|
||||
|
||||
matches = await ContactRepository._get_prefix_matches(key_or_prefix, limit=2)
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
if len(matches) > 1:
|
||||
raise AmbiguousPublicKeyPrefixError(
|
||||
key_or_prefix,
|
||||
[m.public_key for m in matches],
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_all(limit: int = 100, offset: int = 0) -> list[Contact]:
|
||||
@@ -416,9 +452,20 @@ class MessageRepository:
|
||||
query += " AND type = ?"
|
||||
params.append(msg_type)
|
||||
if conversation_key:
|
||||
# Support both exact match and prefix match for DMs
|
||||
query += " AND conversation_key LIKE ?"
|
||||
params.append(f"{conversation_key}%")
|
||||
normalized_key = conversation_key
|
||||
# Prefer exact matching for full keys.
|
||||
if len(conversation_key) == 64:
|
||||
normalized_key = conversation_key.lower()
|
||||
query += " AND conversation_key = ?"
|
||||
params.append(normalized_key)
|
||||
elif len(conversation_key) == 32:
|
||||
normalized_key = conversation_key.upper()
|
||||
query += " AND conversation_key = ?"
|
||||
params.append(normalized_key)
|
||||
else:
|
||||
# Prefix match is only for legacy/partial key callers.
|
||||
query += " AND conversation_key LIKE ?"
|
||||
params.append(f"{conversation_key}%")
|
||||
|
||||
if before is not None and before_id is not None:
|
||||
query += " AND (received_at < ? OR (received_at = ? AND id < ?))"
|
||||
|
||||
@@ -20,7 +20,7 @@ from app.models import (
|
||||
)
|
||||
from app.packet_processor import start_historical_dm_decryption
|
||||
from app.radio import radio_manager
|
||||
from app.repository import ContactRepository, MessageRepository
|
||||
from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,6 +37,26 @@ router = APIRouter(prefix="/contacts", tags=["contacts"])
|
||||
REPEATER_OP_DELAY_SECONDS = 2.0
|
||||
|
||||
|
||||
def _ambiguous_contact_detail(err: AmbiguousPublicKeyPrefixError) -> str:
|
||||
sample = ", ".join(key[:12] for key in err.matches[:2])
|
||||
return (
|
||||
f"Ambiguous contact key prefix '{err.prefix}'. "
|
||||
f"Use a full 64-character public key. Matching contacts: {sample}"
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_contact_or_404(
|
||||
public_key: str, not_found_detail: str = "Contact not found"
|
||||
) -> Contact:
|
||||
try:
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
except AmbiguousPublicKeyPrefixError as err:
|
||||
raise HTTPException(status_code=409, detail=_ambiguous_contact_detail(err)) from err
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail=not_found_detail)
|
||||
return contact
|
||||
|
||||
|
||||
async def prepare_repeater_connection(mc, contact: Contact, password: str) -> None:
|
||||
"""Prepare connection to a repeater by adding to radio and logging in.
|
||||
|
||||
@@ -89,7 +109,7 @@ async def create_contact(
|
||||
raise HTTPException(status_code=400, detail="Invalid public key: must be valid hex") from e
|
||||
|
||||
# Check if contact already exists
|
||||
existing = await ContactRepository.get_by_key_or_prefix(request.public_key)
|
||||
existing = await ContactRepository.get_by_key(request.public_key)
|
||||
if existing:
|
||||
# Update name if provided
|
||||
if request.name:
|
||||
@@ -153,10 +173,7 @@ async def create_contact(
|
||||
@router.get("/{public_key}", response_model=Contact)
|
||||
async def get_contact(public_key: str) -> Contact:
|
||||
"""Get a specific contact by public key or prefix."""
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
return contact
|
||||
return await _resolve_contact_or_404(public_key)
|
||||
|
||||
|
||||
@router.post("/sync")
|
||||
@@ -189,9 +206,7 @@ async def remove_contact_from_radio(public_key: str) -> dict:
|
||||
"""Remove a contact from the radio (keeps it in database)."""
|
||||
mc = require_connected()
|
||||
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
contact = await _resolve_contact_or_404(public_key)
|
||||
|
||||
# Get the contact from radio
|
||||
radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12])
|
||||
@@ -216,9 +231,7 @@ async def add_contact_to_radio(public_key: str) -> dict:
|
||||
"""Add a contact from the database to the radio."""
|
||||
mc = require_connected()
|
||||
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found in database")
|
||||
contact = await _resolve_contact_or_404(public_key, "Contact not found in database")
|
||||
|
||||
# Check if already on radio
|
||||
radio_contact = mc.get_contact_by_key_prefix(contact.public_key[:12])
|
||||
@@ -239,9 +252,7 @@ async def add_contact_to_radio(public_key: str) -> dict:
|
||||
@router.post("/{public_key}/mark-read")
|
||||
async def mark_contact_read(public_key: str) -> dict:
|
||||
"""Mark a contact conversation as read (update last_read_at timestamp)."""
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
contact = await _resolve_contact_or_404(public_key)
|
||||
|
||||
updated = await ContactRepository.update_last_read_at(contact.public_key)
|
||||
if not updated:
|
||||
@@ -253,9 +264,7 @@ async def mark_contact_read(public_key: str) -> dict:
|
||||
@router.delete("/{public_key}")
|
||||
async def delete_contact(public_key: str) -> dict:
|
||||
"""Delete a contact from the database (and radio if present)."""
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
contact = await _resolve_contact_or_404(public_key)
|
||||
|
||||
# Remove from radio if connected and contact is on radio
|
||||
if radio_manager.is_connected and radio_manager.meshcore:
|
||||
@@ -282,9 +291,7 @@ async def request_telemetry(public_key: str, request: TelemetryRequest) -> Telem
|
||||
mc = require_connected()
|
||||
|
||||
# Get contact from database
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
contact = await _resolve_contact_or_404(public_key)
|
||||
|
||||
# Verify it's a repeater
|
||||
if contact.type != CONTACT_TYPE_REPEATER:
|
||||
@@ -459,9 +466,7 @@ async def send_repeater_command(public_key: str, request: CommandRequest) -> Com
|
||||
mc = require_connected()
|
||||
|
||||
# Get contact from database
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
contact = await _resolve_contact_or_404(public_key)
|
||||
|
||||
# Verify it's a repeater
|
||||
if contact.type != CONTACT_TYPE_REPEATER:
|
||||
@@ -540,9 +545,7 @@ async def request_trace(public_key: str) -> TraceResponse:
|
||||
"""
|
||||
mc = require_connected()
|
||||
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
contact = await _resolve_contact_or_404(public_key)
|
||||
|
||||
tag = random.randint(1, 0xFFFFFFFF)
|
||||
# First 2 hex chars of pubkey = 1-byte hash used by the trace protocol
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.dependencies import require_connected
|
||||
from app.event_handlers import track_pending_ack
|
||||
from app.models import Message, SendChannelMessageRequest, SendDirectMessageRequest
|
||||
from app.radio import radio_manager
|
||||
from app.repository import MessageRepository
|
||||
from app.repository import AmbiguousPublicKeyPrefixError, MessageRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/messages", tags=["messages"])
|
||||
@@ -47,7 +47,17 @@ async def send_direct_message(request: SendDirectMessageRequest) -> Message:
|
||||
# First check our database for the contact
|
||||
from app.repository import ContactRepository
|
||||
|
||||
db_contact = await ContactRepository.get_by_key_or_prefix(request.destination)
|
||||
try:
|
||||
db_contact = await ContactRepository.get_by_key_or_prefix(request.destination)
|
||||
except AmbiguousPublicKeyPrefixError as err:
|
||||
sample = ", ".join(key[:12] for key in err.matches[:2])
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(
|
||||
f"Ambiguous destination key prefix '{err.prefix}'. "
|
||||
f"Use a full 64-character public key. Matching contacts: {sample}"
|
||||
),
|
||||
) from err
|
||||
if not db_contact:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Contact not found in database: {request.destination}"
|
||||
|
||||
@@ -211,7 +211,7 @@ async def decrypt_historical_packets(
|
||||
# Try to find contact name for display
|
||||
from app.repository import ContactRepository
|
||||
|
||||
contact = await ContactRepository.get_by_key_or_prefix(contact_public_key_hex)
|
||||
contact = await ContactRepository.get_by_key(contact_public_key_hex)
|
||||
display_name = contact.name if contact else None
|
||||
|
||||
background_tasks.add_task(
|
||||
|
||||
@@ -289,10 +289,7 @@ describe('resolveChannelFromHashToken', () => {
|
||||
];
|
||||
|
||||
it('prefers stable key lookup (case-insensitive)', () => {
|
||||
const result = resolveChannelFromHashToken(
|
||||
'abcdef0123456789abcdef0123456789',
|
||||
channels
|
||||
);
|
||||
const result = resolveChannelFromHashToken('abcdef0123456789abcdef0123456789', channels);
|
||||
expect(result?.key).toBe('ABCDEF0123456789ABCDEF0123456789');
|
||||
});
|
||||
|
||||
|
||||
@@ -74,7 +74,9 @@ export function resolveChannelFromHashToken(token: string, channels: Channel[]):
|
||||
if (byKey) return byKey;
|
||||
|
||||
// Backward compatibility for legacy name-based hashes.
|
||||
return channels.find((c) => c.name === normalizedToken || c.name === `#${normalizedToken}`) || null;
|
||||
return (
|
||||
channels.find((c) => c.name === normalizedToken || c.name === `#${normalizedToken}`) || null
|
||||
);
|
||||
}
|
||||
|
||||
export function resolveContactFromHashToken(token: string, contacts: Contact[]): Contact | null {
|
||||
@@ -86,7 +88,9 @@ export function resolveContactFromHashToken(token: string, contacts: Contact[]):
|
||||
if (byKey) return byKey;
|
||||
|
||||
// Backward compatibility for legacy name/prefix-based hashes.
|
||||
return contacts.find((c) => getContactDisplayName(c.name, c.public_key) === normalizedToken) || null;
|
||||
return (
|
||||
contacts.find((c) => getContactDisplayName(c.name, c.public_key) === normalizedToken) || null
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -90,7 +90,7 @@ class TestCreateContact:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
|
||||
"app.routers.contacts.ContactRepository.get_by_key",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
@@ -123,7 +123,7 @@ class TestCreateContact:
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
with patch(
|
||||
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
|
||||
"app.routers.contacts.ContactRepository.get_by_key",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
@@ -160,7 +160,7 @@ class TestCreateContact:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
|
||||
"app.routers.contacts.ContactRepository.get_by_key",
|
||||
new_callable=AsyncMock,
|
||||
return_value=existing,
|
||||
),
|
||||
@@ -220,6 +220,30 @@ class TestGetContact:
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_ambiguous_prefix_returns_409(self):
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.repository import AmbiguousPublicKeyPrefixError
|
||||
|
||||
with patch(
|
||||
"app.routers.contacts.ContactRepository.get_by_key_or_prefix",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=AmbiguousPublicKeyPrefixError(
|
||||
"abcd12",
|
||||
[
|
||||
"abcd120000000000000000000000000000000000000000000000000000000000",
|
||||
"abcd12ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
|
||||
],
|
||||
),
|
||||
):
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/contacts/abcd12")
|
||||
|
||||
assert response.status_code == 409
|
||||
assert "ambiguous" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestMarkRead:
|
||||
"""Test POST /api/contacts/{public_key}/mark-read."""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
|
||||
from app.database import Database
|
||||
from app.repository import ContactRepository, MessageRepository
|
||||
from app.repository import AmbiguousPublicKeyPrefixError, ContactRepository, MessageRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -117,3 +117,43 @@ async def test_duplicate_with_same_text_and_null_timestamp_rejected(test_db):
|
||||
received_at=received_at,
|
||||
)
|
||||
assert msg_id2 is None # duplicate rejected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_key_prefix_returns_none_when_ambiguous(test_db):
|
||||
"""Ambiguous prefixes should not resolve to an arbitrary contact."""
|
||||
key1 = "abc1230000000000000000000000000000000000000000000000000000000000"
|
||||
key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
|
||||
|
||||
await ContactRepository.upsert({"public_key": key1, "name": "A"})
|
||||
await ContactRepository.upsert({"public_key": key2, "name": "B"})
|
||||
|
||||
contact = await ContactRepository.get_by_key_prefix("abc123")
|
||||
assert contact is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_key_or_prefix_raises_on_ambiguous_prefix(test_db):
|
||||
"""Prefix lookup should raise when multiple contacts match."""
|
||||
key1 = "abc1230000000000000000000000000000000000000000000000000000000000"
|
||||
key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
|
||||
|
||||
await ContactRepository.upsert({"public_key": key1, "name": "A"})
|
||||
await ContactRepository.upsert({"public_key": key2, "name": "B"})
|
||||
|
||||
with pytest.raises(AmbiguousPublicKeyPrefixError):
|
||||
await ContactRepository.get_by_key_or_prefix("abc123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_key_or_prefix_prefers_exact_full_key(test_db):
|
||||
"""Exact key lookup works even when the shorter prefix is ambiguous."""
|
||||
key1 = "abc1230000000000000000000000000000000000000000000000000000000000"
|
||||
key2 = "abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
|
||||
|
||||
await ContactRepository.upsert({"public_key": key1, "name": "A"})
|
||||
await ContactRepository.upsert({"public_key": key2, "name": "B"})
|
||||
|
||||
contact = await ContactRepository.get_by_key_or_prefix(key2.upper())
|
||||
assert contact is not None
|
||||
assert contact.public_key == key2
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from meshcore import EventType
|
||||
|
||||
from app.models import (
|
||||
@@ -13,6 +14,7 @@ from app.models import (
|
||||
SendChannelMessageRequest,
|
||||
SendDirectMessageRequest,
|
||||
)
|
||||
from app.repository import AmbiguousPublicKeyPrefixError
|
||||
from app.routers.messages import send_channel_message, send_direct_message
|
||||
|
||||
|
||||
@@ -123,6 +125,34 @@ class TestOutgoingDMBotTrigger:
|
||||
call_kwargs = mock_bot.call_args[1]
|
||||
assert call_kwargs["sender_name"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_ambiguous_prefix_returns_409(self):
|
||||
"""Ambiguous destination prefix should fail instead of selecting a random contact."""
|
||||
mc = _make_mc()
|
||||
|
||||
with (
|
||||
patch("app.routers.messages.require_connected", return_value=mc),
|
||||
patch(
|
||||
"app.repository.ContactRepository.get_by_key_or_prefix",
|
||||
new=AsyncMock(
|
||||
side_effect=AmbiguousPublicKeyPrefixError(
|
||||
"abc123",
|
||||
[
|
||||
"abc1230000000000000000000000000000000000000000000000000000000000",
|
||||
"abc123ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
|
||||
],
|
||||
)
|
||||
),
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await send_direct_message(
|
||||
SendDirectMessageRequest(destination="abc123", text="Hello")
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "ambiguous" in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
class TestOutgoingChannelBotTrigger:
|
||||
"""Test that sending a channel message triggers bots with is_outgoing=True."""
|
||||
|
||||
Reference in New Issue
Block a user