Be better about DB insertion shape

This commit is contained in:
Jack Kingsman
2026-03-09 23:42:46 -07:00
parent 3e941a5b20
commit 18e1408292
9 changed files with 202 additions and 104 deletions

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
from meshcore import EventType
from app.models import CONTACT_TYPE_REPEATER, Contact
from app.models import CONTACT_TYPE_REPEATER, Contact, ContactUpsert
from app.packet_processor import process_raw_packet
from app.repository import (
AmbiguousPublicKeyPrefixError,
@@ -228,11 +228,9 @@ async def on_new_contact(event: "Event") -> None:
logger.debug("New contact: %s", public_key[:12])
contact_data = {
**Contact.from_radio_dict(public_key.lower(), payload, on_radio=True),
"last_seen": int(time.time()),
}
await ContactRepository.upsert(contact_data)
contact_upsert = ContactUpsert.from_radio_dict(public_key.lower(), payload, on_radio=True)
contact_upsert.last_seen = int(time.time())
await ContactRepository.upsert(contact_upsert)
adv_name = payload.get("adv_name")
await record_contact_name_and_reconcile(
@@ -245,7 +243,14 @@ async def on_new_contact(event: "Event") -> None:
# Read back from DB so the broadcast includes all fields (last_contacted,
# last_read_at, etc.) matching the REST Contact shape exactly.
db_contact = await ContactRepository.get_by_key(public_key)
broadcast_event("contact", (db_contact.model_dump() if db_contact else contact_data))
broadcast_event(
"contact",
(
db_contact.model_dump()
if db_contact
else Contact(**contact_upsert.model_dump(exclude_none=True)).model_dump()
),
)
async def on_ack(event: "Event") -> None:

View File

@@ -5,6 +5,64 @@ from pydantic import BaseModel, Field
from app.path_utils import normalize_contact_route
class ContactUpsert(BaseModel):
"""Typed write contract for contacts persisted to SQLite."""
public_key: str = Field(description="Public key (64-char hex)")
name: str | None = None
type: int = 0
flags: int = 0
last_path: str | None = None
last_path_len: int = -1
out_path_hash_mode: int | None = None
route_override_path: str | None = None
route_override_len: int | None = None
route_override_hash_mode: int | None = None
last_advert: int | None = None
lat: float | None = None
lon: float | None = None
last_seen: int | None = None
on_radio: bool | None = None
last_contacted: int | None = None
first_seen: int | None = None
@classmethod
def from_contact(cls, contact: "Contact", **changes) -> "ContactUpsert":
return cls.model_validate(
{
**contact.model_dump(exclude={"last_read_at"}),
**changes,
}
)
@classmethod
def from_radio_dict(
cls, public_key: str, radio_data: dict, on_radio: bool = False
) -> "ContactUpsert":
"""Convert radio contact data to the contact-row write shape."""
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
radio_data.get("out_path"),
radio_data.get("out_path_len", -1),
radio_data.get(
"out_path_hash_mode",
-1 if radio_data.get("out_path_len", -1) == -1 else 0,
),
)
return cls(
public_key=public_key,
name=radio_data.get("adv_name"),
type=radio_data.get("type", 0),
flags=radio_data.get("flags", 0),
last_path=last_path,
last_path_len=last_path_len,
out_path_hash_mode=out_path_hash_mode,
lat=radio_data.get("adv_lat"),
lon=radio_data.get("adv_lon"),
last_advert=radio_data.get("last_advert"),
on_radio=on_radio,
)
class Contact(BaseModel):
public_key: str = Field(description="Public key (64-char hex)")
name: str | None = None
@@ -61,34 +119,18 @@ class Contact(BaseModel):
"last_advert": self.last_advert if self.last_advert is not None else 0,
}
def to_upsert(self, **changes) -> ContactUpsert:
"""Convert the stored contact to the repository's write contract."""
return ContactUpsert.from_contact(self, **changes)
@staticmethod
def from_radio_dict(public_key: str, radio_data: dict, on_radio: bool = False) -> dict:
"""Convert radio contact data to database format dict.
This is the inverse of to_radio_dict(), used when syncing contacts
from radio to database.
"""
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
radio_data.get("out_path"),
radio_data.get("out_path_len", -1),
radio_data.get(
"out_path_hash_mode",
-1 if radio_data.get("out_path_len", -1) == -1 else 0,
),
)
return {
"public_key": public_key,
"name": radio_data.get("adv_name"),
"type": radio_data.get("type", 0),
"flags": radio_data.get("flags", 0),
"last_path": last_path,
"last_path_len": last_path_len,
"out_path_hash_mode": out_path_hash_mode,
"lat": radio_data.get("adv_lat"),
"lon": radio_data.get("adv_lon"),
"last_advert": radio_data.get("last_advert"),
"on_radio": on_radio,
}
"""Backward-compatible dict wrapper over ContactUpsert.from_radio_dict()."""
return ContactUpsert.from_radio_dict(
public_key,
radio_data,
on_radio=on_radio,
).model_dump()
class CreateContactRequest(BaseModel):

