mirror of
https://github.com/jkingsman/Remote-Terminal-for-MeshCore.git
synced 2026-03-28 17:43:05 +01:00
Be better about DB insertion shape
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from meshcore import EventType
|
||||
|
||||
from app.models import CONTACT_TYPE_REPEATER, Contact
|
||||
from app.models import CONTACT_TYPE_REPEATER, Contact, ContactUpsert
|
||||
from app.packet_processor import process_raw_packet
|
||||
from app.repository import (
|
||||
AmbiguousPublicKeyPrefixError,
|
||||
@@ -228,11 +228,9 @@ async def on_new_contact(event: "Event") -> None:
|
||||
|
||||
logger.debug("New contact: %s", public_key[:12])
|
||||
|
||||
contact_data = {
|
||||
**Contact.from_radio_dict(public_key.lower(), payload, on_radio=True),
|
||||
"last_seen": int(time.time()),
|
||||
}
|
||||
await ContactRepository.upsert(contact_data)
|
||||
contact_upsert = ContactUpsert.from_radio_dict(public_key.lower(), payload, on_radio=True)
|
||||
contact_upsert.last_seen = int(time.time())
|
||||
await ContactRepository.upsert(contact_upsert)
|
||||
|
||||
adv_name = payload.get("adv_name")
|
||||
await record_contact_name_and_reconcile(
|
||||
@@ -245,7 +243,14 @@ async def on_new_contact(event: "Event") -> None:
|
||||
# Read back from DB so the broadcast includes all fields (last_contacted,
|
||||
# last_read_at, etc.) matching the REST Contact shape exactly.
|
||||
db_contact = await ContactRepository.get_by_key(public_key)
|
||||
broadcast_event("contact", (db_contact.model_dump() if db_contact else contact_data))
|
||||
broadcast_event(
|
||||
"contact",
|
||||
(
|
||||
db_contact.model_dump()
|
||||
if db_contact
|
||||
else Contact(**contact_upsert.model_dump(exclude_none=True)).model_dump()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def on_ack(event: "Event") -> None:
|
||||
|
||||
@@ -5,6 +5,64 @@ from pydantic import BaseModel, Field
|
||||
from app.path_utils import normalize_contact_route
|
||||
|
||||
|
||||
class ContactUpsert(BaseModel):
|
||||
"""Typed write contract for contacts persisted to SQLite."""
|
||||
|
||||
public_key: str = Field(description="Public key (64-char hex)")
|
||||
name: str | None = None
|
||||
type: int = 0
|
||||
flags: int = 0
|
||||
last_path: str | None = None
|
||||
last_path_len: int = -1
|
||||
out_path_hash_mode: int | None = None
|
||||
route_override_path: str | None = None
|
||||
route_override_len: int | None = None
|
||||
route_override_hash_mode: int | None = None
|
||||
last_advert: int | None = None
|
||||
lat: float | None = None
|
||||
lon: float | None = None
|
||||
last_seen: int | None = None
|
||||
on_radio: bool | None = None
|
||||
last_contacted: int | None = None
|
||||
first_seen: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_contact(cls, contact: "Contact", **changes) -> "ContactUpsert":
|
||||
return cls.model_validate(
|
||||
{
|
||||
**contact.model_dump(exclude={"last_read_at"}),
|
||||
**changes,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_radio_dict(
|
||||
cls, public_key: str, radio_data: dict, on_radio: bool = False
|
||||
) -> "ContactUpsert":
|
||||
"""Convert radio contact data to the contact-row write shape."""
|
||||
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
|
||||
radio_data.get("out_path"),
|
||||
radio_data.get("out_path_len", -1),
|
||||
radio_data.get(
|
||||
"out_path_hash_mode",
|
||||
-1 if radio_data.get("out_path_len", -1) == -1 else 0,
|
||||
),
|
||||
)
|
||||
return cls(
|
||||
public_key=public_key,
|
||||
name=radio_data.get("adv_name"),
|
||||
type=radio_data.get("type", 0),
|
||||
flags=radio_data.get("flags", 0),
|
||||
last_path=last_path,
|
||||
last_path_len=last_path_len,
|
||||
out_path_hash_mode=out_path_hash_mode,
|
||||
lat=radio_data.get("adv_lat"),
|
||||
lon=radio_data.get("adv_lon"),
|
||||
last_advert=radio_data.get("last_advert"),
|
||||
on_radio=on_radio,
|
||||
)
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
public_key: str = Field(description="Public key (64-char hex)")
|
||||
name: str | None = None
|
||||
@@ -61,34 +119,18 @@ class Contact(BaseModel):
|
||||
"last_advert": self.last_advert if self.last_advert is not None else 0,
|
||||
}
|
||||
|
||||
def to_upsert(self, **changes) -> ContactUpsert:
|
||||
"""Convert the stored contact to the repository's write contract."""
|
||||
return ContactUpsert.from_contact(self, **changes)
|
||||
|
||||
@staticmethod
|
||||
def from_radio_dict(public_key: str, radio_data: dict, on_radio: bool = False) -> dict:
|
||||
"""Convert radio contact data to database format dict.
|
||||
|
||||
This is the inverse of to_radio_dict(), used when syncing contacts
|
||||
from radio to database.
|
||||
"""
|
||||
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
|
||||
radio_data.get("out_path"),
|
||||
radio_data.get("out_path_len", -1),
|
||||
radio_data.get(
|
||||
"out_path_hash_mode",
|
||||
-1 if radio_data.get("out_path_len", -1) == -1 else 0,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"public_key": public_key,
|
||||
"name": radio_data.get("adv_name"),
|
||||
"type": radio_data.get("type", 0),
|
||||
"flags": radio_data.get("flags", 0),
|
||||
"last_path": last_path,
|
||||
"last_path_len": last_path_len,
|
||||
"out_path_hash_mode": out_path_hash_mode,
|
||||
"lat": radio_data.get("adv_lat"),
|
||||
"lon": radio_data.get("adv_lon"),
|
||||
"last_advert": radio_data.get("last_advert"),
|
||||
"on_radio": on_radio,
|
||||
}
|
||||
"""Backward-compatible dict wrapper over ContactUpsert.from_radio_dict()."""
|
||||
return ContactUpsert.from_radio_dict(
|
||||
public_key,
|
||||
radio_data,
|
||||
on_radio=on_radio,
|
||||
).model_dump()
|
||||
|
||||
|
||||
class CreateContactRequest(BaseModel):
|
||||
|
||||
@@ -30,6 +30,8 @@ from app.decoder import (
|
||||
from app.keystore import get_private_key, get_public_key, has_private_key
|
||||
from app.models import (
|
||||
CONTACT_TYPE_REPEATER,
|
||||
Contact,
|
||||
ContactUpsert,
|
||||
RawPacketBroadcast,
|
||||
RawPacketDecryptedInfo,
|
||||
)
|
||||
@@ -489,21 +491,21 @@ async def _process_advertisement(
|
||||
hop_count=new_path_len,
|
||||
)
|
||||
|
||||
contact_data = {
|
||||
"public_key": advert.public_key.lower(),
|
||||
"name": advert.name,
|
||||
"type": contact_type,
|
||||
"lat": advert.lat,
|
||||
"lon": advert.lon,
|
||||
"last_advert": advert.timestamp if advert.timestamp > 0 else timestamp,
|
||||
"last_seen": timestamp,
|
||||
"last_path": path_hex,
|
||||
"last_path_len": path_len,
|
||||
"out_path_hash_mode": out_path_hash_mode,
|
||||
"first_seen": timestamp, # COALESCE in upsert preserves existing value
|
||||
}
|
||||
contact_upsert = ContactUpsert(
|
||||
public_key=advert.public_key.lower(),
|
||||
name=advert.name,
|
||||
type=contact_type,
|
||||
lat=advert.lat,
|
||||
lon=advert.lon,
|
||||
last_advert=advert.timestamp if advert.timestamp > 0 else timestamp,
|
||||
last_seen=timestamp,
|
||||
last_path=path_hex,
|
||||
last_path_len=path_len,
|
||||
out_path_hash_mode=out_path_hash_mode,
|
||||
first_seen=timestamp, # COALESCE in upsert preserves existing value
|
||||
)
|
||||
|
||||
await ContactRepository.upsert(contact_data)
|
||||
await ContactRepository.upsert(contact_upsert)
|
||||
await record_contact_name_and_reconcile(
|
||||
public_key=advert.public_key,
|
||||
contact_name=advert.name,
|
||||
@@ -517,7 +519,10 @@ async def _process_advertisement(
|
||||
if db_contact:
|
||||
broadcast_event("contact", db_contact.model_dump())
|
||||
else:
|
||||
broadcast_event("contact", contact_data)
|
||||
broadcast_event(
|
||||
"contact",
|
||||
Contact(**contact_upsert.model_dump(exclude_none=True)).model_dump(),
|
||||
)
|
||||
|
||||
# For new contacts, optionally attempt to decrypt any historical DMs we may have stored
|
||||
# This is controlled by the auto_decrypt_dm_on_advert setting
|
||||
|
||||
@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
|
||||
from meshcore import EventType, MeshCore
|
||||
|
||||
from app.event_handlers import cleanup_expired_acks
|
||||
from app.models import Contact
|
||||
from app.models import Contact, ContactUpsert
|
||||
from app.radio import RadioOperationBusyError
|
||||
from app.repository import (
|
||||
AmbiguousPublicKeyPrefixError,
|
||||
@@ -155,7 +155,7 @@ async def sync_and_offload_contacts(mc: MeshCore) -> dict:
|
||||
for public_key, contact_data in contacts.items():
|
||||
# Save to database
|
||||
await ContactRepository.upsert(
|
||||
Contact.from_radio_dict(public_key, contact_data, on_radio=False)
|
||||
ContactUpsert.from_radio_dict(public_key, contact_data, on_radio=False)
|
||||
)
|
||||
await reconcile_contact_messages(
|
||||
public_key=public_key,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from app.database import db
|
||||
@@ -7,6 +8,7 @@ from app.models import (
|
||||
ContactAdvertPath,
|
||||
ContactAdvertPathSummary,
|
||||
ContactNameHistory,
|
||||
ContactUpsert,
|
||||
)
|
||||
from app.path_utils import first_hop_hex, normalize_contact_route, normalize_route_override
|
||||
|
||||
@@ -22,17 +24,28 @@ class AmbiguousPublicKeyPrefixError(ValueError):
|
||||
|
||||
class ContactRepository:
|
||||
@staticmethod
|
||||
async def upsert(contact: dict[str, Any]) -> None:
|
||||
def _coerce_contact_upsert(
|
||||
contact: ContactUpsert | Contact | Mapping[str, Any],
|
||||
) -> ContactUpsert:
|
||||
if isinstance(contact, ContactUpsert):
|
||||
return contact
|
||||
if isinstance(contact, Contact):
|
||||
return contact.to_upsert()
|
||||
return ContactUpsert.model_validate(contact)
|
||||
|
||||
@staticmethod
|
||||
async def upsert(contact: ContactUpsert | Contact | Mapping[str, Any]) -> None:
|
||||
contact_row = ContactRepository._coerce_contact_upsert(contact)
|
||||
last_path, last_path_len, out_path_hash_mode = normalize_contact_route(
|
||||
contact.get("last_path"),
|
||||
contact.get("last_path_len", -1),
|
||||
contact.get("out_path_hash_mode"),
|
||||
contact_row.last_path,
|
||||
contact_row.last_path_len,
|
||||
contact_row.out_path_hash_mode,
|
||||
)
|
||||
route_override_path, route_override_len, route_override_hash_mode = (
|
||||
normalize_route_override(
|
||||
contact.get("route_override_path"),
|
||||
contact.get("route_override_len"),
|
||||
contact.get("route_override_hash_mode"),
|
||||
contact_row.route_override_path,
|
||||
contact_row.route_override_len,
|
||||
contact_row.route_override_hash_mode,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -70,23 +83,23 @@ class ContactRepository:
|
||||
first_seen = COALESCE(contacts.first_seen, excluded.first_seen)
|
||||
""",
|
||||
(
|
||||
contact.get("public_key", "").lower(),
|
||||
contact.get("name"),
|
||||
contact.get("type", 0),
|
||||
contact.get("flags", 0),
|
||||
contact_row.public_key.lower(),
|
||||
contact_row.name,
|
||||
contact_row.type,
|
||||
contact_row.flags,
|
||||
last_path,
|
||||
last_path_len,
|
||||
out_path_hash_mode,
|
||||
route_override_path,
|
||||
route_override_len,
|
||||
route_override_hash_mode,
|
||||
contact.get("last_advert"),
|
||||
contact.get("lat"),
|
||||
contact.get("lon"),
|
||||
contact.get("last_seen", int(time.time())),
|
||||
contact.get("on_radio"),
|
||||
contact.get("last_contacted"),
|
||||
contact.get("first_seen"),
|
||||
contact_row.last_advert,
|
||||
contact_row.lat,
|
||||
contact_row.lon,
|
||||
contact_row.last_seen if contact_row.last_seen is not None else int(time.time()),
|
||||
contact_row.on_radio,
|
||||
contact_row.last_contacted,
|
||||
contact_row.first_seen,
|
||||
),
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.models import (
|
||||
ContactAdvertPathSummary,
|
||||
ContactDetail,
|
||||
ContactRoutingOverrideRequest,
|
||||
ContactUpsert,
|
||||
CreateContactRequest,
|
||||
NearestRepeater,
|
||||
TraceResponse,
|
||||
@@ -133,23 +134,7 @@ async def create_contact(
|
||||
if existing:
|
||||
# Update name if provided
|
||||
if request.name:
|
||||
await ContactRepository.upsert(
|
||||
{
|
||||
"public_key": existing.public_key,
|
||||
"name": request.name,
|
||||
"type": existing.type,
|
||||
"flags": existing.flags,
|
||||
"last_path": existing.last_path,
|
||||
"last_path_len": existing.last_path_len,
|
||||
"out_path_hash_mode": existing.out_path_hash_mode,
|
||||
"last_advert": existing.last_advert,
|
||||
"lat": existing.lat,
|
||||
"lon": existing.lon,
|
||||
"last_seen": existing.last_seen,
|
||||
"on_radio": existing.on_radio,
|
||||
"last_contacted": existing.last_contacted,
|
||||
}
|
||||
)
|
||||
await ContactRepository.upsert(existing.to_upsert(name=request.name))
|
||||
refreshed = await ContactRepository.get_by_key(request.public_key)
|
||||
if refreshed is not None:
|
||||
existing = refreshed
|
||||
@@ -164,22 +149,13 @@ async def create_contact(
|
||||
|
||||
# Create new contact
|
||||
lower_key = request.public_key.lower()
|
||||
contact_data = {
|
||||
"public_key": lower_key,
|
||||
"name": request.name,
|
||||
"type": 0, # Unknown
|
||||
"flags": 0,
|
||||
"last_path": None,
|
||||
"last_path_len": -1,
|
||||
"out_path_hash_mode": -1,
|
||||
"last_advert": None,
|
||||
"lat": None,
|
||||
"lon": None,
|
||||
"last_seen": None,
|
||||
"on_radio": False,
|
||||
"last_contacted": None,
|
||||
}
|
||||
await ContactRepository.upsert(contact_data)
|
||||
contact_upsert = ContactUpsert(
|
||||
public_key=lower_key,
|
||||
name=request.name,
|
||||
out_path_hash_mode=-1,
|
||||
on_radio=False,
|
||||
)
|
||||
await ContactRepository.upsert(contact_upsert)
|
||||
logger.info("Created contact %s", lower_key[:12])
|
||||
|
||||
await reconcile_contact_messages(
|
||||
@@ -192,7 +168,7 @@ async def create_contact(
|
||||
if request.try_historical:
|
||||
await start_historical_dm_decryption(background_tasks, lower_key, request.name)
|
||||
|
||||
return Contact(**contact_data)
|
||||
return Contact(**contact_upsert.model_dump())
|
||||
|
||||
|
||||
@router.get("/{public_key}/detail", response_model=ContactDetail)
|
||||
@@ -309,7 +285,7 @@ async def sync_contacts_from_radio() -> dict:
|
||||
for public_key, contact_data in contacts.items():
|
||||
lower_key = public_key.lower()
|
||||
await ContactRepository.upsert(
|
||||
Contact.from_radio_dict(lower_key, contact_data, on_radio=True)
|
||||
ContactUpsert.from_radio_dict(lower_key, contact_data, on_radio=True)
|
||||
)
|
||||
synced_keys.append(lower_key)
|
||||
await reconcile_contact_messages(
|
||||
|
||||
@@ -73,8 +73,10 @@ class RadioRuntime:
|
||||
"""Return MeshCore when available, mirroring existing HTTP semantics."""
|
||||
if self.is_setup_in_progress:
|
||||
raise HTTPException(status_code=503, detail="Radio is initializing")
|
||||
if not self.is_connected:
|
||||
raise HTTPException(status_code=503, detail="Radio not connected")
|
||||
mc = self.meshcore
|
||||
if not self.is_connected or mc is None:
|
||||
if mc is None:
|
||||
raise HTTPException(status_code=503, detail="Radio not connected")
|
||||
return mc
|
||||
|
||||
|
||||
@@ -65,6 +65,29 @@ def test_require_connected_preserves_http_semantics():
|
||||
assert exc.value.status_code == 503
|
||||
|
||||
|
||||
def test_require_connected_returns_fresh_meshcore_after_connectivity_check():
|
||||
old_meshcore = object()
|
||||
new_meshcore = object()
|
||||
|
||||
class _SwappingManager:
|
||||
def __init__(self):
|
||||
self._meshcore = old_meshcore
|
||||
self.is_setup_in_progress = False
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
self._meshcore = new_meshcore
|
||||
return True
|
||||
|
||||
@property
|
||||
def meshcore(self):
|
||||
return self._meshcore
|
||||
|
||||
runtime = RadioRuntime(_SwappingManager())
|
||||
|
||||
assert runtime.require_connected() is new_meshcore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_radio_operation_delegates_to_current_manager():
|
||||
manager = _Manager(meshcore="meshcore", is_connected=True)
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models import Contact, ContactUpsert
|
||||
from app.repository import (
|
||||
ContactAdvertPathRepository,
|
||||
ContactNameHistoryRepository,
|
||||
@@ -643,3 +644,34 @@ class TestMessageRepositoryGetById:
|
||||
result = await MessageRepository.get_by_id(999999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestContactRepositoryUpsertContracts:
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_contact_upsert_model(self, test_db):
|
||||
await ContactRepository.upsert(
|
||||
ContactUpsert(public_key="aa" * 32, name="Alice", type=1, on_radio=False)
|
||||
)
|
||||
|
||||
contact = await ContactRepository.get_by_key("aa" * 32)
|
||||
assert contact is not None
|
||||
assert contact.name == "Alice"
|
||||
assert contact.type == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_contact_model(self, test_db):
|
||||
await ContactRepository.upsert(
|
||||
Contact(
|
||||
public_key="bb" * 32,
|
||||
name="Bob",
|
||||
type=2,
|
||||
on_radio=True,
|
||||
out_path_hash_mode=-1,
|
||||
)
|
||||
)
|
||||
|
||||
contact = await ContactRepository.get_by_key("bb" * 32)
|
||||
assert contact is not None
|
||||
assert contact.name == "Bob"
|
||||
assert contact.type == 2
|
||||
assert contact.on_radio is True
|
||||
|
||||
Reference in New Issue
Block a user