diff --git a/database.py b/database.py index 25f69fb..4a9e053 100644 --- a/database.py +++ b/database.py @@ -1,38 +1,150 @@ import sqlite3 import globals +from utilities.utils import get_nodeNum, get_node_list, get_name_from_number + + +def init_nodedb(): + """Initialize the node database and update it with nodes from the interface.""" + try: + with sqlite3.connect(globals.db_file_path) as db_connection: + db_cursor = db_connection.cursor() + + # Table name construction + table_name = f"{str(get_nodeNum())}_nodedb" + nodeinfo_table = f'"{table_name}"' # Quote the table name because it might begin with numerics + + # Step 1: Create the table if it doesn't exist + create_table_query = f''' + CREATE TABLE IF NOT EXISTS {nodeinfo_table} ( + user_id TEXT PRIMARY KEY, + long_name TEXT, + short_name TEXT + ) + ''' + db_cursor.execute(create_table_query) + + # Step 2: Get the list of nodes from the interface + node_list = get_node_list() + + # Step 3: Insert nodes into the database if they don't already exist + for node in node_list: + insert_query = f''' + INSERT OR IGNORE INTO {nodeinfo_table} (user_id, long_name, short_name) + VALUES (?, ?, ?) + ''' + # Replace placeholders with actual data for the node + db_cursor.execute(insert_query, (node, get_name_from_number(node, "long"), get_name_from_number(node, "short"))) + + db_connection.commit() + + except sqlite3.Error as e: + print(f"SQLite error in init_and_update_nodedb: {e}") + except Exception as e: + print(f"Unexpected error in init_and_update_nodedb: {e}") -nodeinfo_table = "nodeinfo" -def initialize_database(): - conn = sqlite3.connect(globals.db_file_path) - db_cursor = conn.cursor() - # Create the nodeinfo table for storing nodeinfos - query = f'CREATE TABLE IF NOT EXISTS {nodeinfo_table} (user_id TEXT, long_name TEXT, short_name TEXT)' - db_cursor.execute(query) - # Create the messages table for storing messages - table_name = globals.channel_list[globals.selected_channel] + "_messages" - query = f'CREATE TABLE IF NOT EXISTS {table_name} (user_id TEXT,message_text TEXT)' - db_cursor.execute(query) def save_message_to_db(channel, user_id, message_text): + """Save messages to the database, ensuring the table exists.""" + try: + with sqlite3.connect(globals.db_file_path) as db_connection: + db_cursor = db_connection.cursor() + + # Construct the table name + table_name = f"{str(get_nodeNum())}_{channel}_messages" + quoted_table_name = f'"{table_name}"' # Quote the table name becuase we might begin with numerics + + # Ensure the table exists + create_table_query = f''' + CREATE TABLE IF NOT EXISTS {quoted_table_name} ( + user_id TEXT, + message_text TEXT + ) + ''' + db_cursor.execute(create_table_query) + + # Insert the message + insert_query = f''' + INSERT INTO {quoted_table_name} (user_id, message_text) + VALUES (?, ?) + ''' + db_cursor.execute(insert_query, (user_id, message_text)) + + db_connection.commit() + + except sqlite3.Error as e: + print(f"SQLite error in save_message_to_db: {e}") + + except Exception as e: + print(f"Unexpected error in save_message_to_db: {e}") + + +def load_messages_from_db(): + """Load messages from the database for all channels and update globals.all_messages and globals.channel_list.""" + try: + with sqlite3.connect(globals.db_file_path) as db_connection: + db_cursor = db_connection.cursor() + + # Retrieve all table names that match the pattern + query = "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ?" + db_cursor.execute(query, (f"{str(get_nodeNum())}_%_messages",)) + tables = [row[0] for row in db_cursor.fetchall()] + + # Reset the channel list before populating + globals.channel_list = [] + + # Iterate through each table and fetch its messages + for table_name in tables: + query = f'SELECT user_id, message_text FROM "{table_name}"' + + try: + # Fetch all messages from the table + db_cursor.execute(query) + db_messages = [(row[0], row[1]) for row in db_cursor.fetchall()] # Save as tuples + + # Extract the channel name from the table name + channel = table_name.split("_")[1] + + # Determine the correct channel name + if channel.isdigit(): + friendly_channel_name = get_name_from_number(int(channel)) + else: + friendly_channel_name = channel + + # Add the channel to globals.channel_list if not already present + if friendly_channel_name not in globals.channel_list: + globals.channel_list.append(friendly_channel_name) + + # Ensure the channel exists in globals.all_messages + if friendly_channel_name not in globals.all_messages: + globals.all_messages[friendly_channel_name] = [] + + # Add messages to globals.all_messages in tuple format + for user_id, message in db_messages: + + formatted_message = (f"{globals.message_prefix} {get_name_from_number(int(user_id), 'short')}: ", message) + if formatted_message not in globals.all_messages[friendly_channel_name]: + globals.all_messages[friendly_channel_name].append(formatted_message) + + + except sqlite3.Error as e: + print(f"SQLite error while loading messages from table '{table_name}': {e}") + + + except sqlite3.Error as e: + print(f"SQLite error in load_messages_from_db: {e}") + + - conn = sqlite3.connect(globals.db_file_path) - table_name = f"{channel}_messages" - db_cursor = conn.cursor() - query = f''' - INSERT INTO {table_name} (user_id, message_text) - VALUES (?, ?) - ''' - db_cursor.execute(query, (user_id, message_text)) - conn.commit() - conn.close() def maybe_store_nodeinfo_in_db(packet): """Save nodeinfo unless that record is already there.""" - try: with sqlite3.connect(globals.db_file_path) as db_connection: + + table_name = f"{str(get_nodeNum())}_nodedb" + nodeinfo_table = f'"{table_name}"' # Quote the table name becuase we might begin with numerics db_cursor = db_connection.cursor() # Check if a record with the same user_id already exists @@ -59,7 +171,8 @@ def maybe_store_nodeinfo_in_db(packet): db_connection.commit() # Fetch the updated record - updated_record = db_cursor.execute(f'SELECT * FROM {nodeinfo_table} WHERE user_id=?', (packet['from'],)).fetchone() + # TODO display new node name in nodelist + # updated_record = db_cursor.execute(f'SELECT * FROM {nodeinfo_table} WHERE user_id=?', (packet['from'],)).fetchone() except sqlite3.Error as e: diff --git a/globals.py b/globals.py index 139b97e..6471a44 100644 --- a/globals.py +++ b/globals.py @@ -7,4 +7,5 @@ selected_node = 0 direct_message = False interface = None display_log = False -db_file_path = "client.db" \ No newline at end of file +db_file_path = "client.db" +message_prefix = ">> " \ No newline at end of file diff --git a/main.py b/main.py index 1a573e3..2728751 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ from utilities.interfaces import initialize_interface from message_handlers.rx_handler import on_receive from ui.curses_ui import main_ui, draw_splash from utilities.utils import get_channels -from database import initialize_database +from database import init_nodedb, load_messages_from_db import globals def main(stdscr): @@ -23,9 +23,11 @@ def main(stdscr): args = parser.parse_args() globals.interface = initialize_interface(args) globals.channel_list = get_channels() - initialize_database() pub.subscribe(on_receive, 'meshtastic.receive') + init_nodedb() + load_messages_from_db() main_ui(stdscr) + if __name__ == "__main__": curses.wrapper(main) \ No newline at end of file diff --git a/message_handlers/rx_handler.py b/message_handlers/rx_handler.py index a629d55..2ec02f4 100644 --- a/message_handlers/rx_handler.py +++ b/message_handlers/rx_handler.py @@ -2,7 +2,7 @@ from meshtastic import BROADCAST_NUM from utilities.utils import get_node_list, decimal_to_hex, get_nodeNum import globals from ui.curses_ui import update_packetlog_win, draw_node_list, update_messages_window, draw_channel_list, add_notification -from database import save_message_to_db, maybe_store_nodeinfo_in_db +from database import init_nodedb, save_message_to_db, maybe_store_nodeinfo_in_db @@ -54,9 +54,9 @@ def on_receive(packet): message_from_string = str(decimal_to_hex(message_from_id)) # If long name not found, use the ID as string if globals.channel_list[channel_number] in globals.all_messages: - globals.all_messages[globals.channel_list[channel_number]].append((f">> {message_from_string} ", message_string)) + globals.all_messages[globals.channel_list[channel_number]].append((f"{globals.message_prefix} {message_from_string} ", message_string)) else: - globals.all_messages[globals.channel_list[channel_number]] = [(f">> {message_from_string} ", message_string)] + globals.all_messages[globals.channel_list[channel_number]] = [(f"{globals.message_prefix} {message_from_string} ", message_string)] draw_channel_list() update_messages_window() diff --git a/message_handlers/tx_handler.py b/message_handlers/tx_handler.py index 7a45ec6..e94bc31 100644 --- a/message_handlers/tx_handler.py +++ b/message_handlers/tx_handler.py @@ -1,7 +1,12 @@ from meshtastic import BROADCAST_NUM +from database import save_message_to_db +from utilities.utils import get_nodeNum import globals + def send_message(message, destination=BROADCAST_NUM, channel=0): + + myid = get_nodeNum() send_on_channel = 0 if isinstance(globals.channel_list[channel], int): send_on_channel = 0 @@ -22,4 +27,6 @@ def send_message(message, destination=BROADCAST_NUM, channel=0): if globals.channel_list[channel] in globals.all_messages: globals.all_messages[globals.channel_list[channel]].append((">> Sent: ", message)) else: - globals.all_messages[globals.channel_list[channel]] = [(">> Sent: ", message)] \ No newline at end of file + globals.all_messages[globals.channel_list[channel]] = [(">> Sent: ", message)] + + save_message_to_db(globals.channel_list[channel], myid, message) \ No newline at end of file diff --git a/ui/curses_ui.py b/ui/curses_ui.py index 1047b8f..4c0be49 100644 --- a/ui/curses_ui.py +++ b/ui/curses_ui.py @@ -106,7 +106,6 @@ def update_packetlog_win(): packetlog_win.refresh() def draw_text_field(win, text): - # win.clear() win.border() win.addstr(1, 1, text) @@ -241,6 +240,7 @@ def main_ui(stdscr): channel_win.refresh() draw_channel_list() draw_node_list() + update_messages_window() # Draw boxes around windows channel_win.box() diff --git a/utilities/utils.py b/utilities/utils.py index daf7ad6..8feb73a 100644 --- a/utilities/utils.py +++ b/utilities/utils.py @@ -1,5 +1,6 @@ import globals from meshtastic.protobuf import config_pb2 +import re def get_channels(): node = globals.interface.getNode('^local') @@ -9,16 +10,21 @@ def get_channels(): for device_channel in device_channels: if device_channel.role: if device_channel.settings.name: - channel_output.append(device_channel.settings.name) - globals.all_messages[device_channel.settings.name] = [] - + # Use the channel name + channel_name = device_channel.settings.name else: # If channel name is blank, use the modem preset lora_config = node.localConfig.lora modem_preset_enum = lora_config.modem_preset modem_preset_string = config_pb2._CONFIG_LORACONFIG_MODEMPRESET.values_by_number[modem_preset_enum].name - channel_output.append(convert_to_camel_case(modem_preset_string)) - globals.all_messages[convert_to_camel_case(modem_preset_string)] = [] + channel_name = convert_to_camel_case(modem_preset_string) + + # Add channel to output + channel_output.append(channel_name) + + # Only initialize globals.all_messages[channel_name] if it doesn't already exist + if channel_name not in globals.all_messages: + globals.all_messages[channel_name] = [] return list(globals.all_messages.keys()) @@ -57,4 +63,14 @@ def get_name_from_number(number, type='long'): else: name = str(decimal_to_hex(number)) # If long name not found, use the ID as string return name - \ No newline at end of file + +def sanitize_string(input_str: str) -> str: + """Check if the string starts with a letter (a-z, A-Z) or an underscore (_), and replace all non-alpha/numeric/underscore characters with underscores.""" + + if not re.match(r'^[a-zA-Z_]', input_str): + # If not, add "_" + input_str = '_' + input_str + + # Replace special characters with underscores (for database tables) + sanitized_str: str = re.sub(r'[^a-zA-Z0-9_]', '_', input_str) + return sanitized_str \ No newline at end of file