diff --git a/meshview/store.py b/meshview/store.py index e5278db..97c13d2 100644 --- a/meshview/store.py +++ b/meshview/store.py @@ -12,30 +12,6 @@ from meshview.models import Packet, PacketSeen, Node, Traceroute from meshview import notify - -# We count the total amount of packages -# This is to be used by /stats in web.py -async def get_total_packet_count(): - async with database.async_session() as session: - q = select(func.count(Packet.id)) # Use SQLAlchemy's func to count packets - result = await session.execute(q) - return result.scalar() # Return the total count of packets - -# We count the total amount of nodes -async def get_total_node_count(): - async with database.async_session() as session: - q = select(func.count(Node.id)) # Use SQLAlchemy's func to count nodes - result = await session.execute(q) - return result.scalar() # Return the total count of nodes - -# We count the total amount of seen packets -async def get_total_packet_seen_count(): - async with database.async_session() as session: - q = select(func.count(PacketSeen.node_id)) # Use SQLAlchemy's func to count nodes - result = await session.execute(q) - return result.scalar() # Return the total count of seen packets - - async def process_envelope(topic, env): if not env.packet.id: return @@ -54,6 +30,7 @@ async def process_envelope(topic, env): payload=env.packet.SerializeToString(), # p.r. Here seems to be where the packet is imported on the Database and import time is set. import_time=datetime.datetime.now(), + channel=env.channel_id, ) session.add(packet) @@ -111,7 +88,7 @@ async def process_envelope(topic, env): node.short_name = user.short_name node.hw_model = hw_model node.role = role - # if need to update time of last update it may be here + # if need to update time of last update it may be here else: node = Node( @@ -121,7 +98,8 @@ async def process_envelope(topic, env): short_name=user.short_name, hw_model=hw_model, role=role, - # if need to update time of last update it may be here + channel=env.channel_id, + # if need to update time of last update it may be here ) session.add(node) @@ -293,6 +271,233 @@ async def get_mqtt_neighbors(since): ) return result +# In order to provide separate network graphs for LongFast and MediumSlow, I am duplicating the procedures. +# 3 procedures are needed. These would have to be replicated for any other network that we may need to use graphs. +# +# get_traceroutes_longfast +# get_packets_longfast +# get_mqtt_neighbors_longfast +# +# p.r. +# +# Get Traceroute for LongFast only +async def get_traceroutes_longfast(since): + async with database.async_session() as session: + result = await session.execute( + select(Traceroute) + .join(Packet) + .where( + (Traceroute.import_time > (datetime.datetime.now() - since)) + & (Packet.channel == "LongFast") + ) + .order_by(Traceroute.import_time) + ) + return result.scalars() + +# Get MQTT Neighbors for LongFast only +# p.r. +async def get_mqtt_neighbors_longfast(since): + async with database.async_session() as session: + result = await session.execute(select(PacketSeen, Packet) + .join(Packet) + .where( + (PacketSeen.hop_limit == PacketSeen.hop_start) + & (PacketSeen.hop_start != 0) + & (Packet.channel == "LongFast") + ) + + .options( + lazyload(Packet.from_node), + lazyload(Packet.to_node), + ) + ) + return result + +# Get Packets for LongFast only +# p.r. +async def get_packets_longfast(node_id=None, portnum=None, since=None, limit=500, before=None, after=None): + async with database.async_session() as session: + q = select(Packet) + + # Add condition for channel being "LongFast" + q = q.where(Packet.channel == "LongFast") + + if node_id: + q = q.where( + (Packet.from_node_id == node_id) | (Packet.to_node_id == node_id) + ) + if portnum: + q = q.where(Packet.portnum == portnum) + if since: + q = q.where(Packet.import_time > (datetime.datetime.now() - since)) + if before: + q = q.where(Packet.import_time < before) + if after: + q = q.where(Packet.import_time > after) + if limit is not None: + q = q.limit(limit) + + result = await session.execute(q.order_by(Packet.import_time.desc())) + return result.scalars() + +# Get Traceroute for mediumslow only +# p.r. +async def get_traceroutes_mediumslow(since): + async with database.async_session() as session: + result = await session.execute( + select(Traceroute) + .join(Packet) + .where( + (Traceroute.import_time > (datetime.datetime.now() - since)) + & (Packet.channel == "MediumSlow") + ) + .order_by(Traceroute.import_time) + ) + return result.scalars() + +# Get MQTT Neighbors for mediumslow only +# p.r. +async def get_mqtt_neighbors_mediumslow(since): + async with database.async_session() as session: + result = await session.execute(select(PacketSeen, Packet) + .join(Packet) + .where( + (PacketSeen.hop_limit == PacketSeen.hop_start) + & (PacketSeen.hop_start != 0) + & (Packet.channel == "MediumSlow") + ) + + .options( + lazyload(Packet.from_node), + lazyload(Packet.to_node), + ) + ) + return result + +# Get Packets for MediumSlow only +# p.r. +async def get_packets_mediumslow(node_id=None, portnum=None, since=None, limit=500, before=None, after=None): + async with database.async_session() as session: + q = select(Packet) + + # Add condition for channel being "MediumSlow" + q = q.where(Packet.channel == "MediumSlow") + + if node_id: + q = q.where( + (Packet.from_node_id == node_id) | (Packet.to_node_id == node_id) + ) + if portnum: + q = q.where(Packet.portnum == portnum) + if since: + q = q.where(Packet.import_time > (datetime.datetime.now() - since)) + if before: + q = q.where(Packet.import_time < before) + if after: + q = q.where(Packet.import_time > after) + if limit is not None: + q = q.limit(limit) + + result = await session.execute(q.order_by(Packet.import_time.desc())) + return result.scalars() +# We count the total amount of packages +# This is to be used by /stats in web.py +async def get_total_packet_count(): + async with database.async_session() as session: + q = select(func.count(Packet.id)) # Use SQLAlchemy's func to count packets + result = await session.execute(q) + return result.scalar() # Return the total count of packets + +# We count the total amount of nodes +async def get_total_node_count(): + async with database.async_session() as session: + q = select(func.count(Node.id)) # Use SQLAlchemy's func to count nodes + result = await session.execute(q) + return result.scalar() # Return the total count of nodes + +# We count the total amount of seen packets +async def get_total_packet_seen_count(): + async with database.async_session() as session: + q = select(func.count(PacketSeen.node_id)) # Use SQLAlchemy's func to count nodes + result = await session.execute(q) + return result.scalar() # Return the total count of seen packets + + +async def get_total_node_count_longfast() -> int: + """ + Retrieves the total count of nodes where the channel is equal to 'LongFast'. + + This function queries the database asynchronously to count the number of nodes + in the `Node` table that meet the condition `channel == 'LongFast'`. It uses + SQLAlchemy's asynchronous session management and query construction. + + Returns: + int: The total count of nodes with `channel == 'LongFast'`. + + Raises: + Exception: If an error occurs during the database query execution. + """ + try: + # Open an asynchronous session with the database + async with database.async_session() as session: + # Build the query to count nodes where channel == 'LongFast' + q = select(func.count(Node.id)).filter(Node.channel == 'LongFast') + + # Execute the query asynchronously and fetch the result + result = await session.execute(q) + + # Return the scalar value (the count of nodes) + return result.scalar() + except Exception as e: + # Log or handle the exception if needed (optional, replace with logging if necessary) + print(f"An error occurred: {e}") + return 0 # Return 0 or an appropriate fallback value in case of an error + + +async def get_total_node_count_mediumslow() -> int: + """ + Retrieves the total count of nodes where the channel is equal to 'MediumSlow'. + + This function queries the database asynchronously to count the number of nodes + in the `Node` table that meet the condition `channel == 'MediumSlow'`. It uses + SQLAlchemy's asynchronous session management and query construction. + + Returns: + int: The total count of nodes with `channel == 'MediumSlow'`. + + Raises: + Exception: If an error occurs during the database query execution. + """ + try: + # Open an asynchronous session with the database + async with database.async_session() as session: + # Build the query to count nodes where channel == 'LongFast' + q = select(func.count(Node.id)).filter(Node.channel == 'MediumSlow') + + # Execute the query asynchronously and fetch the result + result = await session.execute(q) + + # Return the scalar value (the count of nodes) + return result.scalar() + except Exception as e: + # Log or handle the exception if needed (optional, replace with logging if necessary) + print(f"An error occurred: {e}") + return 0 # Return 0 or an appropriate fallback value in case of an error + + +# Get Nodes for mediumslow only +# p.r. +async def get_nodes_mediumslow(): + async with database.async_session() as session: + result = await session.execute( + select(Node) + .where( + (Node.channel == "MediumSlow") + ) + ) + return result.scalars() + +