Remove some unneeded duplication and fix up reconnection management

This commit is contained in:
Jack Kingsman
2026-01-30 21:03:58 -08:00
parent b6c3e13234
commit 1ea809c4e3
16 changed files with 50 additions and 152 deletions

View File

@@ -11,7 +11,6 @@ class Settings(BaseSettings):
serial_baudrate: int = 115200
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
database_path: str = "data/meshcore.db"
max_radio_contacts: int = 200 # Max non-repeater contacts to keep on radio for DM ACKs
settings = Settings()

View File

@@ -393,21 +393,6 @@ def parse_advertisement(payload: bytes) -> ParsedAdvertisement | None:
)
def try_parse_advertisement(raw_packet: bytes) -> ParsedAdvertisement | None:
"""
Try to parse a raw packet as an advertisement.
Returns parsed advertisement if successful, None otherwise.
"""
packet_info = parse_packet(raw_packet)
if packet_info is None:
return None
if packet_info.payload_type != PayloadType.ADVERT:
return None
return parse_advertisement(packet_info.payload)
# =============================================================================
# Direct Message (TEXT_MESSAGE) Decryption
# =============================================================================

View File

@@ -64,14 +64,6 @@ def has_private_key() -> bool:
return _private_key is not None
def clear_private_key() -> None:
"""Clear the stored private key from memory."""
global _private_key, _public_key
_private_key = None
_public_key = None
logger.info("Private key cleared from keystore")
async def export_and_store_private_key(mc: "MeshCore") -> bool:
"""Export private key from the radio and store it in the keystore.

View File

@@ -103,20 +103,6 @@ class Message(BaseModel):
acked: int = 0
class RawPacket(BaseModel):
"""Raw packet as stored in the database."""
id: int
timestamp: int
data: str = Field(description="Hex-encoded packet data")
message_id: int | None = None
@property
def decrypted(self) -> bool:
"""A packet is decrypted iff it has a linked message_id."""
return self.message_id is not None
class RawPacketDecryptedInfo(BaseModel):
"""Decryption info for a raw packet (when successfully decrypted)."""

View File

@@ -107,6 +107,20 @@ class RadioManager:
self._last_connected: bool = False
self._reconnect_lock: asyncio.Lock | None = None
async def post_connect_setup(self) -> None:
"""Register event handlers, export private key, and start message fetching.
Called after every successful connection or reconnection.
"""
from app.event_handlers import register_event_handlers
from app.keystore import export_and_store_private_key
if self._meshcore:
register_event_handlers(self._meshcore)
await export_and_store_private_key(self._meshcore)
await self._meshcore.start_auto_message_fetching()
logger.info("Post-connect setup complete (handlers, key export, message fetching)")
@property
def meshcore(self) -> MeshCore | None:
return self._meshcore
@@ -229,15 +243,7 @@ class RadioManager:
# Attempt reconnection
await asyncio.sleep(3) # Wait a bit before trying
if await self.reconnect():
# Re-register event handlers after successful reconnect
from app.event_handlers import register_event_handlers
from app.keystore import export_and_store_private_key
if self._meshcore:
register_event_handlers(self._meshcore)
await export_and_store_private_key(self._meshcore)
await self._meshcore.start_auto_message_fetching()
logger.info("Event handlers re-registered after auto-reconnect")
await self.post_connect_setup()
elif not self._last_connected and current_connected:
# Connection restored (might have reconnected automatically)

View File

@@ -16,10 +16,9 @@ from contextlib import asynccontextmanager
from meshcore import EventType
from app.config import settings
from app.models import Contact
from app.radio import radio_manager
from app.repository import ChannelRepository, ContactRepository
from app.repository import AppSettingsRepository, ChannelRepository, ContactRepository
logger = logging.getLogger(__name__)
@@ -347,8 +346,6 @@ async def send_advertisement(force: bool = False) -> bool:
Returns True if successful, False otherwise (including if throttled).
"""
from app.repository import AppSettingsRepository
if not radio_manager.is_connected or radio_manager.meshcore is None:
logger.debug("Cannot send advertisement: radio not connected")
return False
@@ -530,7 +527,8 @@ async def sync_recent_contacts_to_radio(force: bool = False) -> dict:
try:
# Get recent non-repeater contacts from database
max_contacts = settings.max_radio_contacts
app_settings = await AppSettingsRepository.get()
max_contacts = app_settings.max_radio_contacts
contacts = await ContactRepository.get_recent_non_repeaters(limit=max_contacts)
logger.debug("Found %d recent non-repeater contacts to sync", len(contacts))

View File

