diff --git a/Program.cs b/Program.cs index c6af7ac..7b3cfb5 100644 --- a/Program.cs +++ b/Program.cs @@ -13,132 +13,191 @@ using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Reflection; -var mqttFactory = new MqttServerFactory(); +await RunMqttServer(args); -// #if SSL -var currentPath = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!; -#pragma warning disable SYSLIB0057 // Type or member is obsolete -var certificate = new X509Certificate2(Path.Combine(currentPath, "certificate.pfx"), "large4cats", X509KeyStorageFlags.Exportable); -#pragma warning restore SYSLIB0057 // Type or member is obsolete - -var mqttServerOptions = new MqttServerOptionsBuilder() - .WithoutDefaultEndpoint() // This call disables the default unencrypted endpoint on port 1883 - .WithEncryptedEndpoint() - .WithEncryptedEndpointPort(8883) - .WithEncryptionCertificate(certificate.Export(X509ContentType.Pfx)) - .WithEncryptionSslProtocol(SslProtocols.Tls12) - .Build(); - Log.Logger.Information("Using SSL certificate for MQTT server"); - -// If you want to use a non-encrypted MQTT server, you can uncomment the following lines instead of the above -// var mqttServerOptions = new MqttServerOptionsBuilder() -// .WithDefaultEndpoint() -// .WithDefaultEndpointPort(1883) -// .Build(); -// Log.Logger.Information("Using unencrypted MQTT server"); -// #endif - -Log.Logger = new LoggerConfiguration() - .MinimumLevel.Debug() - .WriteTo.Console(new RenderedCompactJsonFormatter()) - // .WriteTo.File(new RenderedCompactJsonFormatter(), "log.json", rollingInterval: RollingInterval.Hour) // File logging can be enabled if needed - .CreateLogger(); - -using var mqttServer = mqttFactory.CreateMqttServer(mqttServerOptions); - -static Data? DecryptMeshPacket(ServiceEnvelope serviceEnvelope) +async Task RunMqttServer(string[] args) { - var nonce = new NonceGenerator(serviceEnvelope.Packet.From, serviceEnvelope.Packet.Id).Create(); + ConfigureLogging(); - var decrypted = PacketEncryption.TransformPacket(serviceEnvelope.Packet.Encrypted.ToByteArray(), nonce, Resources.DEFAULT_PSK); - var payload = Data.Parser.ParseFrom(decrypted); + // Create and configure MQTT server + using var mqttServer = new MqttServerFactory() + .CreateMqttServer(BuildMqttServerOptions()); + ConfigureMqttServer(mqttServer); - if (payload.Portnum > PortNum.UnknownApp && payload.Payload.Length > 0) - return payload; + // Set up host + using var host = CreateHostBuilder(args).Build(); + await host.StartAsync(); + var lifetime = host.Services.GetRequiredService(); - // Was not able to decrypt the payload - return null; + // Start MQTT server + await mqttServer.StartAsync(); + + // Configure graceful shutdown + await SetupGracefulShutdown(mqttServer, lifetime, host); } -mqttServer.InterceptingPublishAsync += async (args) => +void ConfigureLogging() +{ + Log.Logger = new LoggerConfiguration() + .MinimumLevel.Debug() + .WriteTo.Console(new RenderedCompactJsonFormatter()) + // .WriteTo.File(new RenderedCompactJsonFormatter(), "log.json", rollingInterval: RollingInterval.Hour) + .CreateLogger(); +} + +MqttServerOptions BuildMqttServerOptions() +{ + var currentPath = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!; + + #pragma warning disable SYSLIB0057 // Type or member is obsolete + var certificate = new X509Certificate2( + Path.Combine(currentPath, "certificate.pfx"), + "large4cats", + X509KeyStorageFlags.Exportable); + #pragma warning restore SYSLIB0057 + + var options = new MqttServerOptionsBuilder() + .WithoutDefaultEndpoint() + .WithEncryptedEndpoint() + .WithEncryptedEndpointPort(8883) + .WithEncryptionCertificate(certificate.Export(X509ContentType.Pfx)) + .WithEncryptionSslProtocol(SslProtocols.Tls12) + .Build(); + + Log.Logger.Information("Using SSL certificate for MQTT server"); + return options; +} + +void ConfigureMqttServer(MqttServer mqttServer) +{ + mqttServer.InterceptingPublishAsync += HandleInterceptingPublish; + mqttServer.InterceptingSubscriptionAsync += HandleInterceptingSubscription; + mqttServer.ValidatingConnectionAsync += HandleValidatingConnection; +} + +async Task HandleInterceptingPublish(InterceptingPublishEventArgs args) { try { if (args.ApplicationMessage.Payload.Length == 0) { Log.Logger.Warning("Received empty payload on topic {@Topic} from {@ClientId}", args.ApplicationMessage.Topic, args.ClientId); - args.ProcessPublish = false; // This will block empty packets + args.ProcessPublish = false; return; } + var serviceEnvelope = ServiceEnvelope.Parser.ParseFrom(args.ApplicationMessage.Payload); - // Block malformed service envelopes / packets - if ( - String.IsNullOrWhiteSpace(serviceEnvelope.ChannelId) || + if (!IsValidServiceEnvelope(serviceEnvelope)) + { + Log.Logger.Warning("Service envelope or packet is malformed. Blocking packet on topic {@Topic} from {@ClientId}", + args.ApplicationMessage.Topic, args.ClientId); + args.ProcessPublish = false; + return; + } + + // Spot for any async operations we might want to perform + await Task.FromResult(0); + + var data = DecryptMeshPacket(serviceEnvelope); + + // Uncomment to block unrecognized packets + // if (data == null) + // { + // Log.Logger.Warning("Service envelope does not contain a valid packet. Blocking packet"); + // args.ProcessPublish = false; + // return; + // } + + LogReceivedMessage(args.ApplicationMessage.Topic, args.ClientId, data); + args.ProcessPublish = true; + } + catch (InvalidProtocolBufferException) + { + Log.Logger.Warning("Failed to decode presumed protobuf packet. Blocking"); + args.ProcessPublish = false; + } + catch (Exception ex) + { + Log.Logger.Error("Exception occurred while processing packet on {@Topic} from {@ClientId}: {@Exception}", + args.ApplicationMessage.Topic, args.ClientId, ex.Message); + args.ProcessPublish = false; + } +} + +Task HandleInterceptingSubscription(InterceptingSubscriptionEventArgs args) +{ + args.ProcessSubscription = true; // Add filtering logic here if needed + return Task.CompletedTask; +} + +Task HandleValidatingConnection(ValidatingConnectionEventArgs args) +{ + // Add authentication logic here + args.ReasonCode = MqttConnectReasonCode.Success; + return Task.CompletedTask; +} + +bool IsValidServiceEnvelope(ServiceEnvelope serviceEnvelope) +{ + return !(String.IsNullOrWhiteSpace(serviceEnvelope.ChannelId) || String.IsNullOrWhiteSpace(serviceEnvelope.GatewayId) || serviceEnvelope.Packet == null || serviceEnvelope.Packet.Id < 1 || serviceEnvelope.Packet.From < 1 || serviceEnvelope.Packet.Encrypted == null || serviceEnvelope.Packet.Encrypted.Length < 1 || - serviceEnvelope.Packet.Decoded != null) - { - Log.Logger.Warning("Service envelope or packet is malformed. Blocking packet on topic {@Topic} from {@ClientId}", args.ApplicationMessage.Topic, args.ClientId); - args.ProcessPublish = false; - return; - } + serviceEnvelope.Packet.Decoded != null); +} - var data = DecryptMeshPacket(serviceEnvelope); - // If we were not able to decrypt the packet, it is likely encrypted with an unknown PSK - // Uncomment the following lines if you want to block these packets - // if (data == null) - // { - // Log.Logger.Warning("Service envelope does not contain a valid packet. Blocking packet"); - // args.ProcessPublish = false; // This will block packets that are not valid protobuf packets - // return; - // } - - if (data?.Portnum == PortNum.TextMessageApp) - { - Log.Logger.Information("Received text message on topic {@Topic} from {@ClientId}: {@Message}", - args.ApplicationMessage.Topic, args.ClientId, data.Payload.ToStringUtf8()); - } - else - { - Log.Logger.Information("Received packet on topic {@Topic} from {@ClientId} with port number: {@Portnum}", - args.ApplicationMessage.Topic, args.ClientId, data?.Portnum); - } - - // Any further validation logic to block a packet can be added here - args.ProcessPublish = true; - } - catch (InvalidProtocolBufferException) - { - Log.Logger.Warning("Failed to decode presumed protobuf packet. Blocking"); - args.ProcessPublish = false; // This will block packets encrypted on unknown PSKs - } - catch (Exception ex) - { - Log.Logger.Error("Exception occured while attempting to decode packet on {@Topic} from {@ClientId}: {@Exception}", args.ApplicationMessage.Topic, args.ClientId, ex.Message); - args.ProcessPublish = false; // This will block packets that caused us to encounter an exception - } -}; - -mqttServer.InterceptingSubscriptionAsync += (args) => +void LogReceivedMessage(string topic, string clientId, Data? data) { - args.ProcessSubscription = true; // Subscription filtering logic can be added here to only allow certain topics + if (data?.Portnum == PortNum.TextMessageApp) + { + Log.Logger.Information("Received text message on topic {@Topic} from {@ClientId}: {@Message}", + topic, clientId, data.Payload.ToStringUtf8()); + } + else + { + Log.Logger.Information("Received packet on topic {@Topic} from {@ClientId} with port number: {@Portnum}", + topic, clientId, data?.Portnum); + } +} - return Task.CompletedTask; -}; - -mqttServer.ValidatingConnectionAsync += (args) => +static Data? DecryptMeshPacket(ServiceEnvelope serviceEnvelope) { - args.ReasonCode = true ? // Authentication logic can be added here - MqttConnectReasonCode.Success : MqttConnectReasonCode.BadUserNameOrPassword; + var nonce = new NonceGenerator(serviceEnvelope.Packet.From, serviceEnvelope.Packet.Id).Create(); + var decrypted = PacketEncryption.TransformPacket(serviceEnvelope.Packet.Encrypted.ToByteArray(), nonce, Resources.DEFAULT_PSK); + var payload = Data.Parser.ParseFrom(decrypted); - // You can block connections based on client ID, username, ip, etc. - return Task.CompletedTask; -}; + if (payload.Portnum > PortNum.UnknownApp && payload.Payload.Length > 0) + return payload; + + return null; +} + +async Task SetupGracefulShutdown(MqttServer mqttServer, IHostApplicationLifetime lifetime, IHost host) +{ + var ended = new ManualResetEventSlim(); + var starting = new ManualResetEventSlim(); + + AssemblyLoadContext.Default.Unloading += ctx => + { + starting.Set(); + Log.Logger.Debug("Waiting for completion"); + ended.Wait(); + }; + + starting.Wait(); + + Log.Logger.Debug("Received signal gracefully shutting down"); + await mqttServer.StopAsync(); + Thread.Sleep(500); + ended.Set(); + + lifetime.StopApplication(); + host.WaitForShutdownAsync().Wait(); +} static IHostBuilder CreateHostBuilder(string[] args) { @@ -146,33 +205,6 @@ static IHostBuilder CreateHostBuilder(string[] args) .UseConsoleLifetime() .ConfigureServices((hostContext, services) => { - services - .AddSingleton(Console.Out); + services.AddSingleton(Console.Out); }); } - - -using var host = CreateHostBuilder(args).Build(); -await host.StartAsync(); -var lifetime = host.Services.GetRequiredService(); -await mqttServer.StartAsync(); - -var ended = new ManualResetEventSlim(); -var starting = new ManualResetEventSlim(); - -AssemblyLoadContext.Default.Unloading += ctx => -{ - starting.Set(); - Log.Logger.Debug("Waiting for completion"); - ended.Wait(); -}; - -starting.Wait(); - -Log.Logger.Debug("Received signal gracefully shutting down"); -await mqttServer.StopAsync(); -Thread.Sleep(1000); -ended.Set(); - -lifetime.StopApplication(); -await host.WaitForShutdownAsync();