View File

@@ -30,6 +30,8 @@ from app.decoder import (
from app.keystore import get_private_key, get_public_key, has_private_key
from app.models import (
CONTACT_TYPE_REPEATER,
Contact,
ContactUpsert,
RawPacketBroadcast,
RawPacketDecryptedInfo,
)
@@ -489,21 +491,21 @@ async def _process_advertisement(
hop_count=new_path_len,
)
contact_data = {
"public_key": advert.public_key.lower(),
"name": advert.name,
"type": contact_type,
"lat": advert.lat,
"lon": advert.lon,
"last_advert": advert.timestamp if advert.timestamp > 0 else timestamp,
"last_seen": timestamp,
"last_path": path_hex,
"last_path_len": path_len,
"out_path_hash_mode": out_path_hash_mode,
"first_seen": timestamp, # COALESCE in upsert preserves existing value
}
contact_upsert = ContactUpsert(
public_key=advert.public_key.lower(),
name=advert.name,
type=contact_type,
lat=advert.lat,
lon=advert.lon,
last_advert=advert.timestamp if advert.timestamp > 0 else timestamp,
last_seen=timestamp,
last_path=path_hex,
last_path_len=path_len,
out_path_hash_mode=out_path_hash_mode,
first_seen=timestamp, # COALESCE in upsert preserves existing value
)
await ContactRepository.upsert(contact_data)
await ContactRepository.upsert(contact_upsert)
await record_contact_name_and_reconcile(
public_key=advert.public_key,
contact_name=advert.name,
@@ -517,7 +519,10 @@ async def _process_advertisement(
if db_contact:
broadcast_event("contact", db_contact.model_dump())
else:
broadcast_event("contact", contact_data)
broadcast_event(
"contact",
Contact(**contact_upsert.model_dump(exclude_none=True)).model_dump(),
)
# For new contacts, optionally attempt to decrypt any historical DMs we may have stored
# This is controlled by the auto_decrypt_dm_on_advert setting

View File

@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from meshcore import EventType, MeshCore
from app.event_handlers import cleanup_expired_acks
from app.models import Contact
from app.models import Contact, ContactUpsert
from app.radio import RadioOperationBusyError
from app.repository import (
AmbiguousPublicKeyPrefixError,
@@ -155,7 +155,7 @@ async def sync_and_offload_contacts(mc: MeshCore) -> dict:
for public_key, contact_data in contacts.items():
# Save to database
await ContactRepository.upsert(
Contact.from_radio_dict(public_key, contact_data, on_radio=False)
ContactUpsert.from_radio_dict(public_key, contact_data, on_radio=False)
)
await reconcile_contact_messages(
public_key=public_key,

View File

@@ -1,4 +1,5 @@
import time
from collections.abc import Mapping
from typing import Any
from app.database import db
@@ -7,6 +8,7 @@ from app.models import (
ContactAdvertPath,
ContactAdvertPathSummary,
ContactNameHistory,
ContactUpsert,
)
from app.path_utils import first_hop_hex, normalize_contact_route, normalize_route_override
@@ -22,17 +24,28 @@ class AmbiguousPublicKeyPrefixError(ValueError):
class ContactRepository:
@staticmethod
async def upsert(contact: dict[str, Any]) -> None:
def _coerce_contact_upsert(
contact: ContactUpsert | Contact | Mapping[str, Any],
) -> ContactUpsert:
if isinstance(contact, ContactUpsert):
return contact
if isinstance(contact, Contact):
return contact.to_upsert()
return ContactUpsert.model_validate(contact)
@staticmethod
async def upsert(contact: ContactUpsert | Contact | Mapping[str, Any]) -> None:
contact_row = ContactRepository._coerce_contact_upsert(contact)
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
contact.get("last_path"),
contact.get("last_path_len", -1),
contact.get("out_path_hash_mode"),
contact_row.last_path,
contact_row.last_path_len,
contact_row.out_path_hash_mode,
)
route_override_path, route_override_len, route_override_hash_mode = (
normalize_route_override(
contact.get("route_override_path"),
contact.get("route_override_len"),
contact.get("route_override_hash_mode"),
contact_row.route_override_path,
contact_row.route_override_len,
contact_row.route_override_hash_mode,
)
)
@@ -70,23 +83,23 @@ class ContactRepository:
first_seen = COALESCE(contacts.first_seen, excluded.first_seen)
""",
(
contact.get("public_key", "").lower(),
contact.get("name"),
contact.get("type", 0),
contact.get("flags", 0),
contact_row.public_key.lower(),
contact_row.name,
contact_row.type,
contact_row.flags,
last_path,
last_path_len,
out_path_hash_mode,
route_override_path,
route_override_len,
route_override_hash_mode,
contact.get("last_advert"),
contact.get("lat"),
contact.get("lon"),
contact.get("last_seen", int(time.time())),
contact.get("on_radio"),
contact.get("last_contacted"),
contact.get("first_seen"),
contact_row.last_advert,
contact_row.lat,
contact_row.lon,
contact_row.last_seen if contact_row.last_seen is not None else int(time.time()),
contact_row.on_radio,
contact_row.last_contacted,
contact_row.first_seen,
),
)
await db.conn.commit()

View File

@@ -12,6 +12,7 @@ from app.models import (
ContactAdvertPathSummary,
ContactDetail,
ContactRoutingOverrideRequest,
ContactUpsert,
CreateContactRequest,
NearestRepeater,
TraceResponse,
@@ -133,23 +134,7 @@ async def create_contact(
if existing:
# Update name if provided
if request.name:
await ContactRepository.upsert(
{
"public_key": existing.public_key,
"name": request.name,
"type": existing.type,
"flags": existing.flags,
"last_path": existing.last_path,
"last_path_len": existing.last_path_len,
"out_path_hash_mode": existing.out_path_hash_mode,
"last_advert": existing.last_advert,
"lat": existing.lat,
"lon": existing.lon,
"last_seen": existing.last_seen,
"on_radio": existing.on_radio,
"last_contacted": existing.last_contacted,
}
)
await ContactRepository.upsert(existing.to_upsert(name=request.name))
refreshed = await ContactRepository.get_by_key(request.public_key)
if refreshed is not None:
existing = refreshed
@@ -164,22 +149,13 @@ async def create_contact(
# Create new contact
lower_key = request.public_key.lower()
contact_data = {
"public_key": lower_key,
"name": request.name,
"type": 0, # Unknown
"flags": 0,
"last_path": None,
"last_path_len": -1,
"out_path_hash_mode": -1,
"last_advert": None,
"lat": None,
"lon": None,
"last_seen": None,
"on_radio": False,
"last_contacted": None,
}
await ContactRepository.upsert(contact_data)
contact_upsert = ContactUpsert(
public_key=lower_key,
name=request.name,
out_path_hash_mode=-1,
on_radio=False,
)
await ContactRepository.upsert(contact_upsert)
logger.info("Created contact %s", lower_key[:12])
await reconcile_contact_messages(
@@ -192,7 +168,7 @@ async def create_contact(
if request.try_historical:
await start_historical_dm_decryption(background_tasks, lower_key, request.name)
return Contact(**contact_data)
return Contact(**contact_upsert.model_dump())
@router.get("/{public_key}/detail", response_model=ContactDetail)
@@ -309,7 +285,7 @@ async def sync_contacts_from_radio() -> dict:
for public_key, contact_data in contacts.items():
lower_key = public_key.lower()
await ContactRepository.upsert(
Contact.from_radio_dict(lower_key, contact_data, on_radio=True)
ContactUpsert.from_radio_dict(lower_key, contact_data, on_radio=True)
)
synced_keys.append(lower_key)
await reconcile_contact_messages(

View File

@@ -73,8 +73,10 @@ class RadioRuntime:
"""Return MeshCore when available, mirroring existing HTTP semantics."""
if self.is_setup_in_progress:
raise HTTPException(status_code=503, detail="Radio is initializing")
if not self.is_connected:
raise HTTPException(status_code=503, detail="Radio not connected")
mc = self.meshcore
if not self.is_connected or mc is None:
if mc is None:
raise HTTPException(status_code=503, detail="Radio not connected")
return mc

View File

@@ -65,6 +65,29 @@ def test_require_connected_preserves_http_semantics():
assert exc.value.status_code == 503
def test_require_connected_returns_fresh_meshcore_after_connectivity_check():
old_meshcore = object()
new_meshcore = object()
class _SwappingManager:
def __init__(self):
self._meshcore = old_meshcore
self.is_setup_in_progress = False
@property
def is_connected(self):
self._meshcore = new_meshcore
return True
@property
def meshcore(self):
return self._meshcore
runtime = RadioRuntime(_SwappingManager())
assert runtime.require_connected() is new_meshcore
@pytest.mark.asyncio
async def test_radio_operation_delegates_to_current_manager():
manager = _Manager(meshcore="meshcore", is_connected=True)

View File

@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models import Contact, ContactUpsert
from app.repository import (
ContactAdvertPathRepository,
ContactNameHistoryRepository,
@@ -643,3 +644,34 @@ class TestMessageRepositoryGetById:
result = await MessageRepository.get_by_id(999999)
assert result is None
class TestContactRepositoryUpsertContracts:
@pytest.mark.asyncio
async def test_accepts_contact_upsert_model(self, test_db):
await ContactRepository.upsert(
ContactUpsert(public_key="aa" * 32, name="Alice", type=1, on_radio=False)
)
contact = await ContactRepository.get_by_key("aa" * 32)
assert contact is not None
assert contact.name == "Alice"
assert contact.type == 1
@pytest.mark.asyncio
async def test_accepts_contact_model(self, test_db):
await ContactRepository.upsert(
Contact(
public_key="bb" * 32,
name="Bob",
type=2,
on_radio=True,
out_path_hash_mode=-1,
)
)
contact = await ContactRepository.get_by_key("bb" * 32)
assert contact is not None
assert contact.name == "Bob"
assert contact.type == 2
assert contact.on_radio is True