@@ -15,7 +15,6 @@ from app.models import (
Favorite,
Message,
MessagePath,
RawPacket,
)
logger = logging.getLogger(__name__)
@@ -271,13 +270,6 @@ class MessageRepository:
except (json.JSONDecodeError, TypeError, KeyError):
return None
@staticmethod
def _serialize_paths(paths: list[dict] | None) -> str | None:
"""Serialize paths list to JSON string."""
if not paths:
return None
return json.dumps(paths)
@staticmethod
async def create(
msg_type: str,
@@ -621,29 +613,6 @@ class RawPacketRepository:
)
await db.conn.commit()
@staticmethod
async def get_undecrypted(limit: int = 100) -> list[RawPacket]:
"""Get undecrypted packets (those without a linked message)."""
cursor = await db.conn.execute(
"""
SELECT id, timestamp, data, message_id FROM raw_packets
WHERE message_id IS NULL
ORDER BY timestamp DESC
LIMIT ?
""",
(limit,),
)
rows = await cursor.fetchall()
return [
RawPacket(
id=row["id"],
timestamp=row["timestamp"],
data=row["data"].hex(),
message_id=row["message_id"],
)
for row in rows
]
@staticmethod
async def prune_old_undecrypted(max_age_days: int) -> int:
"""Delete undecrypted packets older than max_age_days. Returns count deleted."""

View File

@@ -189,13 +189,7 @@ async def reboot_radio() -> dict:
success = await radio_manager.reconnect()
if success:
# Re-register event handlers after successful reconnect
from app.event_handlers import register_event_handlers
if radio_manager.meshcore:
register_event_handlers(radio_manager.meshcore)
await radio_manager.meshcore.start_auto_message_fetching()
logger.info("Event handlers re-registered and auto message fetching started")
await radio_manager.post_connect_setup()
return {"status": "ok", "message": "Reconnected successfully", "connected": True}
else:
@@ -228,14 +222,7 @@ async def reconnect_radio() -> dict:
success = await radio_manager.reconnect()
if success:
# Re-register event handlers after successful reconnect
from app.event_handlers import register_event_handlers
if radio_manager.meshcore:
register_event_handlers(radio_manager.meshcore)
# Restart auto message fetching
await radio_manager.meshcore.start_auto_message_fetching()
logger.info("Event handlers re-registered and auto message fetching started")
await radio_manager.post_connect_setup()
return {"status": "ok", "message": "Reconnected successfully", "connected": True}
else:

View File

@@ -130,20 +130,6 @@ async def update_settings(update: AppSettingsUpdate) -> AppSettings:
return await AppSettingsRepository.get()
@router.post("/favorites", response_model=AppSettings)
async def add_favorite(request: FavoriteRequest) -> AppSettings:
"""Add a conversation to favorites."""
logger.info("Adding favorite: %s %s", request.type, request.id[:12])
return await AppSettingsRepository.add_favorite(request.type, request.id)
@router.delete("/favorites", response_model=AppSettings)
async def remove_favorite(request: FavoriteRequest) -> AppSettings:
"""Remove a conversation from favorites."""
logger.info("Removing favorite: %s %s", request.type, request.id[:12])
return await AppSettingsRepository.remove_favorite(request.type, request.id)
@router.post("/favorites/toggle", response_model=AppSettings)
async def toggle_favorite(request: FavoriteRequest) -> AppSettings:
"""Toggle a conversation's favorite status."""

File diff suppressed because one or more lines are too long

View File

