mirror of
https://github.com/dpup/meshstream.git
synced 2026-03-28 17:42:37 +01:00
"Add message broker for distributing packets to multiple consumers
This commit is contained in:
31
CLAUDE.md
31
CLAUDE.md
@@ -1,8 +1,18 @@
|
||||
# Meshtastic MQTT Protocol Structure
|
||||
# Meshstream development guide
|
||||
|
||||
## Topic Structure
|
||||
## Dev commands
|
||||
|
||||
- `make build`
|
||||
- `make gen-proto`
|
||||
- `make clean`
|
||||
- `make run` <-- Do not execute the program yourself.
|
||||
|
||||
## Meshtastic MQTT Protocol Structure
|
||||
|
||||
### Topic Structure
|
||||
|
||||
The Meshtastic MQTT topic structure follows this pattern:
|
||||
|
||||
```
|
||||
msh/REGION_PATH/2/e/CHANNELNAME/USERID
|
||||
```
|
||||
@@ -16,18 +26,22 @@ msh/REGION_PATH/2/e/CHANNELNAME/USERID
|
||||
- `!` followed by hex characters for MAC address based IDs
|
||||
- `+` followed by phone number for Signal-based IDs
|
||||
|
||||
## Message Types
|
||||
### Message Types
|
||||
|
||||
#### Encoded Messages (ServiceEnvelope)
|
||||
|
||||
### Encoded Messages (ServiceEnvelope)
|
||||
Topic pattern: `msh/REGION_PATH/2/e/CHANNELNAME/USERID`
|
||||
|
||||
- ServiceEnvelope protobuf messages
|
||||
- Contains:
|
||||
- A MeshPacket (can be encrypted or unencrypted)
|
||||
- channel_id: The global channel ID it was sent on
|
||||
- gateway_id: Node ID of the gateway that relayed the message
|
||||
|
||||
### JSON Messages
|
||||
#### JSON Messages
|
||||
|
||||
Topic pattern: `msh/REGION_PATH/2/json/CHANNELNAME/USERID`
|
||||
|
||||
- Structured JSON payloads with fields like:
|
||||
- `id`: Message ID
|
||||
- `from`: Node ID of sender
|
||||
@@ -37,11 +51,12 @@ Topic pattern: `msh/REGION_PATH/2/json/CHANNELNAME/USERID`
|
||||
|
||||
Note: JSON format is not supported on nRF52 platform devices.
|
||||
|
||||
### Special Topics
|
||||
#### Special Topics
|
||||
|
||||
- MQTT Downlink: `msh/REGION_PATH/2/json/mqtt/`
|
||||
- Used for sending instructions to gateway nodes
|
||||
|
||||
## Important Notes
|
||||
### Important Notes
|
||||
|
||||
- The public MQTT server implements a zero-hop policy (only direct messages)
|
||||
- JSON messages may include specific data types like:
|
||||
@@ -52,4 +67,4 @@ Note: JSON format is not supported on nRF52 platform devices.
|
||||
- Position data on public servers has reduced precision for privacy
|
||||
- Binary messages use the protocol buffers format defined in the Meshtastic codebase
|
||||
|
||||
This corrects the previous assumption that the topic structure was `msh/REGION/STATE/NAME`, which was incorrect.
|
||||
This corrects the previous assumption that the topic structure was `msh/REGION/STATE/NAME`, which was incorrect.
|
||||
|
||||
243
main.go
243
main.go
@@ -5,11 +5,16 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"meshstream/decoder"
|
||||
"meshstream/mqtt"
|
||||
|
||||
pb "meshstream/proto/generated/meshtastic"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -17,8 +22,165 @@ const (
|
||||
mqttUsername = "meshdev"
|
||||
mqttPassword = "large4cats"
|
||||
mqttTopicPrefix = "msh/US/bayarea"
|
||||
logsDir = "./logs"
|
||||
)
|
||||
|
||||
// MessageStats tracks statistics about received messages
|
||||
type MessageStats struct {
|
||||
sync.Mutex
|
||||
TotalMessages int
|
||||
ByNode map[uint32]int
|
||||
ByPortType map[pb.PortNum]int
|
||||
LastStatsPrinted time.Time
|
||||
}
|
||||
|
||||
// NewMessageStats creates a new MessageStats instance
|
||||
func NewMessageStats() *MessageStats {
|
||||
return &MessageStats{
|
||||
ByNode: make(map[uint32]int),
|
||||
ByPortType: make(map[pb.PortNum]int),
|
||||
LastStatsPrinted: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordMessage records a message in the statistics
|
||||
func (s *MessageStats) RecordMessage(packet *mqtt.Packet) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
s.TotalMessages++
|
||||
|
||||
// Count by source node
|
||||
s.ByNode[packet.From]++
|
||||
|
||||
// Count by port type
|
||||
s.ByPortType[packet.PortNum]++
|
||||
}
|
||||
|
||||
// PrintStats prints current statistics
|
||||
func (s *MessageStats) PrintStats() {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
duration := now.Sub(s.LastStatsPrinted)
|
||||
msgPerSec := float64(s.TotalMessages) / duration.Seconds()
|
||||
|
||||
fmt.Println("\n==== Message Statistics ====")
|
||||
fmt.Printf("Total messages: %d (%.2f msg/sec)\n", s.TotalMessages, msgPerSec)
|
||||
|
||||
// Print node statistics
|
||||
fmt.Println("\nMessages by Node:")
|
||||
for nodeID, count := range s.ByNode {
|
||||
fmt.Printf(" Node %d: %d messages\n", nodeID, count)
|
||||
}
|
||||
|
||||
// Print port type statistics
|
||||
fmt.Println("\nMessages by Port Type:")
|
||||
for portType, count := range s.ByPortType {
|
||||
fmt.Printf(" %s: %d messages\n", portType, count)
|
||||
}
|
||||
fmt.Println(strings.Repeat("=", 30))
|
||||
|
||||
// Reset counters for rate calculation
|
||||
s.TotalMessages = 0
|
||||
s.ByNode = make(map[uint32]int)
|
||||
s.ByPortType = make(map[pb.PortNum]int)
|
||||
s.LastStatsPrinted = now
|
||||
}
|
||||
|
||||
// MessageLogger logs messages of specific types to separate files
|
||||
type MessageLogger struct {
|
||||
logDir string
|
||||
loggers map[pb.PortNum]*log.Logger
|
||||
files map[pb.PortNum]*os.File
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewMessageLogger creates a new message logger
|
||||
func NewMessageLogger(logDir string) (*MessageLogger, error) {
|
||||
// Ensure log directory exists
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log directory: %v", err)
|
||||
}
|
||||
|
||||
return &MessageLogger{
|
||||
logDir: logDir,
|
||||
loggers: make(map[pb.PortNum]*log.Logger),
|
||||
files: make(map[pb.PortNum]*os.File),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getLogger returns a logger for the specified port type
|
||||
func (ml *MessageLogger) getLogger(portNum pb.PortNum) *log.Logger {
|
||||
ml.mutex.Lock()
|
||||
defer ml.mutex.Unlock()
|
||||
|
||||
// Check if we already have a logger for this port type
|
||||
if logger, ok := ml.loggers[portNum]; ok {
|
||||
return logger
|
||||
}
|
||||
|
||||
// Create a new log file for this port type
|
||||
filename := fmt.Sprintf("%s.log", strings.ToLower(portNum.String()))
|
||||
filepath := filepath.Join(ml.logDir, filename)
|
||||
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
log.Printf("Error opening log file %s: %v", filepath, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create a new logger
|
||||
logger := log.New(file, "", log.LstdFlags)
|
||||
|
||||
// Store the logger and file handle
|
||||
ml.loggers[portNum] = logger
|
||||
ml.files[portNum] = file
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// LogMessage logs a message to the appropriate file based on its port type
|
||||
func (ml *MessageLogger) LogMessage(packet *mqtt.Packet) {
|
||||
// We only log specific message types
|
||||
switch packet.PortNum {
|
||||
case pb.PortNum_POSITION_APP,
|
||||
pb.PortNum_TELEMETRY_APP,
|
||||
pb.PortNum_NODEINFO_APP,
|
||||
pb.PortNum_MAP_REPORT_APP,
|
||||
pb.PortNum_TRACEROUTE_APP,
|
||||
pb.PortNum_NEIGHBORINFO_APP:
|
||||
|
||||
// Get the logger for this port type
|
||||
logger := ml.getLogger(packet.PortNum)
|
||||
if logger != nil {
|
||||
// Format the message
|
||||
formattedOutput := decoder.FormatTopicAndPacket(packet.TopicInfo, packet.DecodedPacket)
|
||||
|
||||
// Add a timestamp and node info
|
||||
logEntry := fmt.Sprintf("[Node %d] %s\n%s\n",
|
||||
packet.From,
|
||||
time.Now().Format(time.RFC3339),
|
||||
formattedOutput)
|
||||
|
||||
// Write to the log
|
||||
logger.Println(logEntry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all log files
|
||||
func (ml *MessageLogger) Close() {
|
||||
ml.mutex.Lock()
|
||||
defer ml.mutex.Unlock()
|
||||
|
||||
for portNum, file := range ml.files {
|
||||
log.Printf("Closing log file for %s", portNum)
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Set up logging
|
||||
log.SetOutput(os.Stdout)
|
||||
@@ -52,17 +214,85 @@ func main() {
|
||||
// Get the messages channel to receive decoded messages
|
||||
messagesChan := mqttClient.Messages()
|
||||
|
||||
// Create a message broker to distribute messages to multiple consumers
|
||||
broker := mqtt.NewBroker(messagesChan)
|
||||
|
||||
// Create a consumer channel for display with buffer size 10
|
||||
displayChan := broker.Subscribe(10)
|
||||
|
||||
// Create a consumer channel for statistics with buffer size 50
|
||||
statsChan := broker.Subscribe(50)
|
||||
|
||||
// Create a consumer channel for logging with buffer size 100
|
||||
loggerChan := broker.Subscribe(100)
|
||||
|
||||
// Create a stats tracker
|
||||
stats := NewMessageStats()
|
||||
|
||||
// Create a message logger
|
||||
messageLogger, err := NewMessageLogger(logsDir)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to initialize message logger: %v", err)
|
||||
}
|
||||
|
||||
// Create a ticker for periodically printing stats
|
||||
statsTicker := time.NewTicker(30 * time.Second)
|
||||
|
||||
// Setup signal handling for graceful shutdown
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
// Start a goroutine for processing statistics
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case packet, ok := <-statsChan:
|
||||
if !ok {
|
||||
// Channel closed
|
||||
return
|
||||
}
|
||||
|
||||
if packet != nil {
|
||||
stats.RecordMessage(packet)
|
||||
}
|
||||
|
||||
case <-statsTicker.C:
|
||||
stats.PrintStats()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start a goroutine for logging specific message types
|
||||
go func() {
|
||||
if messageLogger != nil {
|
||||
for {
|
||||
packet, ok := <-loggerChan
|
||||
if !ok {
|
||||
// Channel closed
|
||||
return
|
||||
}
|
||||
|
||||
if packet != nil {
|
||||
messageLogger.LogMessage(packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Process messages until interrupt received
|
||||
fmt.Println("Waiting for messages... Press Ctrl+C to exit")
|
||||
fmt.Println("Statistics will be printed every 30 seconds")
|
||||
fmt.Println("Specific message types will be logged to files in the ./logs directory")
|
||||
|
||||
// Main event loop
|
||||
// Main event loop for display
|
||||
for {
|
||||
select {
|
||||
case packet := <-messagesChan:
|
||||
case packet := <-displayChan:
|
||||
if packet == nil {
|
||||
log.Println("Received nil packet, subscriber channel may be closed")
|
||||
continue
|
||||
}
|
||||
|
||||
// Format and print the decoded message
|
||||
formattedOutput := decoder.FormatTopicAndPacket(packet.TopicInfo, packet.DecodedPacket)
|
||||
fmt.Println(formattedOutput)
|
||||
@@ -71,6 +301,15 @@ func main() {
|
||||
case <-sig:
|
||||
// Got an interrupt signal, shutting down
|
||||
fmt.Println("Shutting down...")
|
||||
// Stop the ticker
|
||||
statsTicker.Stop()
|
||||
// Close the message logger
|
||||
if messageLogger != nil {
|
||||
messageLogger.Close()
|
||||
}
|
||||
// Close the broker first (which will close all subscriber channels)
|
||||
broker.Close()
|
||||
// Then disconnect the MQTT client
|
||||
mqttClient.Disconnect()
|
||||
return
|
||||
}
|
||||
|
||||
139
mqtt/broker.go
Normal file
139
mqtt/broker.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Broker distributes messages from a source channel to multiple subscriber channels
|
||||
type Broker struct {
|
||||
sourceChan <-chan *Packet // Source of packets (e.g., from MQTT client)
|
||||
subscribers map[chan *Packet]struct{} // Active subscribers
|
||||
subscriberMutex sync.RWMutex // Lock for modifying the subscribers map
|
||||
done chan struct{} // Signal to stop the dispatch loop
|
||||
wg sync.WaitGroup // Wait group to ensure clean shutdown
|
||||
}
|
||||
|
||||
// NewBroker creates a new broker that distributes messages from sourceChannel to subscribers
|
||||
func NewBroker(sourceChannel <-chan *Packet) *Broker {
|
||||
broker := &Broker{
|
||||
sourceChan: sourceChannel,
|
||||
subscribers: make(map[chan *Packet]struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the dispatch loop
|
||||
broker.wg.Add(1)
|
||||
go broker.dispatchLoop()
|
||||
|
||||
return broker
|
||||
}
|
||||
|
||||
// Subscribe creates and returns a new subscriber channel
|
||||
// The bufferSize parameter controls how many messages can be buffered in the channel
|
||||
func (b *Broker) Subscribe(bufferSize int) <-chan *Packet {
|
||||
// Create a new channel for this subscriber
|
||||
subscriberChan := make(chan *Packet, bufferSize)
|
||||
|
||||
// Register the new subscriber
|
||||
b.subscriberMutex.Lock()
|
||||
b.subscribers[subscriberChan] = struct{}{}
|
||||
b.subscriberMutex.Unlock()
|
||||
|
||||
// Return the channel
|
||||
return subscriberChan
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscriber and closes its channel
|
||||
func (b *Broker) Unsubscribe(ch <-chan *Packet) {
|
||||
|
||||
b.subscriberMutex.Lock()
|
||||
defer b.subscriberMutex.Unlock()
|
||||
|
||||
// Find the channel in our subscribers map
|
||||
for subCh := range b.subscribers {
|
||||
if subCh == ch {
|
||||
delete(b.subscribers, subCh)
|
||||
close(subCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, the channel wasn't found
|
||||
log.Println("Warning: Subscriber channel not found - cannot unsubscribe")
|
||||
}
|
||||
|
||||
// Close shuts down the broker and closes all subscriber channels
|
||||
func (b *Broker) Close() {
|
||||
// Signal the dispatch loop to stop
|
||||
close(b.done)
|
||||
|
||||
// Wait for the dispatch loop to exit
|
||||
b.wg.Wait()
|
||||
|
||||
// Close all subscriber channels
|
||||
b.subscriberMutex.Lock()
|
||||
defer b.subscriberMutex.Unlock()
|
||||
|
||||
for ch := range b.subscribers {
|
||||
close(ch)
|
||||
}
|
||||
b.subscribers = make(map[chan *Packet]struct{})
|
||||
}
|
||||
|
||||
// dispatchLoop continuously reads from the source channel and distributes to subscribers
|
||||
func (b *Broker) dispatchLoop() {
|
||||
defer b.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.done:
|
||||
// Broker is shutting down
|
||||
return
|
||||
|
||||
case packet, ok := <-b.sourceChan:
|
||||
if !ok {
|
||||
// Source channel has been closed, shut down the broker
|
||||
log.Println("Source channel closed, shutting down broker")
|
||||
b.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Distribute the packet to all subscribers
|
||||
b.broadcast(packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// broadcast sends a packet to all active subscribers without blocking
|
||||
func (b *Broker) broadcast(packet *Packet) {
|
||||
// Take a read lock to get a snapshot of the subscribers
|
||||
b.subscriberMutex.RLock()
|
||||
subscribers := make([]chan *Packet, 0, len(b.subscribers))
|
||||
for ch := range b.subscribers {
|
||||
subscribers = append(subscribers, ch)
|
||||
}
|
||||
b.subscriberMutex.RUnlock()
|
||||
|
||||
// Distribute to all subscribers
|
||||
for _, ch := range subscribers {
|
||||
// Use a goroutine and recover to ensure sending to a closed channel doesn't panic
|
||||
go func(ch chan *Packet) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// This can happen if the channel was closed after we took a snapshot
|
||||
log.Println("Warning: Recovered from panic in broadcast, channel likely closed")
|
||||
}
|
||||
}()
|
||||
|
||||
// Try to send without blocking
|
||||
select {
|
||||
case ch <- packet:
|
||||
// Message delivered successfully
|
||||
default:
|
||||
// Channel buffer is full, log warning and drop the message
|
||||
log.Println("Warning: Subscriber buffer full, dropping message")
|
||||
}
|
||||
}(ch)
|
||||
}
|
||||
}
|
||||
245
mqtt/broker_test.go
Normal file
245
mqtt/broker_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"meshstream/decoder"
|
||||
)
|
||||
|
||||
// TestBrokerSubscribeUnsubscribe tests the basic subscribe and unsubscribe functionality
|
||||
func TestBrokerSubscribeUnsubscribe(t *testing.T) {
|
||||
// Create a test source channel
|
||||
sourceChan := make(chan *Packet, 10)
|
||||
|
||||
// Create a broker with the source channel
|
||||
broker := NewBroker(sourceChan)
|
||||
defer broker.Close()
|
||||
|
||||
// Subscribe to the broker
|
||||
subscriber1 := broker.Subscribe(5)
|
||||
subscriber2 := broker.Subscribe(5)
|
||||
|
||||
// Keep track of the internal broker state for testing
|
||||
broker.subscriberMutex.RLock()
|
||||
subscriberCount := len(broker.subscribers)
|
||||
broker.subscriberMutex.RUnlock()
|
||||
|
||||
if subscriberCount != 2 {
|
||||
t.Errorf("Expected 2 subscribers, got %d", subscriberCount)
|
||||
}
|
||||
|
||||
// We need to use sequential packets because our implementation is asynchronous
|
||||
// and exact packet matching may not work reliably
|
||||
|
||||
// First packet with ID 1
|
||||
packet1 := &Packet{
|
||||
DecodedPacket: &decoder.DecodedPacket{ID: 1},
|
||||
TopicInfo: &decoder.TopicInfo{},
|
||||
}
|
||||
|
||||
// Send the packet
|
||||
sourceChan <- packet1
|
||||
|
||||
// Both subscribers should receive the packet
|
||||
select {
|
||||
case received := <-subscriber1:
|
||||
if received.ID != 1 {
|
||||
t.Errorf("Expected subscriber1 to receive packet with ID 1, got %d", received.ID)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("subscriber1 didn't receive packet within timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-subscriber2:
|
||||
if received.ID != 1 {
|
||||
t.Errorf("Expected subscriber2 to receive packet with ID 1, got %d", received.ID)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("subscriber2 didn't receive packet within timeout")
|
||||
}
|
||||
|
||||
// Unsubscribe the first subscriber
|
||||
broker.Unsubscribe(subscriber1)
|
||||
|
||||
// Verify the subscriber was removed
|
||||
broker.subscriberMutex.RLock()
|
||||
subscriberCount = len(broker.subscribers)
|
||||
broker.subscriberMutex.RUnlock()
|
||||
|
||||
if subscriberCount != 1 {
|
||||
t.Errorf("Expected 1 subscriber after unsubscribe, got %d", subscriberCount)
|
||||
}
|
||||
|
||||
// Second packet with ID 2
|
||||
packet2 := &Packet{
|
||||
DecodedPacket: &decoder.DecodedPacket{ID: 2},
|
||||
TopicInfo: &decoder.TopicInfo{},
|
||||
}
|
||||
|
||||
// Send the second packet
|
||||
sourceChan <- packet2
|
||||
|
||||
// The second subscriber should receive the packet
|
||||
select {
|
||||
case received := <-subscriber2:
|
||||
if received.ID != 2 {
|
||||
t.Errorf("Expected subscriber2 to receive packet with ID 2, got %d", received.ID)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("subscriber2 didn't receive second packet within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBrokerMultipleSubscribers tests broadcasting to many subscribers
|
||||
func TestBrokerMultipleSubscribers(t *testing.T) {
|
||||
// Create a test source channel
|
||||
sourceChan := make(chan *Packet, 10)
|
||||
|
||||
// Create a broker with the source channel
|
||||
broker := NewBroker(sourceChan)
|
||||
defer broker.Close()
|
||||
|
||||
// Create multiple subscribers
|
||||
const numSubscribers = 10
|
||||
subscribers := make([]<-chan *Packet, numSubscribers)
|
||||
for i := 0; i < numSubscribers; i++ {
|
||||
subscribers[i] = broker.Subscribe(5)
|
||||
}
|
||||
|
||||
// Send a test packet with ID 42
|
||||
testPacket := &Packet{
|
||||
DecodedPacket: &decoder.DecodedPacket{ID: 42},
|
||||
TopicInfo: &decoder.TopicInfo{},
|
||||
}
|
||||
sourceChan <- testPacket
|
||||
|
||||
// All subscribers should receive the packet
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numSubscribers)
|
||||
|
||||
for i, subscriber := range subscribers {
|
||||
go func(idx int, ch <-chan *Packet) {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case received := <-ch:
|
||||
if received.ID != 42 {
|
||||
t.Errorf("subscriber %d expected packet ID 42, got %d", idx, received.ID)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Errorf("subscriber %d didn't receive packet within timeout", idx)
|
||||
}
|
||||
}(i, subscriber)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestBrokerSlowSubscriber tests that a slow subscriber doesn't block others
|
||||
func TestBrokerSlowSubscriber(t *testing.T) {
|
||||
// Create a test source channel
|
||||
sourceChan := make(chan *Packet, 10)
|
||||
|
||||
// Create a broker with the source channel
|
||||
broker := NewBroker(sourceChan)
|
||||
defer broker.Close()
|
||||
|
||||
// Create a slow subscriber with buffer size 1
|
||||
slowSubscriber := broker.Subscribe(1)
|
||||
|
||||
// And a normal subscriber
|
||||
normalSubscriber := broker.Subscribe(5)
|
||||
|
||||
// Verify we have two subscribers
|
||||
broker.subscriberMutex.RLock()
|
||||
subscriberCount := len(broker.subscribers)
|
||||
broker.subscriberMutex.RUnlock()
|
||||
|
||||
if subscriberCount != 2 {
|
||||
t.Errorf("Expected 2 subscribers, got %d", subscriberCount)
|
||||
}
|
||||
|
||||
// Send two packets quickly to fill the slow subscriber's buffer
|
||||
testPacket1 := &Packet{
|
||||
DecodedPacket: &decoder.DecodedPacket{ID: 101},
|
||||
TopicInfo: &decoder.TopicInfo{},
|
||||
}
|
||||
testPacket2 := &Packet{
|
||||
DecodedPacket: &decoder.DecodedPacket{ID: 102},
|
||||
TopicInfo: &decoder.TopicInfo{},
|
||||
}
|
||||
|
||||
sourceChan <- testPacket1
|
||||
|
||||
// Give the broker time to distribute the first packet
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
sourceChan <- testPacket2
|
||||
|
||||
// The normal subscriber should receive both packets
|
||||
select {
|
||||
case received := <-normalSubscriber:
|
||||
if received.ID != 101 {
|
||||
t.Errorf("normalSubscriber expected packet ID 101, got %d", received.ID)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("normalSubscriber didn't receive first packet within timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-normalSubscriber:
|
||||
if received.ID != 102 {
|
||||
t.Errorf("normalSubscriber expected packet ID 102, got %d", received.ID)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("normalSubscriber didn't receive second packet within timeout")
|
||||
}
|
||||
|
||||
// The slow subscriber should receive at least the first packet
|
||||
select {
|
||||
case received := <-slowSubscriber:
|
||||
if received.ID != 101 {
|
||||
t.Errorf("slowSubscriber expected packet ID 101, got %d", received.ID)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("slowSubscriber didn't receive first packet within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBrokerCloseWithSubscribers tests closing the broker with active subscribers
|
||||
func TestBrokerCloseWithSubscribers(t *testing.T) {
|
||||
// Create a test source channel
|
||||
sourceChan := make(chan *Packet, 10)
|
||||
|
||||
// Create a broker with the source channel
|
||||
broker := NewBroker(sourceChan)
|
||||
|
||||
// Subscribe to the broker
|
||||
subscriber := broker.Subscribe(5)
|
||||
|
||||
// Verify we have one subscriber
|
||||
broker.subscriberMutex.RLock()
|
||||
subscriberCount := len(broker.subscribers)
|
||||
broker.subscriberMutex.RUnlock()
|
||||
|
||||
if subscriberCount != 1 {
|
||||
t.Errorf("Expected 1 subscriber, got %d", subscriberCount)
|
||||
}
|
||||
|
||||
// Close the broker - this should close all subscriber channels
|
||||
broker.Close()
|
||||
|
||||
// Trying to read from the subscriber channel should not block
|
||||
// since it should be closed
|
||||
select {
|
||||
case _, ok := <-subscriber:
|
||||
if ok {
|
||||
t.Error("Expected subscriber channel to be closed")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Subscriber channel should be closed but isn't")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user