forked from iarv/Remote-Terminal-for-MeshCore
Add server-side read management
This commit is contained in:
33
CLAUDE.md
33
CLAUDE.md
@@ -152,7 +152,8 @@ Key test files:
|
||||
- `tests/test_decoder.py` - Channel + direct message decryption, key exchange
|
||||
- `tests/test_keystore.py` - Ephemeral key store
|
||||
- `tests/test_event_handlers.py` - ACK tracking, repeat detection
|
||||
- `tests/test_api.py` - API endpoints
|
||||
- `tests/test_api.py` - API endpoints, read state tracking
|
||||
- `tests/test_migrations.py` - Database migration system
|
||||
|
||||
### Frontend (Vitest)
|
||||
|
||||
@@ -165,6 +166,23 @@ npm run test:run
|
||||
|
||||
Open `integration_test.html` in a browser with the backend running.
|
||||
|
||||
### Before Completing Changes
|
||||
|
||||
**Always run both backend and frontend validation before finishing any changes:**
|
||||
|
||||
```bash
|
||||
# From project root - run backend tests
|
||||
PYTHONPATH=. uv run pytest tests/ -v
|
||||
|
||||
# From project root - run frontend tests and build
|
||||
cd frontend && npm run test:run && npm run build
|
||||
```
|
||||
|
||||
This catches:
|
||||
- Type mismatches between frontend and backend (e.g., missing fields in TypeScript interfaces)
|
||||
- Breaking changes to shared types or API contracts
|
||||
- Runtime errors that only surface during compilation
|
||||
|
||||
## API Summary
|
||||
|
||||
All endpoints are prefixed with `/api` (e.g., `/api/health`).
|
||||
@@ -188,6 +206,9 @@ All endpoints are prefixed with `/api` (e.g., `/api/health`).
|
||||
| POST | `/api/messages/direct` | Send direct message |
|
||||
| POST | `/api/messages/channel` | Send channel message |
|
||||
| POST | `/api/packets/decrypt/historical` | Decrypt stored packets |
|
||||
| POST | `/api/contacts/{key}/mark-read` | Mark contact conversation as read |
|
||||
| POST | `/api/channels/{key}/mark-read` | Mark channel as read |
|
||||
| POST | `/api/read-state/mark-all-read` | Mark all conversations as read |
|
||||
| GET | `/api/settings` | Get app settings |
|
||||
| PATCH | `/api/settings` | Update app settings |
|
||||
| WS | `/api/ws` | Real-time updates |
|
||||
@@ -219,9 +240,15 @@ All endpoints are prefixed with `/api` (e.g., `/api/health`).
|
||||
- `CHAN` - Channel messages
|
||||
- Both use `conversation_key` (user pubkey for PRIV, channel key for CHAN)
|
||||
|
||||
### State Tracking Keys (Frontend)
|
||||
### Read State Tracking
|
||||
|
||||
Generated by `getStateKey()` for unread tracking and message times:
|
||||
Read state (`last_read_at`) is tracked **server-side** for consistency across devices:
|
||||
- Stored as Unix timestamp in `contacts.last_read_at` and `channels.last_read_at`
|
||||
- Updated via `POST /api/contacts/{key}/mark-read` and `POST /api/channels/{key}/mark-read`
|
||||
- Bulk update via `POST /api/read-state/mark-all-read`
|
||||
- Frontend compares `last_read_at` with message `received_at` to count unreads
|
||||
|
||||
**State Tracking Keys (Frontend)**: Generated by `getStateKey()` for message times (sidebar sorting):
|
||||
- Channels: `channel-{channel_key}`
|
||||
- Contacts: `contact-{12-char-pubkey-prefix}`
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ This document provides context for AI assistants and developers working on the F
|
||||
app/
|
||||
├── main.py # FastAPI app, lifespan, router registration, static file serving
|
||||
├── config.py # Pydantic settings (env vars: MESHCORE_*)
|
||||
├── database.py # SQLite schema, connection management
|
||||
├── database.py # SQLite schema, connection management, runs migrations
|
||||
├── migrations.py # Database migrations using SQLite user_version pragma
|
||||
├── models.py # Pydantic models for API request/response
|
||||
├── repository.py # Database CRUD (ContactRepository, ChannelRepository, etc.)
|
||||
├── radio.py # RadioManager - serial connection to MeshCore device
|
||||
@@ -30,10 +31,11 @@ app/
|
||||
└── routers/ # All routes prefixed with /api
|
||||
├── health.py # GET /api/health
|
||||
├── radio.py # Radio config, advertise, private key, reboot
|
||||
├── contacts.py # Contact CRUD and radio sync
|
||||
├── channels.py # Channel CRUD and radio sync
|
||||
├── contacts.py # Contact CRUD, radio sync, mark-read
|
||||
├── channels.py # Channel CRUD, radio sync, mark-read
|
||||
├── messages.py # Message list and send (direct/channel)
|
||||
├── packets.py # Raw packet endpoints, historical decryption
|
||||
├── read_state.py # Bulk read state operations (mark-all-read)
|
||||
├── settings.py # App settings (max_radio_contacts)
|
||||
└── ws.py # WebSocket endpoint at /api/ws
|
||||
```
|
||||
@@ -138,14 +140,16 @@ contacts (
|
||||
lat REAL, lon REAL,
|
||||
last_seen INTEGER,
|
||||
on_radio INTEGER DEFAULT 0, -- Boolean: contact loaded on radio
|
||||
last_contacted INTEGER -- Unix timestamp of last message sent/received
|
||||
last_contacted INTEGER, -- Unix timestamp of last message sent/received
|
||||
last_read_at INTEGER -- Unix timestamp when conversation was last read
|
||||
)
|
||||
|
||||
channels (
|
||||
key TEXT PRIMARY KEY, -- 32-char hex channel key
|
||||
name TEXT NOT NULL,
|
||||
is_hashtag INTEGER DEFAULT 0, -- Key derived from SHA256(name)[:16]
|
||||
on_radio INTEGER DEFAULT 0
|
||||
on_radio INTEGER DEFAULT 0,
|
||||
last_read_at INTEGER -- Unix timestamp when channel was last read
|
||||
)
|
||||
|
||||
messages (
|
||||
@@ -175,6 +179,49 @@ raw_packets (
|
||||
)
|
||||
```
|
||||
|
||||
## Database Migrations (`migrations.py`)
|
||||
|
||||
Schema migrations use SQLite's `user_version` pragma for version tracking:
|
||||
|
||||
```python
|
||||
from app.migrations import get_version, set_version, run_migrations
|
||||
|
||||
# Check current schema version
|
||||
version = await get_version(conn) # Returns int (0 for new/unmigrated DB)
|
||||
|
||||
# Run pending migrations (called automatically on startup)
|
||||
applied = await run_migrations(conn) # Returns number of migrations applied
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. `database.py` calls `run_migrations()` after schema initialization
|
||||
2. Each migration checks `user_version` and runs if needed
|
||||
3. Migrations are idempotent (safe to run multiple times)
|
||||
4. `ALTER TABLE ADD COLUMN` handles existing columns gracefully
|
||||
|
||||
### Adding a New Migration
|
||||
|
||||
```python
|
||||
# In migrations.py
|
||||
async def run_migrations(conn: aiosqlite.Connection) -> int:
|
||||
version = await get_version(conn)
|
||||
applied = 0
|
||||
|
||||
if version < 1:
|
||||
await _migrate_001_add_last_read_at(conn)
|
||||
await set_version(conn, 1)
|
||||
applied += 1
|
||||
|
||||
# Add new migrations here:
|
||||
# if version < 2:
|
||||
# await _migrate_002_something(conn)
|
||||
# await set_version(conn, 2)
|
||||
# applied += 1
|
||||
|
||||
return applied
|
||||
```
|
||||
|
||||
## Packet Decryption (`decoder.py`)
|
||||
|
||||
The decoder handles MeshCore packet decryption for historical packet analysis:
|
||||
@@ -326,6 +373,7 @@ All endpoints are prefixed with `/api`.
|
||||
- `POST /api/contacts/sync` - Pull from radio to database
|
||||
- `POST /api/contacts/{key}/add-to-radio` - Push to radio
|
||||
- `POST /api/contacts/{key}/remove-from-radio` - Remove from radio
|
||||
- `POST /api/contacts/{key}/mark-read` - Mark conversation as read (updates last_read_at)
|
||||
- `POST /api/contacts/{key}/telemetry` - Request telemetry from repeater (see below)
|
||||
|
||||
### Channels
|
||||
@@ -333,8 +381,12 @@ All endpoints are prefixed with `/api`.
|
||||
- `GET /api/channels/{key}` - Get by channel key
|
||||
- `POST /api/channels` - Create (hashtag if name starts with # or no key provided)
|
||||
- `POST /api/channels/sync` - Pull from radio
|
||||
- `POST /api/channels/{key}/mark-read` - Mark channel as read (updates last_read_at)
|
||||
- `DELETE /api/channels/{key}` - Delete channel
|
||||
|
||||
### Read State
|
||||
- `POST /api/read-state/mark-all-read` - Mark all contacts and channels as read
|
||||
|
||||
### Messages
|
||||
- `GET /api/messages?type=&conversation_key=&limit=&offset=` - List with filters
|
||||
- `POST /api/messages/direct` - Send direct message
|
||||
@@ -368,7 +420,8 @@ Key test files:
|
||||
- `tests/test_decoder.py` - Channel + direct message decryption, key exchange, real-world test vectors
|
||||
- `tests/test_keystore.py` - Ephemeral key store operations
|
||||
- `tests/test_event_handlers.py` - ACK tracking, repeat detection, CLI response filtering
|
||||
- `tests/test_api.py` - API endpoint tests
|
||||
- `tests/test_api.py` - API endpoint tests, read state tracking
|
||||
- `tests/test_migrations.py` - Migration system, schema versioning
|
||||
|
||||
## Common Tasks
|
||||
|
||||
|
||||
@@ -77,6 +77,10 @@ class Database:
|
||||
await self._connection.commit()
|
||||
logger.debug("Database schema initialized")
|
||||
|
||||
# Run any pending migrations
|
||||
from app.migrations import run_migrations
|
||||
await run_migrations(self._connection)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._connection:
|
||||
await self._connection.close()
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.radio_sync import (
|
||||
stop_periodic_sync,
|
||||
sync_and_offload_all,
|
||||
)
|
||||
from app.routers import channels, contacts, health, messages, packets, radio, settings, ws
|
||||
from app.routers import channels, contacts, health, messages, packets, radio, read_state, settings, ws
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -100,6 +100,7 @@ app.include_router(contacts.router, prefix="/api")
|
||||
app.include_router(channels.router, prefix="/api")
|
||||
app.include_router(messages.router, prefix="/api")
|
||||
app.include_router(packets.router, prefix="/api")
|
||||
app.include_router(read_state.router, prefix="/api")
|
||||
app.include_router(settings.router, prefix="/api")
|
||||
app.include_router(ws.router, prefix="/api")
|
||||
|
||||
|
||||
90
app/migrations.py
Normal file
90
app/migrations.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Database migrations using SQLite's user_version pragma.
|
||||
|
||||
Migrations run automatically on startup. The user_version pragma tracks
|
||||
which migrations have been applied (defaults to 0 for existing databases).
|
||||
|
||||
This approach is safe for existing users - their databases have user_version=0,
|
||||
so all migrations run in order on first startup after upgrade.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_version(conn: aiosqlite.Connection) -> int:
|
||||
"""Get current schema version from SQLite user_version pragma."""
|
||||
cursor = await conn.execute("PRAGMA user_version")
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
async def set_version(conn: aiosqlite.Connection, version: int) -> None:
|
||||
"""Set schema version using SQLite user_version pragma."""
|
||||
await conn.execute(f"PRAGMA user_version = {version}")
|
||||
|
||||
|
||||
async def run_migrations(conn: aiosqlite.Connection) -> int:
|
||||
"""
|
||||
Run all pending migrations.
|
||||
|
||||
Returns the number of migrations applied.
|
||||
"""
|
||||
version = await get_version(conn)
|
||||
applied = 0
|
||||
|
||||
# Migration 1: Add last_read_at columns for server-side read tracking
|
||||
if version < 1:
|
||||
logger.info("Applying migration 1: add last_read_at columns")
|
||||
await _migrate_001_add_last_read_at(conn)
|
||||
await set_version(conn, 1)
|
||||
applied += 1
|
||||
|
||||
# Future migrations go here:
|
||||
# if version < 2:
|
||||
# await _migrate_002_something(conn)
|
||||
# await set_version(conn, 2)
|
||||
# applied += 1
|
||||
|
||||
if applied > 0:
|
||||
logger.info("Applied %d migration(s), schema now at version %d", applied, await get_version(conn))
|
||||
else:
|
||||
logger.debug("Schema up to date at version %d", version)
|
||||
|
||||
return applied
|
||||
|
||||
|
||||
async def _migrate_001_add_last_read_at(conn: aiosqlite.Connection) -> None:
|
||||
"""
|
||||
Add last_read_at column to contacts and channels tables.
|
||||
|
||||
This enables server-side read state tracking, replacing the localStorage
|
||||
approach for consistent read state across devices.
|
||||
|
||||
ALTER TABLE ADD COLUMN is safe - it preserves existing data and handles
|
||||
the "column already exists" case gracefully.
|
||||
"""
|
||||
# Add to contacts table
|
||||
try:
|
||||
await conn.execute("ALTER TABLE contacts ADD COLUMN last_read_at INTEGER")
|
||||
logger.debug("Added last_read_at to contacts table")
|
||||
except aiosqlite.OperationalError as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
logger.debug("contacts.last_read_at already exists, skipping")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Add to channels table
|
||||
try:
|
||||
await conn.execute("ALTER TABLE channels ADD COLUMN last_read_at INTEGER")
|
||||
logger.debug("Added last_read_at to channels table")
|
||||
except aiosqlite.OperationalError as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
logger.debug("channels.last_read_at already exists, skipping")
|
||||
else:
|
||||
raise
|
||||
|
||||
await conn.commit()
|
||||
@@ -14,6 +14,7 @@ class Contact(BaseModel):
|
||||
last_seen: int | None = None
|
||||
on_radio: bool = False
|
||||
last_contacted: int | None = None # Last time we sent/received a message
|
||||
last_read_at: int | None = None # Server-side read state tracking
|
||||
|
||||
def to_radio_dict(self) -> dict:
|
||||
"""Convert to the dict format expected by meshcore radio commands.
|
||||
@@ -63,6 +64,7 @@ class Channel(BaseModel):
|
||||
name: str
|
||||
is_hashtag: bool = False
|
||||
on_radio: bool = False
|
||||
last_read_at: int | None = None # Server-side read state tracking
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
|
||||
@@ -59,6 +59,7 @@ class ContactRepository:
|
||||
last_seen=row["last_seen"],
|
||||
on_radio=bool(row["on_radio"]),
|
||||
last_contacted=row["last_contacted"],
|
||||
last_read_at=row["last_read_at"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -169,6 +170,20 @@ class ContactRepository:
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
@staticmethod
|
||||
async def update_last_read_at(public_key: str, timestamp: int | None = None) -> bool:
|
||||
"""Update the last_read_at timestamp for a contact.
|
||||
|
||||
Returns True if a row was updated, False if contact not found.
|
||||
"""
|
||||
ts = timestamp or int(time.time())
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE contacts SET last_read_at = ? WHERE public_key = ?",
|
||||
(ts, public_key),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
|
||||
class ChannelRepository:
|
||||
@staticmethod
|
||||
@@ -191,7 +206,7 @@ class ChannelRepository:
|
||||
async def get_by_key(key: str) -> Channel | None:
|
||||
"""Get a channel by its key (32-char hex string)."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT key, name, is_hashtag, on_radio FROM channels WHERE key = ?",
|
||||
"SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels WHERE key = ?",
|
||||
(key.upper(),)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
@@ -201,13 +216,14 @@ class ChannelRepository:
|
||||
name=row["name"],
|
||||
is_hashtag=bool(row["is_hashtag"]),
|
||||
on_radio=bool(row["on_radio"]),
|
||||
last_read_at=row["last_read_at"],
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_all() -> list[Channel]:
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT key, name, is_hashtag, on_radio FROM channels ORDER BY name"
|
||||
"SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels ORDER BY name"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
@@ -216,6 +232,7 @@ class ChannelRepository:
|
||||
name=row["name"],
|
||||
is_hashtag=bool(row["is_hashtag"]),
|
||||
on_radio=bool(row["on_radio"]),
|
||||
last_read_at=row["last_read_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
@@ -224,7 +241,7 @@ class ChannelRepository:
|
||||
async def get_by_name(name: str) -> Channel | None:
|
||||
"""Get a channel by name."""
|
||||
cursor = await db.conn.execute(
|
||||
"SELECT key, name, is_hashtag, on_radio FROM channels WHERE name = ?", (name,)
|
||||
"SELECT key, name, is_hashtag, on_radio, last_read_at FROM channels WHERE name = ?", (name,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
@@ -233,6 +250,7 @@ class ChannelRepository:
|
||||
name=row["name"],
|
||||
is_hashtag=bool(row["is_hashtag"]),
|
||||
on_radio=bool(row["on_radio"]),
|
||||
last_read_at=row["last_read_at"],
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -245,6 +263,20 @@ class ChannelRepository:
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
@staticmethod
|
||||
async def update_last_read_at(key: str, timestamp: int | None = None) -> bool:
|
||||
"""Update the last_read_at timestamp for a channel.
|
||||
|
||||
Returns True if a row was updated, False if channel not found.
|
||||
"""
|
||||
ts = timestamp or int(time.time())
|
||||
cursor = await db.conn.execute(
|
||||
"UPDATE channels SET last_read_at = ? WHERE key = ?",
|
||||
(ts, key.upper()),
|
||||
)
|
||||
await db.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
|
||||
class MessageRepository:
|
||||
@staticmethod
|
||||
|
||||
@@ -118,6 +118,20 @@ async def sync_channels_from_radio(
|
||||
return {"synced": count}
|
||||
|
||||
|
||||
@router.post("/{key}/mark-read")
|
||||
async def mark_channel_read(key: str) -> dict:
|
||||
"""Mark a channel as read (update last_read_at timestamp)."""
|
||||
channel = await ChannelRepository.get_by_key(key)
|
||||
if not channel:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
updated = await ChannelRepository.update_last_read_at(key)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=500, detail="Failed to update read state")
|
||||
|
||||
return {"status": "ok", "key": channel.key}
|
||||
|
||||
|
||||
@router.delete("/{key}")
|
||||
async def delete_channel(key: str) -> dict:
|
||||
"""Delete a channel from the database by key.
|
||||
|
||||
@@ -214,6 +214,20 @@ async def add_contact_to_radio(public_key: str) -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/{public_key}/mark-read")
|
||||
async def mark_contact_read(public_key: str) -> dict:
|
||||
"""Mark a contact conversation as read (update last_read_at timestamp)."""
|
||||
contact = await ContactRepository.get_by_key_or_prefix(public_key)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
|
||||
updated = await ContactRepository.update_last_read_at(contact.public_key)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=500, detail="Failed to update read state")
|
||||
|
||||
return {"status": "ok", "public_key": contact.public_key}
|
||||
|
||||
|
||||
@router.delete("/{public_key}")
|
||||
async def delete_contact(public_key: str) -> dict:
|
||||
"""Delete a contact from the database (and radio if present)."""
|
||||
|
||||
35
app/routers/read_state.py
Normal file
35
app/routers/read_state.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Read state management endpoints."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/read-state", tags=["read-state"])
|
||||
|
||||
|
||||
@router.post("/mark-all-read")
|
||||
async def mark_all_read() -> dict:
|
||||
"""Mark all contacts and channels as read.
|
||||
|
||||
Updates last_read_at to current timestamp for all contacts and channels
|
||||
in a single database transaction.
|
||||
"""
|
||||
now = int(time.time())
|
||||
|
||||
# Update all contacts and channels in one transaction
|
||||
await db.conn.execute(
|
||||
"UPDATE contacts SET last_read_at = ?",
|
||||
(now,)
|
||||
)
|
||||
await db.conn.execute(
|
||||
"UPDATE channels SET last_read_at = ?",
|
||||
(now,)
|
||||
)
|
||||
await db.conn.commit()
|
||||
|
||||
logger.info("Marked all contacts and channels as read at %d", now)
|
||||
return {"status": "ok", "timestamp": now}
|
||||
@@ -26,7 +26,7 @@ frontend/
|
||||
│ ├── styles.css # Dark theme CSS
|
||||
│ ├── utils/
|
||||
│ │ ├── messageParser.ts # Text parsing utilities
|
||||
│ │ ├── conversationState.ts # localStorage for unread tracking
|
||||
│ │ ├── conversationState.ts # localStorage for message times (sidebar sorting)
|
||||
│ │ ├── pubkey.ts # Public key utilities (prefix matching, display names)
|
||||
│ │ └── contactAvatar.ts # Avatar generation (colors, initials/emoji)
|
||||
│ ├── components/
|
||||
@@ -342,7 +342,7 @@ const activeConv = activeConversationRef.current;
|
||||
|
||||
### State Tracking Keys
|
||||
|
||||
State tracking keys (for unread counts and message times) are generated by `getStateKey()`:
|
||||
State tracking keys (for message times used in sidebar sorting) are generated by `getStateKey()`:
|
||||
|
||||
```typescript
|
||||
import { getStateKey } from './utils/conversationState';
|
||||
@@ -355,7 +355,26 @@ getStateKey('contact', publicKey) // e.g., "contact-abc123def456"
|
||||
```
|
||||
|
||||
**Note:** `getStateKey()` is NOT the same as `Message.conversation_key`. The state key is prefixed
|
||||
for localStorage tracking, while `conversation_key` is the raw database field.
|
||||
for local state tracking, while `conversation_key` is the raw database field.
|
||||
|
||||
### Read State (Server-Side)
|
||||
|
||||
Unread tracking uses server-side `last_read_at` timestamps for cross-device consistency:
|
||||
|
||||
```typescript
|
||||
// Contacts and channels include last_read_at from server
|
||||
interface Contact {
|
||||
// ...
|
||||
last_read_at: number | null; // Unix timestamp when conversation was last read
|
||||
}
|
||||
|
||||
// Mark as read via API (called automatically when viewing conversation)
|
||||
await api.markContactRead(publicKey);
|
||||
await api.markChannelRead(channelKey);
|
||||
await api.markAllRead(); // Bulk mark all as read
|
||||
```
|
||||
|
||||
Unread count = messages where `received_at > last_read_at`.
|
||||
|
||||
## Utility Functions
|
||||
|
||||
@@ -389,17 +408,19 @@ getContactDisplayName(name, publicKey) // name or first 12 chars of key
|
||||
### Conversation State (`utils/conversationState.ts`)
|
||||
|
||||
```typescript
|
||||
import { getStateKey, setLastMessageTime, setLastReadTime } from './utils/conversationState';
|
||||
import { getStateKey, setLastMessageTime, getLastMessageTimes } from './utils/conversationState';
|
||||
|
||||
// Generate state tracking key (NOT the same as Message.conversation_key)
|
||||
getStateKey('channel', channelKey)
|
||||
getStateKey('contact', publicKey)
|
||||
|
||||
// Track message times for unread detection
|
||||
// Track message times for sidebar sorting (stored in localStorage)
|
||||
setLastMessageTime(stateKey, timestamp)
|
||||
setLastReadTime(stateKey, timestamp)
|
||||
getLastMessageTimes() // Returns all tracked message times
|
||||
```
|
||||
|
||||
**Note:** Read state (`last_read_at`) is tracked server-side, not in localStorage.
|
||||
|
||||
### Contact Avatar (`utils/contactAvatar.ts`)
|
||||
|
||||
Generates consistent profile "images" for contacts using hash-based colors:
|
||||
|
||||
File diff suppressed because one or more lines are too long
2
frontend/dist/index.html
vendored
2
frontend/dist/index.html
vendored
@@ -13,7 +13,7 @@
|
||||
<link rel="shortcut icon" href="/favicon.ico" />
|
||||
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<script type="module" crossorigin src="/assets/index-6T32T4ZI.js"></script>
|
||||
<script type="module" crossorigin src="/assets/index-Cp9RQ4Uj.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-DaLCXB8p.css">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
@@ -393,6 +393,7 @@ export function App() {
|
||||
lon: null,
|
||||
last_seen: null,
|
||||
on_radio: false,
|
||||
last_read_at: null,
|
||||
};
|
||||
setContacts((prev) => [...prev, newContact]);
|
||||
|
||||
|
||||
@@ -82,6 +82,10 @@ export const api = {
|
||||
fetchJson<{ status: string }>(`/contacts/${publicKey}`, {
|
||||
method: 'DELETE',
|
||||
}),
|
||||
markContactRead: (publicKey: string) =>
|
||||
fetchJson<{ status: string; public_key: string }>(`/contacts/${publicKey}/mark-read`, {
|
||||
method: 'POST',
|
||||
}),
|
||||
requestTelemetry: (publicKey: string, password: string) =>
|
||||
fetchJson<TelemetryResponse>(`/contacts/${publicKey}/telemetry`, {
|
||||
method: 'POST',
|
||||
@@ -105,6 +109,10 @@ export const api = {
|
||||
fetchJson<{ synced: number }>('/channels/sync', { method: 'POST' }),
|
||||
deleteChannel: (key: string) =>
|
||||
fetchJson<{ status: string }>(`/channels/${key}`, { method: 'DELETE' }),
|
||||
markChannelRead: (key: string) =>
|
||||
fetchJson<{ status: string; key: string }>(`/channels/${key}/mark-read`, {
|
||||
method: 'POST',
|
||||
}),
|
||||
|
||||
// Messages
|
||||
getMessages: (params?: {
|
||||
@@ -157,6 +165,12 @@ export const api = {
|
||||
body: JSON.stringify(params),
|
||||
}),
|
||||
|
||||
// Read State
|
||||
markAllRead: () =>
|
||||
fetchJson<{ status: string; timestamp: number }>('/read-state/mark-all-read', {
|
||||
method: 'POST',
|
||||
}),
|
||||
|
||||
// App Settings
|
||||
getSettings: () => fetchJson<AppSettings>('/settings'),
|
||||
updateSettings: (settings: AppSettingsUpdate) =>
|
||||
|
||||
@@ -2,9 +2,7 @@ import { useState, useCallback, useEffect, useRef } from 'react';
|
||||
import { api } from '../api';
|
||||
import {
|
||||
getLastMessageTimes,
|
||||
getLastReadTimes,
|
||||
setLastMessageTime,
|
||||
setLastReadTime,
|
||||
getStateKey,
|
||||
type ConversationTimes,
|
||||
} from '../utils/conversationState';
|
||||
@@ -32,6 +30,7 @@ export function useUnreadCounts(
|
||||
const fetchedContacts = useRef<Set<string>>(new Set());
|
||||
|
||||
// Fetch messages and count unreads for new channels/contacts
|
||||
// Uses server-side last_read_at for consistent read state across devices
|
||||
useEffect(() => {
|
||||
const newChannels = channels.filter(c => !fetchedChannels.current.has(c.key));
|
||||
const newContacts = contacts.filter(c => c.public_key && !fetchedContacts.current.has(c.public_key));
|
||||
@@ -52,16 +51,16 @@ export function useUnreadCounts(
|
||||
|
||||
try {
|
||||
const bulkMessages = await api.getMessagesBulk(conversations, 100);
|
||||
const currentReadTimes = getLastReadTimes();
|
||||
const newUnreadCounts: Record<string, number> = {};
|
||||
const newLastMessageTimes: Record<string, number> = {};
|
||||
|
||||
// Process channel messages
|
||||
// Process channel messages - use server-side last_read_at
|
||||
for (const channel of newChannels) {
|
||||
const msgs = bulkMessages[`CHAN:${channel.key}`] || [];
|
||||
if (msgs.length > 0) {
|
||||
const key = getStateKey('channel', channel.key);
|
||||
const lastRead = currentReadTimes[key] || 0;
|
||||
// Use server-side last_read_at, fallback to 0 if never read
|
||||
const lastRead = channel.last_read_at || 0;
|
||||
|
||||
const unreadCount = msgs.filter(m => !m.outgoing && m.received_at > lastRead).length;
|
||||
if (unreadCount > 0) {
|
||||
@@ -74,12 +73,13 @@ export function useUnreadCounts(
|
||||
}
|
||||
}
|
||||
|
||||
// Process contact messages
|
||||
// Process contact messages - use server-side last_read_at
|
||||
for (const contact of newContacts) {
|
||||
const msgs = bulkMessages[`PRIV:${contact.public_key}`] || [];
|
||||
if (msgs.length > 0) {
|
||||
const key = getStateKey('contact', contact.public_key);
|
||||
const lastRead = currentReadTimes[key] || 0;
|
||||
// Use server-side last_read_at, fallback to 0 if never read
|
||||
const lastRead = contact.last_read_at || 0;
|
||||
|
||||
const unreadCount = msgs.filter(m => !m.outgoing && m.received_at > lastRead).length;
|
||||
if (unreadCount > 0) {
|
||||
@@ -105,15 +105,15 @@ export function useUnreadCounts(
|
||||
}, [channels, contacts]);
|
||||
|
||||
// Mark conversation as read when user views it
|
||||
// Calls server API to persist read state across devices
|
||||
useEffect(() => {
|
||||
if (activeConversation && activeConversation.type !== 'raw') {
|
||||
const key = getStateKey(
|
||||
activeConversation.type as 'channel' | 'contact',
|
||||
activeConversation.id
|
||||
);
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
setLastReadTime(key, now);
|
||||
|
||||
// Update local state immediately for responsive UI
|
||||
setUnreadCounts((prev) => {
|
||||
if (prev[key]) {
|
||||
const next = { ...prev };
|
||||
@@ -122,6 +122,17 @@ export function useUnreadCounts(
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
|
||||
// Persist to server (fire-and-forget, errors logged but not blocking)
|
||||
if (activeConversation.type === 'channel') {
|
||||
api.markChannelRead(activeConversation.id).catch((err) => {
|
||||
console.error('Failed to mark channel as read on server:', err);
|
||||
});
|
||||
} else if (activeConversation.type === 'contact') {
|
||||
api.markContactRead(activeConversation.id).catch((err) => {
|
||||
console.error('Failed to mark contact as read on server:', err);
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [activeConversation]);
|
||||
|
||||
@@ -134,32 +145,25 @@ export function useUnreadCounts(
|
||||
}, []);
|
||||
|
||||
// Mark all conversations as read
|
||||
// Calls single bulk API endpoint to persist read state
|
||||
const markAllRead = useCallback(() => {
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
|
||||
for (const channel of channels) {
|
||||
const key = getStateKey('channel', channel.key);
|
||||
setLastReadTime(key, now);
|
||||
}
|
||||
|
||||
for (const contact of contacts) {
|
||||
if (contact.public_key) {
|
||||
const key = getStateKey('contact', contact.public_key);
|
||||
setLastReadTime(key, now);
|
||||
}
|
||||
}
|
||||
|
||||
// Update local state immediately
|
||||
setUnreadCounts({});
|
||||
}, [channels, contacts]);
|
||||
|
||||
// Persist to server with single bulk request
|
||||
api.markAllRead().catch((err) => {
|
||||
console.error('Failed to mark all as read on server:', err);
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Mark a specific conversation as read
|
||||
// Calls server API to persist read state across devices
|
||||
const markConversationRead = useCallback((conv: Conversation) => {
|
||||
if (conv.type === 'raw') return;
|
||||
|
||||
const key = getStateKey(conv.type as 'channel' | 'contact', conv.id);
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
setLastReadTime(key, now);
|
||||
|
||||
// Update local state immediately
|
||||
setUnreadCounts((prev) => {
|
||||
if (prev[key]) {
|
||||
const next = { ...prev };
|
||||
@@ -168,6 +172,17 @@ export function useUnreadCounts(
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
|
||||
// Persist to server (fire-and-forget)
|
||||
if (conv.type === 'channel') {
|
||||
api.markChannelRead(conv.id).catch((err) => {
|
||||
console.error('Failed to mark channel as read on server:', err);
|
||||
});
|
||||
} else if (conv.type === 'contact') {
|
||||
api.markContactRead(conv.id).catch((err) => {
|
||||
console.error('Failed to mark contact as read on server:', err);
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Track a new incoming message for unread counts
|
||||
|
||||
@@ -55,6 +55,7 @@ export interface Contact {
|
||||
lon: number | null;
|
||||
last_seen: number | null;
|
||||
on_radio: boolean;
|
||||
last_read_at: number | null;
|
||||
}
|
||||
|
||||
export interface Channel {
|
||||
@@ -62,6 +63,7 @@ export interface Channel {
|
||||
name: string;
|
||||
is_hashtag: boolean;
|
||||
on_radio: boolean;
|
||||
last_read_at: number | null;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
/**
|
||||
* localStorage utilities for tracking conversation read/message state.
|
||||
* localStorage utilities for tracking conversation message times.
|
||||
*
|
||||
* Stores two maps:
|
||||
* - lastMessageTime: when each conversation last received a message
|
||||
* - lastReadTime: when the user last viewed each conversation
|
||||
* Stores when each conversation last received a message, used for
|
||||
* sorting conversations by recency in the sidebar.
|
||||
*
|
||||
* A conversation has unread messages if lastMessageTime > lastReadTime.
|
||||
* Read state (last_read_at) is tracked server-side for consistency
|
||||
* across devices - see useUnreadCounts hook.
|
||||
*/
|
||||
|
||||
import { getPubkeyPrefix } from './pubkey';
|
||||
|
||||
const LAST_MESSAGE_KEY = 'remoteterm-lastMessageTime';
|
||||
const LAST_READ_KEY = 'remoteterm-lastReadTime';
|
||||
|
||||
export type ConversationTimes = Record<string, number>;
|
||||
|
||||
@@ -36,10 +35,6 @@ export function getLastMessageTimes(): ConversationTimes {
|
||||
return loadTimes(LAST_MESSAGE_KEY);
|
||||
}
|
||||
|
||||
export function getLastReadTimes(): ConversationTimes {
|
||||
return loadTimes(LAST_READ_KEY);
|
||||
}
|
||||
|
||||
export function setLastMessageTime(stateKey: string, timestamp: number): ConversationTimes {
|
||||
const times = loadTimes(LAST_MESSAGE_KEY);
|
||||
// Only update if this is a newer message
|
||||
@@ -50,15 +45,8 @@ export function setLastMessageTime(stateKey: string, timestamp: number): Convers
|
||||
return times;
|
||||
}
|
||||
|
||||
export function setLastReadTime(stateKey: string, timestamp: number): ConversationTimes {
|
||||
const times = loadTimes(LAST_READ_KEY);
|
||||
times[stateKey] = timestamp;
|
||||
saveTimes(LAST_READ_KEY, times);
|
||||
return times;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a state tracking key for unread counts and message times.
|
||||
* Generate a state tracking key for message times.
|
||||
*
|
||||
* This is NOT the same as Message.conversation_key (the database field).
|
||||
* This creates prefixed keys for localStorage/state tracking:
|
||||
|
||||
@@ -185,6 +185,257 @@ class TestPacketsEndpoint:
|
||||
assert response.json()["count"] == 42
|
||||
|
||||
|
||||
class TestReadStateEndpoints:
|
||||
"""Test read state tracking endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_contact_read_updates_timestamp(self):
|
||||
"""Marking contact as read updates last_read_at in database."""
|
||||
import aiosqlite
|
||||
import time
|
||||
from app.repository import ContactRepository
|
||||
from app.database import db
|
||||
|
||||
# Use in-memory database for testing
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
|
||||
# Create contacts table with last_read_at column
|
||||
await conn.execute("""
|
||||
CREATE TABLE contacts (
|
||||
public_key TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
type INTEGER DEFAULT 0,
|
||||
flags INTEGER DEFAULT 0,
|
||||
last_path TEXT,
|
||||
last_path_len INTEGER DEFAULT -1,
|
||||
last_advert INTEGER,
|
||||
lat REAL,
|
||||
lon REAL,
|
||||
last_seen INTEGER,
|
||||
on_radio INTEGER DEFAULT 0,
|
||||
last_contacted INTEGER,
|
||||
last_read_at INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert a test contact
|
||||
await conn.execute(
|
||||
"INSERT INTO contacts (public_key, name) VALUES (?, ?)",
|
||||
("abc123def456789012345678901234567890123456789012345678901234", "TestContact")
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
original_conn = db._connection
|
||||
db._connection = conn
|
||||
|
||||
try:
|
||||
before_time = int(time.time())
|
||||
|
||||
# Update last_read_at
|
||||
updated = await ContactRepository.update_last_read_at(
|
||||
"abc123def456789012345678901234567890123456789012345678901234"
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
|
||||
# Verify the timestamp was set
|
||||
contact = await ContactRepository.get_by_key(
|
||||
"abc123def456789012345678901234567890123456789012345678901234"
|
||||
)
|
||||
assert contact is not None
|
||||
assert contact.last_read_at is not None
|
||||
assert contact.last_read_at >= before_time
|
||||
finally:
|
||||
db._connection = original_conn
|
||||
await conn.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_channel_read_updates_timestamp(self):
|
||||
"""Marking channel as read updates last_read_at in database."""
|
||||
import aiosqlite
|
||||
import time
|
||||
from app.repository import ChannelRepository
|
||||
from app.database import db
|
||||
|
||||
# Use in-memory database for testing
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
|
||||
# Create channels table with last_read_at column
|
||||
await conn.execute("""
|
||||
CREATE TABLE channels (
|
||||
key TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
is_hashtag INTEGER DEFAULT 0,
|
||||
on_radio INTEGER DEFAULT 0,
|
||||
last_read_at INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert a test channel
|
||||
await conn.execute(
|
||||
"INSERT INTO channels (key, name) VALUES (?, ?)",
|
||||
("0123456789ABCDEF0123456789ABCDEF", "#testchannel")
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
original_conn = db._connection
|
||||
db._connection = conn
|
||||
|
||||
try:
|
||||
before_time = int(time.time())
|
||||
|
||||
# Update last_read_at
|
||||
updated = await ChannelRepository.update_last_read_at(
|
||||
"0123456789ABCDEF0123456789ABCDEF"
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
|
||||
# Verify the timestamp was set
|
||||
channel = await ChannelRepository.get_by_key(
|
||||
"0123456789ABCDEF0123456789ABCDEF"
|
||||
)
|
||||
assert channel is not None
|
||||
assert channel.last_read_at is not None
|
||||
assert channel.last_read_at >= before_time
|
||||
finally:
|
||||
db._connection = original_conn
|
||||
await conn.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_nonexistent_contact_returns_false(self):
|
||||
"""Marking nonexistent contact returns False."""
|
||||
import aiosqlite
|
||||
from app.repository import ContactRepository
|
||||
from app.database import db
|
||||
|
||||
# Use in-memory database for testing
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
|
||||
await conn.execute("""
|
||||
CREATE TABLE contacts (
|
||||
public_key TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
type INTEGER DEFAULT 0,
|
||||
flags INTEGER DEFAULT 0,
|
||||
last_path TEXT,
|
||||
last_path_len INTEGER DEFAULT -1,
|
||||
last_advert INTEGER,
|
||||
lat REAL,
|
||||
lon REAL,
|
||||
last_seen INTEGER,
|
||||
on_radio INTEGER DEFAULT 0,
|
||||
last_contacted INTEGER,
|
||||
last_read_at INTEGER
|
||||
)
|
||||
""")
|
||||
await conn.commit()
|
||||
|
||||
original_conn = db._connection
|
||||
db._connection = conn
|
||||
|
||||
try:
|
||||
updated = await ContactRepository.update_last_read_at("nonexistent")
|
||||
assert updated is False
|
||||
finally:
|
||||
db._connection = original_conn
|
||||
await conn.close()
|
||||
|
||||
def test_mark_contact_read_endpoint_returns_404_for_missing(self):
|
||||
"""Mark-read endpoint returns 404 for nonexistent contact."""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
with patch("app.repository.ContactRepository.get_by_key_or_prefix", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = None
|
||||
|
||||
from app.main import app
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/api/contacts/nonexistent/mark-read")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_mark_channel_read_endpoint_returns_404_for_missing(self):
|
||||
"""Mark-read endpoint returns 404 for nonexistent channel."""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
with patch("app.repository.ChannelRepository.get_by_key", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = None
|
||||
|
||||
from app.main import app
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/api/channels/NONEXISTENT/mark-read")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_all_read_updates_all_conversations(self):
|
||||
"""Bulk mark-all-read updates all contacts and channels."""
|
||||
import aiosqlite
|
||||
import time
|
||||
from app.database import db
|
||||
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
|
||||
# Create tables
|
||||
await conn.execute("""
|
||||
CREATE TABLE contacts (
|
||||
public_key TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
last_read_at INTEGER
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE TABLE channels (
|
||||
key TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
last_read_at INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert test data with NULL last_read_at
|
||||
await conn.execute("INSERT INTO contacts (public_key, name) VALUES (?, ?)", ("contact1", "Alice"))
|
||||
await conn.execute("INSERT INTO contacts (public_key, name) VALUES (?, ?)", ("contact2", "Bob"))
|
||||
await conn.execute("INSERT INTO channels (key, name) VALUES (?, ?)", ("CHAN1", "#test1"))
|
||||
await conn.execute("INSERT INTO channels (key, name) VALUES (?, ?)", ("CHAN2", "#test2"))
|
||||
await conn.commit()
|
||||
|
||||
original_conn = db._connection
|
||||
db._connection = conn
|
||||
|
||||
try:
|
||||
before_time = int(time.time())
|
||||
|
||||
# Call the endpoint
|
||||
from app.routers.read_state import mark_all_read
|
||||
result = await mark_all_read()
|
||||
|
||||
assert result["status"] == "ok"
|
||||
assert result["timestamp"] >= before_time
|
||||
|
||||
# Verify all contacts updated
|
||||
cursor = await conn.execute("SELECT last_read_at FROM contacts")
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
assert row["last_read_at"] >= before_time
|
||||
|
||||
# Verify all channels updated
|
||||
cursor = await conn.execute("SELECT last_read_at FROM channels")
|
||||
rows = await cursor.fetchall()
|
||||
for row in rows:
|
||||
assert row["last_read_at"] >= before_time
|
||||
finally:
|
||||
db._connection = original_conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
class TestRawPacketRepository:
|
||||
"""Test raw packet storage with deduplication."""
|
||||
|
||||
|
||||
221
tests/test_migrations.py
Normal file
221
tests/test_migrations.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Tests for database migrations."""
|
||||
|
||||
import pytest
|
||||
import aiosqlite
|
||||
|
||||
from app.migrations import get_version, set_version, run_migrations
|
||||
|
||||
|
||||
class TestMigrationSystem:
|
||||
"""Test the migration version tracking system."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_version_returns_zero_for_new_db(self):
|
||||
"""New database has user_version=0."""
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
try:
|
||||
version = await get_version(conn)
|
||||
assert version == 0
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_version_updates_pragma(self):
|
||||
"""Setting version updates the user_version pragma."""
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
try:
|
||||
await set_version(conn, 5)
|
||||
version = await get_version(conn)
|
||||
assert version == 5
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
class TestMigration001:
|
||||
"""Test migration 001: add last_read_at columns."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_adds_last_read_at_to_contacts(self):
|
||||
"""Migration adds last_read_at column to contacts table."""
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
try:
|
||||
# Create schema without last_read_at (simulating pre-migration state)
|
||||
await conn.execute("""
|
||||
CREATE TABLE contacts (
|
||||
public_key TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
type INTEGER DEFAULT 0,
|
||||
flags INTEGER DEFAULT 0,
|
||||
last_path TEXT,
|
||||
last_path_len INTEGER DEFAULT -1,
|
||||
last_advert INTEGER,
|
||||
lat REAL,
|
||||
lon REAL,
|
||||
last_seen INTEGER,
|
||||
on_radio INTEGER DEFAULT 0,
|
||||
last_contacted INTEGER
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE TABLE channels (
|
||||
key TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
is_hashtag INTEGER DEFAULT 0,
|
||||
on_radio INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
await conn.commit()
|
||||
|
||||
# Run migrations
|
||||
applied = await run_migrations(conn)
|
||||
|
||||
assert applied == 1
|
||||
assert await get_version(conn) == 1
|
||||
|
||||
# Verify columns exist by inserting and selecting
|
||||
await conn.execute(
|
||||
"INSERT INTO contacts (public_key, name, last_read_at) VALUES (?, ?, ?)",
|
||||
("abc123", "Test", 12345)
|
||||
)
|
||||
await conn.execute(
|
||||
"INSERT INTO channels (key, name, last_read_at) VALUES (?, ?, ?)",
|
||||
("KEY123", "#test", 67890)
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
cursor = await conn.execute(
|
||||
"SELECT last_read_at FROM contacts WHERE public_key = ?",
|
||||
("abc123",)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["last_read_at"] == 12345
|
||||
|
||||
cursor = await conn.execute(
|
||||
"SELECT last_read_at FROM channels WHERE key = ?",
|
||||
("KEY123",)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["last_read_at"] == 67890
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_is_idempotent(self):
|
||||
"""Running migration multiple times is safe."""
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
try:
|
||||
# Create schema without last_read_at
|
||||
await conn.execute("""
|
||||
CREATE TABLE contacts (
|
||||
public_key TEXT PRIMARY KEY,
|
||||
name TEXT
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE TABLE channels (
|
||||
key TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
await conn.commit()
|
||||
|
||||
# Run migrations twice
|
||||
applied1 = await run_migrations(conn)
|
||||
applied2 = await run_migrations(conn)
|
||||
|
||||
assert applied1 == 1
|
||||
assert applied2 == 0 # No migrations on second run
|
||||
assert await get_version(conn) == 1
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_handles_column_already_exists(self):
|
||||
"""Migration handles case where column already exists."""
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
try:
|
||||
# Create schema with last_read_at already present
|
||||
await conn.execute("""
|
||||
CREATE TABLE contacts (
|
||||
public_key TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
last_read_at INTEGER
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE TABLE channels (
|
||||
key TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
last_read_at INTEGER
|
||||
)
|
||||
""")
|
||||
await conn.commit()
|
||||
|
||||
# Run migrations - should not fail
|
||||
applied = await run_migrations(conn)
|
||||
|
||||
# Still counts as applied (version incremented) but no error
|
||||
assert applied == 1
|
||||
assert await get_version(conn) == 1
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_data_preserved_after_migration(self):
|
||||
"""Migration preserves existing contact and channel data."""
|
||||
conn = await aiosqlite.connect(":memory:")
|
||||
conn.row_factory = aiosqlite.Row
|
||||
try:
|
||||
# Create schema and insert data before migration
|
||||
await conn.execute("""
|
||||
CREATE TABLE contacts (
|
||||
public_key TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
type INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
await conn.execute("""
|
||||
CREATE TABLE channels (
|
||||
key TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
is_hashtag INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
await conn.execute(
|
||||
"INSERT INTO contacts (public_key, name, type) VALUES (?, ?, ?)",
|
||||
("existingkey", "ExistingContact", 1)
|
||||
)
|
||||
await conn.execute(
|
||||
"INSERT INTO channels (key, name, is_hashtag) VALUES (?, ?, ?)",
|
||||
("EXISTINGCHAN", "#existing", 1)
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
# Run migrations
|
||||
await run_migrations(conn)
|
||||
|
||||
# Verify data is preserved
|
||||
cursor = await conn.execute(
|
||||
"SELECT public_key, name, type, last_read_at FROM contacts WHERE public_key = ?",
|
||||
("existingkey",)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["public_key"] == "existingkey"
|
||||
assert row["name"] == "ExistingContact"
|
||||
assert row["type"] == 1
|
||||
assert row["last_read_at"] is None # New column defaults to NULL
|
||||
|
||||
cursor = await conn.execute(
|
||||
"SELECT key, name, is_hashtag, last_read_at FROM channels WHERE key = ?",
|
||||
("EXISTINGCHAN",)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["key"] == "EXISTINGCHAN"
|
||||
assert row["name"] == "#existing"
|
||||
assert row["is_hashtag"] == 1
|
||||
assert row["last_read_at"] is None
|
||||
finally:
|
||||
await conn.close()
|
||||
Reference in New Issue
Block a user