From 18e1408292885f3d3d5b945939458dd0539dfe58 Mon Sep 17 00:00:00 2001 From: Jack Kingsman Date: Mon, 9 Mar 2026 23:42:46 -0700 Subject: [PATCH] Be better about DB insertion shape --- app/event_handlers.py | 19 +++--- app/models.py | 94 +++++++++++++++++++++-------- app/packet_processor.py | 35 ++++++----- app/radio_sync.py | 4 +- app/repository/contacts.py | 49 +++++++++------ app/routers/contacts.py | 46 ++++---------- app/services/radio_runtime.py | 4 +- tests/test_radio_runtime_service.py | 23 +++++++ tests/test_repository.py | 32 ++++++++++ 9 files changed, 202 insertions(+), 104 deletions(-) diff --git a/app/event_handlers.py b/app/event_handlers.py index 173986b..d9361fe 100644 --- a/app/event_handlers.py +++ b/app/event_handlers.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from meshcore import EventType -from app.models import CONTACT_TYPE_REPEATER, Contact +from app.models import CONTACT_TYPE_REPEATER, Contact, ContactUpsert from app.packet_processor import process_raw_packet from app.repository import ( AmbiguousPublicKeyPrefixError, @@ -228,11 +228,9 @@ async def on_new_contact(event: "Event") -> None: logger.debug("New contact: %s", public_key[:12]) - contact_data = { - **Contact.from_radio_dict(public_key.lower(), payload, on_radio=True), - "last_seen": int(time.time()), - } - await ContactRepository.upsert(contact_data) + contact_upsert = ContactUpsert.from_radio_dict(public_key.lower(), payload, on_radio=True) + contact_upsert.last_seen = int(time.time()) + await ContactRepository.upsert(contact_upsert) adv_name = payload.get("adv_name") await record_contact_name_and_reconcile( @@ -245,7 +243,14 @@ async def on_new_contact(event: "Event") -> None: # Read back from DB so the broadcast includes all fields (last_contacted, # last_read_at, etc.) matching the REST Contact shape exactly. db_contact = await ContactRepository.get_by_key(public_key) - broadcast_event("contact", (db_contact.model_dump() if db_contact else contact_data)) + broadcast_event( + "contact", + ( + db_contact.model_dump() + if db_contact + else Contact(**contact_upsert.model_dump(exclude_none=True)).model_dump() + ), + ) async def on_ack(event: "Event") -> None: diff --git a/app/models.py b/app/models.py index 8c56f45..be428a0 100644 --- a/app/models.py +++ b/app/models.py @@ -5,6 +5,64 @@ from pydantic import BaseModel, Field from app.path_utils import normalize_contact_route +class ContactUpsert(BaseModel): + """Typed write contract for contacts persisted to SQLite.""" + + public_key: str = Field(description="Public key (64-char hex)") + name: str | None = None + type: int = 0 + flags: int = 0 + last_path: str | None = None + last_path_len: int = -1 + out_path_hash_mode: int | None = None + route_override_path: str | None = None + route_override_len: int | None = None + route_override_hash_mode: int | None = None + last_advert: int | None = None + lat: float | None = None + lon: float | None = None + last_seen: int | None = None + on_radio: bool | None = None + last_contacted: int | None = None + first_seen: int | None = None + + @classmethod + def from_contact(cls, contact: "Contact", **changes) -> "ContactUpsert": + return cls.model_validate( + { + **contact.model_dump(exclude={"last_read_at"}), + **changes, + } + ) + + @classmethod + def from_radio_dict( + cls, public_key: str, radio_data: dict, on_radio: bool = False + ) -> "ContactUpsert": + """Convert radio contact data to the contact-row write shape.""" + last_path, last_path_len, out_path_hash_mode = normalize_contact_route( + radio_data.get("out_path"), + radio_data.get("out_path_len", -1), + radio_data.get( + "out_path_hash_mode", + -1 if radio_data.get("out_path_len", -1) == -1 else 0, + ), + ) + return cls( + public_key=public_key, + name=radio_data.get("adv_name"), + type=radio_data.get("type", 0), + flags=radio_data.get("flags", 0), + last_path=last_path, + last_path_len=last_path_len, + out_path_hash_mode=out_path_hash_mode, + lat=radio_data.get("adv_lat"), + lon=radio_data.get("adv_lon"), + last_advert=radio_data.get("last_advert"), + on_radio=on_radio, + ) + + class Contact(BaseModel): public_key: str = Field(description="Public key (64-char hex)") name: str | None = None @@ -61,34 +119,18 @@ class Contact(BaseModel): "last_advert": self.last_advert if self.last_advert is not None else 0, } + def to_upsert(self, **changes) -> ContactUpsert: + """Convert the stored contact to the repository's write contract.""" + return ContactUpsert.from_contact(self, **changes) + @staticmethod def from_radio_dict(public_key: str, radio_data: dict, on_radio: bool = False) -> dict: - """Convert radio contact data to database format dict. - - This is the inverse of to_radio_dict(), used when syncing contacts - from radio to database. - """ - last_path, last_path_len, out_path_hash_mode = normalize_contact_route( - radio_data.get("out_path"), - radio_data.get("out_path_len", -1), - radio_data.get( - "out_path_hash_mode", - -1 if radio_data.get("out_path_len", -1) == -1 else 0, - ), - ) - return { - "public_key": public_key, - "name": radio_data.get("adv_name"), - "type": radio_data.get("type", 0), - "flags": radio_data.get("flags", 0), - "last_path": last_path, - "last_path_len": last_path_len, - "out_path_hash_mode": out_path_hash_mode, - "lat": radio_data.get("adv_lat"), - "lon": radio_data.get("adv_lon"), - "last_advert": radio_data.get("last_advert"), - "on_radio": on_radio, - } + """Backward-compatible dict wrapper over ContactUpsert.from_radio_dict().""" + return ContactUpsert.from_radio_dict( + public_key, + radio_data, + on_radio=on_radio, + ).model_dump() class CreateContactRequest(BaseModel): diff --git a/app/packet_processor.py b/app/packet_processor.py index d8e09b6..e119d20 100644 --- a/app/packet_processor.py +++ b/app/packet_processor.py @@ -30,6 +30,8 @@ from app.decoder import ( from app.keystore import get_private_key, get_public_key, has_private_key from app.models import ( CONTACT_TYPE_REPEATER, + Contact, + ContactUpsert, RawPacketBroadcast, RawPacketDecryptedInfo, ) @@ -489,21 +491,21 @@ async def _process_advertisement( hop_count=new_path_len, ) - contact_data = { - "public_key": advert.public_key.lower(), - "name": advert.name, - "type": contact_type, - "lat": advert.lat, - "lon": advert.lon, - "last_advert": advert.timestamp if advert.timestamp > 0 else timestamp, - "last_seen": timestamp, - "last_path": path_hex, - "last_path_len": path_len, - "out_path_hash_mode": out_path_hash_mode, - "first_seen": timestamp, # COALESCE in upsert preserves existing value - } + contact_upsert = ContactUpsert( + public_key=advert.public_key.lower(), + name=advert.name, + type=contact_type, + lat=advert.lat, + lon=advert.lon, + last_advert=advert.timestamp if advert.timestamp > 0 else timestamp, + last_seen=timestamp, + last_path=path_hex, + last_path_len=path_len, + out_path_hash_mode=out_path_hash_mode, + first_seen=timestamp, # COALESCE in upsert preserves existing value + ) - await ContactRepository.upsert(contact_data) + await ContactRepository.upsert(contact_upsert) await record_contact_name_and_reconcile( public_key=advert.public_key, contact_name=advert.name, @@ -517,7 +519,10 @@ async def _process_advertisement( if db_contact: broadcast_event("contact", db_contact.model_dump()) else: - broadcast_event("contact", contact_data) + broadcast_event( + "contact", + Contact(**contact_upsert.model_dump(exclude_none=True)).model_dump(), + ) # For new contacts, optionally attempt to decrypt any historical DMs we may have stored # This is controlled by the auto_decrypt_dm_on_advert setting diff --git a/app/radio_sync.py b/app/radio_sync.py index 1fb20cc..f6eb876 100644 --- a/app/radio_sync.py +++ b/app/radio_sync.py @@ -17,7 +17,7 @@ from contextlib import asynccontextmanager from meshcore import EventType, MeshCore from app.event_handlers import cleanup_expired_acks -from app.models import Contact +from app.models import Contact, ContactUpsert from app.radio import RadioOperationBusyError from app.repository import ( AmbiguousPublicKeyPrefixError, @@ -155,7 +155,7 @@ async def sync_and_offload_contacts(mc: MeshCore) -> dict: for public_key, contact_data in contacts.items(): # Save to database await ContactRepository.upsert( - Contact.from_radio_dict(public_key, contact_data, on_radio=False) + ContactUpsert.from_radio_dict(public_key, contact_data, on_radio=False) ) await reconcile_contact_messages( public_key=public_key, diff --git a/app/repository/contacts.py b/app/repository/contacts.py index 24c3e22..94c69ee 100644 --- a/app/repository/contacts.py +++ b/app/repository/contacts.py @@ -1,4 +1,5 @@ import time +from collections.abc import Mapping from typing import Any from app.database import db @@ -7,6 +8,7 @@ from app.models import ( ContactAdvertPath, ContactAdvertPathSummary, ContactNameHistory, + ContactUpsert, ) from app.path_utils import first_hop_hex, normalize_contact_route, normalize_route_override @@ -22,17 +24,28 @@ class AmbiguousPublicKeyPrefixError(ValueError): class ContactRepository: @staticmethod - async def upsert(contact: dict[str, Any]) -> None: + def _coerce_contact_upsert( + contact: ContactUpsert | Contact | Mapping[str, Any], + ) -> ContactUpsert: + if isinstance(contact, ContactUpsert): + return contact + if isinstance(contact, Contact): + return contact.to_upsert() + return ContactUpsert.model_validate(contact) + + @staticmethod + async def upsert(contact: ContactUpsert | Contact | Mapping[str, Any]) -> None: + contact_row = ContactRepository._coerce_contact_upsert(contact) last_path, last_path_len, out_path_hash_mode = normalize_contact_route( - contact.get("last_path"), - contact.get("last_path_len", -1), - contact.get("out_path_hash_mode"), + contact_row.last_path, + contact_row.last_path_len, + contact_row.out_path_hash_mode, ) route_override_path, route_override_len, route_override_hash_mode = ( normalize_route_override( - contact.get("route_override_path"), - contact.get("route_override_len"), - contact.get("route_override_hash_mode"), + contact_row.route_override_path, + contact_row.route_override_len, + contact_row.route_override_hash_mode, ) ) @@ -70,23 +83,23 @@ class ContactRepository: first_seen = COALESCE(contacts.first_seen, excluded.first_seen) """, ( - contact.get("public_key", "").lower(), - contact.get("name"), - contact.get("type", 0), - contact.get("flags", 0), + contact_row.public_key.lower(), + contact_row.name, + contact_row.type, + contact_row.flags, last_path, last_path_len, out_path_hash_mode, route_override_path, route_override_len, route_override_hash_mode, - contact.get("last_advert"), - contact.get("lat"), - contact.get("lon"), - contact.get("last_seen", int(time.time())), - contact.get("on_radio"), - contact.get("last_contacted"), - contact.get("first_seen"), + contact_row.last_advert, + contact_row.lat, + contact_row.lon, + contact_row.last_seen if contact_row.last_seen is not None else int(time.time()), + contact_row.on_radio, + contact_row.last_contacted, + contact_row.first_seen, ), ) await db.conn.commit() diff --git a/app/routers/contacts.py b/app/routers/contacts.py index c62f16f..3168736 100644 --- a/app/routers/contacts.py +++ b/app/routers/contacts.py @@ -12,6 +12,7 @@ from app.models import ( ContactAdvertPathSummary, ContactDetail, ContactRoutingOverrideRequest, + ContactUpsert, CreateContactRequest, NearestRepeater, TraceResponse, @@ -133,23 +134,7 @@ async def create_contact( if existing: # Update name if provided if request.name: - await ContactRepository.upsert( - { - "public_key": existing.public_key, - "name": request.name, - "type": existing.type, - "flags": existing.flags, - "last_path": existing.last_path, - "last_path_len": existing.last_path_len, - "out_path_hash_mode": existing.out_path_hash_mode, - "last_advert": existing.last_advert, - "lat": existing.lat, - "lon": existing.lon, - "last_seen": existing.last_seen, - "on_radio": existing.on_radio, - "last_contacted": existing.last_contacted, - } - ) + await ContactRepository.upsert(existing.to_upsert(name=request.name)) refreshed = await ContactRepository.get_by_key(request.public_key) if refreshed is not None: existing = refreshed @@ -164,22 +149,13 @@ async def create_contact( # Create new contact lower_key = request.public_key.lower() - contact_data = { - "public_key": lower_key, - "name": request.name, - "type": 0, # Unknown - "flags": 0, - "last_path": None, - "last_path_len": -1, - "out_path_hash_mode": -1, - "last_advert": None, - "lat": None, - "lon": None, - "last_seen": None, - "on_radio": False, - "last_contacted": None, - } - await ContactRepository.upsert(contact_data) + contact_upsert = ContactUpsert( + public_key=lower_key, + name=request.name, + out_path_hash_mode=-1, + on_radio=False, + ) + await ContactRepository.upsert(contact_upsert) logger.info("Created contact %s", lower_key[:12]) await reconcile_contact_messages( @@ -192,7 +168,7 @@ async def create_contact( if request.try_historical: await start_historical_dm_decryption(background_tasks, lower_key, request.name) - return Contact(**contact_data) + return Contact(**contact_upsert.model_dump()) @router.get("/{public_key}/detail", response_model=ContactDetail) @@ -309,7 +285,7 @@ async def sync_contacts_from_radio() -> dict: for public_key, contact_data in contacts.items(): lower_key = public_key.lower() await ContactRepository.upsert( - Contact.from_radio_dict(lower_key, contact_data, on_radio=True) + ContactUpsert.from_radio_dict(lower_key, contact_data, on_radio=True) ) synced_keys.append(lower_key) await reconcile_contact_messages( diff --git a/app/services/radio_runtime.py b/app/services/radio_runtime.py index eb10251..bb2b700 100644 --- a/app/services/radio_runtime.py +++ b/app/services/radio_runtime.py @@ -73,8 +73,10 @@ class RadioRuntime: """Return MeshCore when available, mirroring existing HTTP semantics.""" if self.is_setup_in_progress: raise HTTPException(status_code=503, detail="Radio is initializing") + if not self.is_connected: + raise HTTPException(status_code=503, detail="Radio not connected") mc = self.meshcore - if not self.is_connected or mc is None: + if mc is None: raise HTTPException(status_code=503, detail="Radio not connected") return mc diff --git a/tests/test_radio_runtime_service.py b/tests/test_radio_runtime_service.py index 8af1c84..41ae42e 100644 --- a/tests/test_radio_runtime_service.py +++ b/tests/test_radio_runtime_service.py @@ -65,6 +65,29 @@ def test_require_connected_preserves_http_semantics(): assert exc.value.status_code == 503 +def test_require_connected_returns_fresh_meshcore_after_connectivity_check(): + old_meshcore = object() + new_meshcore = object() + + class _SwappingManager: + def __init__(self): + self._meshcore = old_meshcore + self.is_setup_in_progress = False + + @property + def is_connected(self): + self._meshcore = new_meshcore + return True + + @property + def meshcore(self): + return self._meshcore + + runtime = RadioRuntime(_SwappingManager()) + + assert runtime.require_connected() is new_meshcore + + @pytest.mark.asyncio async def test_radio_operation_delegates_to_current_manager(): manager = _Manager(meshcore="meshcore", is_connected=True) diff --git a/tests/test_repository.py b/tests/test_repository.py index 7b137c9..9dce1ce 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from app.models import Contact, ContactUpsert from app.repository import ( ContactAdvertPathRepository, ContactNameHistoryRepository, @@ -643,3 +644,34 @@ class TestMessageRepositoryGetById: result = await MessageRepository.get_by_id(999999) assert result is None + + +class TestContactRepositoryUpsertContracts: + @pytest.mark.asyncio + async def test_accepts_contact_upsert_model(self, test_db): + await ContactRepository.upsert( + ContactUpsert(public_key="aa" * 32, name="Alice", type=1, on_radio=False) + ) + + contact = await ContactRepository.get_by_key("aa" * 32) + assert contact is not None + assert contact.name == "Alice" + assert contact.type == 1 + + @pytest.mark.asyncio + async def test_accepts_contact_model(self, test_db): + await ContactRepository.upsert( + Contact( + public_key="bb" * 32, + name="Bob", + type=2, + on_radio=True, + out_path_hash_mode=-1, + ) + ) + + contact = await ContactRepository.get_by_key("bb" * 32) + assert contact is not None + assert contact.name == "Bob" + assert contact.type == 2 + assert contact.on_radio is True