@@ -4,13 +4,12 @@ import {
findContactsByPrefix,
calculateDistance,
sortContactsByDistance,
getHopCount,
resolvePath,
formatDistance,
formatHopCounts,
} from '../utils/pathUtils';
import type { Contact, RadioConfig } from '../types';
import { CONTACT_TYPE_REPEATER, CONTACT_TYPE_CLIENT } from '../types';
import { CONTACT_TYPE_REPEATER } from '../types';
// Helper to create mock contacts
function createContact(overrides: Partial<Contact> = {}): Contact {
@@ -90,7 +89,7 @@ describe('findContactsByPrefix', () => {
createContact({
public_key: '1ACCCC' + 'C'.repeat(52),
name: 'Client1',
type: CONTACT_TYPE_CLIENT,
type: 1, // client
}),
];
@@ -195,20 +194,6 @@ describe('sortContactsByDistance', () => {
});
});
describe('getHopCount', () => {
it('returns 0 for null/empty', () => {
expect(getHopCount(null)).toBe(0);
expect(getHopCount(undefined)).toBe(0);
expect(getHopCount('')).toBe(0);
});
it('counts hops correctly', () => {
expect(getHopCount('1A')).toBe(1);
expect(getHopCount('1A2B')).toBe(2);
expect(getHopCount('1A2B3C')).toBe(3);
});
});
describe('resolvePath', () => {
const repeater1 = createContact({
public_key: '1A' + 'A'.repeat(62),

View File

@@ -10,7 +10,7 @@
import { describe, it, expect } from 'vitest';
import { parseSenderFromText } from '../utils/messageParser';
import { CONTACT_TYPE_REPEATER, CONTACT_TYPE_CLIENT } from '../types';
import { CONTACT_TYPE_REPEATER } from '../types';
describe('Repeater message sender parsing', () => {
/**
@@ -52,7 +52,7 @@ describe('Repeater message sender parsing', () => {
it('non-repeater messages still get sender parsed', () => {
const channelMessage = 'Alice: Hello everyone!';
const contactType: number = CONTACT_TYPE_CLIENT;
const contactType: number = 1; // client
const isRepeater = contactType === CONTACT_TYPE_REPEATER;
const { sender, content } = isRepeater

View File

@@ -164,7 +164,6 @@ export interface MigratePreferencesResponse {
}
/** Contact type constants */
export const CONTACT_TYPE_CLIENT = 1;
export const CONTACT_TYPE_REPEATER = 2;
export interface NeighborInfo {

View File

@@ -148,7 +148,7 @@ export function sortContactsByDistance(
/**
* Get simple hop count from path string
*/
export function getHopCount(path: string | null | undefined): number {
function getHopCount(path: string | null | undefined): number {
if (!path || path.length === 0) {
return 0;
}

View File

@@ -254,7 +254,7 @@ class TestAdvertisementParsing:
def test_parse_repeater_advertisement_with_gps(self):
"""Parse a repeater advertisement with GPS coordinates."""
from app.decoder import try_parse_advertisement
from app.decoder import parse_advertisement, parse_packet
# Repeater packet with lat/lon of 49.02056 / -123.82935
# Flags 0x92: Role=Repeater (2), Location=Yes, Name=Yes
@@ -266,7 +266,9 @@ class TestAdvertisementParsing:
)
packet = bytes.fromhex(packet_hex)
result = try_parse_advertisement(packet)
info = parse_packet(packet)
assert info is not None
result = parse_advertisement(info.payload)
assert result is not None
assert (
@@ -282,7 +284,7 @@ class TestAdvertisementParsing:
def test_parse_chat_node_advertisement_with_gps(self):
"""Parse a chat node advertisement with GPS coordinates."""
from app.decoder import try_parse_advertisement
from app.decoder import parse_advertisement, parse_packet
# Chat node packet with lat/lon of 47.786038 / -122.344096
# Flags 0x91: Role=Chat (1), Location=Yes, Name=Yes
@@ -294,7 +296,9 @@ class TestAdvertisementParsing:
)
packet = bytes.fromhex(packet_hex)
result = try_parse_advertisement(packet)
info = parse_packet(packet)
assert info is not None
result = parse_advertisement(info.payload)
assert result is not None
assert (
@@ -310,7 +314,7 @@ class TestAdvertisementParsing:
def test_parse_advertisement_without_gps(self):
"""Parse an advertisement without GPS coordinates."""
from app.decoder import try_parse_advertisement
from app.decoder import parse_advertisement, parse_packet
# Chat node packet without location
# Flags 0x81: Role=Chat (1), Location=No, Name=Yes
@@ -322,7 +326,9 @@ class TestAdvertisementParsing:
)
packet = bytes.fromhex(packet_hex)
result = try_parse_advertisement(packet)
info = parse_packet(packet)
assert info is not None
result = parse_advertisement(info.payload)
assert result is not None
assert (
@@ -352,15 +358,15 @@ class TestAdvertisementParsing:
assert info.payload_type == PayloadType.ADVERT
def test_non_advertisement_returns_none(self):
"""Non-advertisement packets return None from try_parse_advertisement."""
from app.decoder import try_parse_advertisement
"""Non-advertisement packets return None when parsed as advertisement."""
from app.decoder import PayloadType, parse_packet
# GROUP_TEXT packet, not an advertisement
packet = bytes([0x15, 0x00]) + bytes(50)
result = try_parse_advertisement(packet)
assert result is None
info = parse_packet(packet)
assert info is not None
assert info.payload_type != PayloadType.ADVERT
class TestScalarClamping:

View File

@@ -168,7 +168,7 @@ class TestChannelMessagePipeline:
assert result is not None
# Raw packet should be stored
raw_packets = await RawPacketRepository.get_undecrypted(limit=10)
raw_packets = await RawPacketRepository.get_all_undecrypted()
assert len(raw_packets) >= 1
# No message broadcast (only raw_packet broadcast)
@@ -576,8 +576,8 @@ class TestCreateMessageFromDecrypted:
)
# Verify packet is marked decrypted (has message_id set)
undecrypted = await RawPacketRepository.get_undecrypted(limit=100)
packet_ids = [p.id for p in undecrypted]
undecrypted = await RawPacketRepository.get_all_undecrypted()
packet_ids = [p[0] for p in undecrypted]
assert packet_id not in packet_ids # Should be marked as decrypted
@@ -831,8 +831,8 @@ class TestCreateDMMessageFromDecrypted:
)
# Verify packet is marked decrypted
undecrypted = await RawPacketRepository.get_undecrypted(limit=100)
packet_ids = [p.id for p in undecrypted]
undecrypted = await RawPacketRepository.get_all_undecrypted()
packet_ids = [p[0] for p in undecrypted]
assert packet_id not in packet_ids
@pytest.mark.asyncio
@@ -939,8 +939,8 @@ class TestDMDecryptionFunction:
assert messages[0].outgoing is False
# Verify raw packet is linked
undecrypted = await RawPacketRepository.get_undecrypted(limit=100)
assert packet_id not in [p.id for p in undecrypted]
undecrypted = await RawPacketRepository.get_all_undecrypted()
assert packet_id not in [p[0] for p in undecrypted]
class TestRepeaterMessageFiltering: