diff --git a/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs b/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs index 5578ab786..b3e13fc69 100644 --- a/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs +++ b/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs @@ -156,10 +156,11 @@ public async Task DispatchApplicationMessage( var matchingSubscribersCount = 0; try { - var clonedMessage = CloneApplicationMessage(applicationMessage); + var applicationMessageCopy = new Lazy(() => applicationMessage.Clone()); if (applicationMessage.Retain) { - await _retainedMessagesManager.UpdateMessage(senderId, clonedMessage.Value).ConfigureAwait(false); + // applicationMessage must be copied + await _retainedMessagesManager.UpdateMessage(senderId, applicationMessage, applicationMessageCopy).ConfigureAwait(false); } List subscriberSessions; @@ -202,7 +203,8 @@ public async Task DispatchApplicationMessage( } } - var publishPacketCopy = MqttPublishPacketFactory.Create(clonedMessage.Value); + // applicationMessage must be copied + var publishPacketCopy = MqttPublishPacketFactory.Create(applicationMessageCopy.Value); publishPacketCopy.QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel; publishPacketCopy.SubscriptionIdentifiers = checkSubscriptionsResult.SubscriptionIdentifiers; @@ -508,26 +510,6 @@ public Task UnsubscribeAsync(string clientId, ICollection topicFilters) return GetClientSession(clientId).Unsubscribe(fakeUnsubscribePacket, CancellationToken.None); } - static Lazy CloneApplicationMessage(MqttApplicationMessage m) - { - return new Lazy( - () => new MqttApplicationMessage - { - ContentType = m.ContentType, - CorrelationData = m.CorrelationData?.ToArray(), - Dup = m.Dup, - MessageExpiryInterval = m.MessageExpiryInterval, - Payload = m.Payload.IsEmpty ? default : new ReadOnlySequence(m.Payload.ToArray()), - PayloadFormatIndicator = m.PayloadFormatIndicator, - QualityOfServiceLevel = m.QualityOfServiceLevel, - Retain = m.Retain, - ResponseTopic = m.ResponseTopic, - Topic = m.Topic, - UserProperties = m.UserProperties?.Select(u => u.Clone()).ToList(), - SubscriptionIdentifiers = m.SubscriptionIdentifiers?.ToList(), - TopicAlias = m.TopicAlias - }); - } MqttConnectedClient CreateClient(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, MqttSession session) { @@ -699,39 +681,39 @@ static bool ShouldPersistSession(MqttConnectedClient connectedClient) switch (connectedClient.ChannelAdapter.PacketFormatterAdapter.ProtocolVersion) { case MqttProtocolVersion.V500: - { - // MQTT 5.0 section 3.1.2.11.2 - // The Client and Server MUST store the Session State after the Network Connection is closed if the Session Expiry Interval is greater than 0 [MQTT-3.1.2-23]. - // - // A Client that only wants to process messages while connected will set the Clean Start to 1 and set the Session Expiry Interval to 0. - // It will not receive Application Messages published before it connected and has to subscribe afresh to any topics that it is interested - // in each time it connects. - - var effectiveSessionExpiryInterval = connectedClient.DisconnectPacket?.SessionExpiryInterval ?? 0U; - if (effectiveSessionExpiryInterval == 0U) { - // From RFC: If the Session Expiry Interval is absent, the Session Expiry Interval in the CONNECT packet is used. - effectiveSessionExpiryInterval = connectedClient.ConnectPacket.SessionExpiryInterval; - } + // MQTT 5.0 section 3.1.2.11.2 + // The Client and Server MUST store the Session State after the Network Connection is closed if the Session Expiry Interval is greater than 0 [MQTT-3.1.2-23]. + // + // A Client that only wants to process messages while connected will set the Clean Start to 1 and set the Session Expiry Interval to 0. + // It will not receive Application Messages published before it connected and has to subscribe afresh to any topics that it is interested + // in each time it connects. + + var effectiveSessionExpiryInterval = connectedClient.DisconnectPacket?.SessionExpiryInterval ?? 0U; + if (effectiveSessionExpiryInterval == 0U) + { + // From RFC: If the Session Expiry Interval is absent, the Session Expiry Interval in the CONNECT packet is used. + effectiveSessionExpiryInterval = connectedClient.ConnectPacket.SessionExpiryInterval; + } - return effectiveSessionExpiryInterval != 0U; - } + return effectiveSessionExpiryInterval != 0U; + } case MqttProtocolVersion.V311: - { - // MQTT 3.1.1 section 3.1.2.4: persist only if 'not CleanSession' - // - // If CleanSession is set to 1, the Client and Server MUST discard any previous Session and start a new one. - // This Session lasts as long as the Network Connection. State data associated with this Session MUST NOT be - // reused in any subsequent Session [MQTT-3.1.2-6]. + { + // MQTT 3.1.1 section 3.1.2.4: persist only if 'not CleanSession' + // + // If CleanSession is set to 1, the Client and Server MUST discard any previous Session and start a new one. + // This Session lasts as long as the Network Connection. State data associated with this Session MUST NOT be + // reused in any subsequent Session [MQTT-3.1.2-6]. - return !connectedClient.ConnectPacket.CleanSession; - } + return !connectedClient.ConnectPacket.CleanSession; + } case MqttProtocolVersion.V310: - { - return true; - } + { + return true; + } default: throw new NotSupportedException(); diff --git a/Source/MQTTnet.Server/Internal/MqttRetainedMessagesManager.cs b/Source/MQTTnet.Server/Internal/MqttRetainedMessagesManager.cs index 854696c73..57d4841d5 100644 --- a/Source/MQTTnet.Server/Internal/MqttRetainedMessagesManager.cs +++ b/Source/MQTTnet.Server/Internal/MqttRetainedMessagesManager.cs @@ -50,7 +50,7 @@ public async Task Start() } } - public async Task UpdateMessage(string clientId, MqttApplicationMessage applicationMessage) + public async Task UpdateMessage(string clientId, MqttApplicationMessage applicationMessage, Lazy applicationMessageCopy) { ArgumentNullException.ThrowIfNull(applicationMessage); @@ -61,10 +61,7 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat lock (_messages) { - var payload = applicationMessage.Payload; - var hasPayload = payload.Length > 0; - - if (!hasPayload) + if (applicationMessage.Payload.IsEmpty) { saveIsRequired = _messages.Remove(applicationMessage.Topic); _logger.Verbose("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic); @@ -73,15 +70,15 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat { if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage)) { - _messages[applicationMessage.Topic] = applicationMessage; + _messages[applicationMessage.Topic] = applicationMessageCopy.Value; saveIsRequired = true; } else { if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || - !MqttMemoryHelper.SequenceEqual(existingMessage.Payload, payload)) + !MqttMemoryHelper.SequenceEqual(existingMessage.Payload, applicationMessage.Payload)) { - _messages[applicationMessage.Topic] = applicationMessage; + _messages[applicationMessage.Topic] = applicationMessageCopy.Value; saveIsRequired = true; } } @@ -99,7 +96,7 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat { using (await _storageAccessLock.EnterAsync().ConfigureAwait(false)) { - var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessage, messagesForSave); + var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessageCopy.Value, messagesForSave); await _eventContainer.RetainedMessageChangedEvent.InvokeAsync(eventArgs).ConfigureAwait(false); } } diff --git a/Source/MQTTnet.Server/MqttServer.cs b/Source/MQTTnet.Server/MqttServer.cs index dc30ac93f..96f8d586f 100644 --- a/Source/MQTTnet.Server/MqttServer.cs +++ b/Source/MQTTnet.Server/MqttServer.cs @@ -352,7 +352,7 @@ public Task UpdateRetainedMessageAsync(MqttApplicationMessage retainedMessage) ThrowIfDisposed(); ThrowIfNotStarted(); - return _retainedMessagesManager?.UpdateMessage(string.Empty, retainedMessage); + return _retainedMessagesManager?.UpdateMessage(string.Empty, retainedMessage, new Lazy(retainedMessage)); } protected override void Dispose(bool disposing) diff --git a/Source/MQTTnet.Tests/Mockups/MqttApplicationMessageReceived.cs b/Source/MQTTnet.Tests/Mockups/MqttApplicationMessageReceived.cs new file mode 100644 index 000000000..17221a771 --- /dev/null +++ b/Source/MQTTnet.Tests/Mockups/MqttApplicationMessageReceived.cs @@ -0,0 +1,4 @@ +namespace MQTTnet.Tests.Mockups +{ + public record MqttApplicationMessageReceived(string ClientId, MqttApplicationMessage ApplicationMessage); +} diff --git a/Source/MQTTnet.Tests/Mockups/TestApplicationMessageReceivedHandler.cs b/Source/MQTTnet.Tests/Mockups/TestApplicationMessageReceivedHandler.cs index 66b44bd16..c145aec03 100644 --- a/Source/MQTTnet.Tests/Mockups/TestApplicationMessageReceivedHandler.cs +++ b/Source/MQTTnet.Tests/Mockups/TestApplicationMessageReceivedHandler.cs @@ -14,13 +14,16 @@ namespace MQTTnet.Tests.Mockups { public sealed class TestApplicationMessageReceivedHandler { - readonly List _receivedEventArgs = new List(); + readonly IMqttClient _mqttClient; + readonly List _receivedEventArgs = new(); + public TestApplicationMessageReceivedHandler(IMqttClient mqttClient) { ArgumentNullException.ThrowIfNull(mqttClient); mqttClient.ApplicationMessageReceivedAsync += OnApplicationMessageReceivedAsync; + _mqttClient = mqttClient; } public int Count @@ -34,7 +37,7 @@ public int Count } } - public List ReceivedEventArgs + public List ReceivedEventArgs { get { @@ -72,7 +75,11 @@ Task OnApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventArgs e { lock (_receivedEventArgs) { - _receivedEventArgs.Add(eventArgs); + var applicationMessage = _mqttClient.Options.ReceivedApplicationMessageQueueable + ? eventArgs.ApplicationMessage + : eventArgs.ApplicationMessage.Clone(); + + _receivedEventArgs.Add(new MqttApplicationMessageReceived(eventArgs.ClientId, applicationMessage)); } return CompletedTask.Instance; diff --git a/Source/MQTTnet.Tests/Server/QoS_Tests.cs b/Source/MQTTnet.Tests/Server/QoS_Tests.cs index b48a1a99c..1262c6d9f 100644 --- a/Source/MQTTnet.Tests/Server/QoS_Tests.cs +++ b/Source/MQTTnet.Tests/Server/QoS_Tests.cs @@ -113,13 +113,24 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_2() [TestMethod] public async Task Preserve_Message_Order_For_Queued_Messages() + { + await Preserve_Message_Order_For_Queued_Messages(receivedPublishPacketQueueable: true); + } + + [TestMethod] + public async Task Preserve_Message_Order_For_Queued_Messages_NoQueue() + { + await Preserve_Message_Order_For_Queued_Messages(receivedPublishPacketQueueable: false); + } + + private async Task Preserve_Message_Order_For_Queued_Messages(bool receivedPublishPacketQueueable) { using (var testEnvironment = CreateTestEnvironment()) { var server = await testEnvironment.StartServer(o => o.WithPersistentSessions()); // Create a session which will contain the messages. - var dummyClient = await testEnvironment.ConnectClient(o => o.WithClientId("A").WithCleanSession(false)); + var dummyClient = await testEnvironment.ConnectClient(o => o.WithClientId("A").WithCleanSession(false).WithReceivedApplicationMessageQueueable(receivedPublishPacketQueueable)); await dummyClient.SubscribeAsync("#", MqttQualityOfServiceLevel.AtLeastOnce); dummyClient.Dispose(); diff --git a/Source/MQTTnet/MqttApplicationMessage.cs b/Source/MQTTnet/MqttApplicationMessage.cs index eb402e155..9d7c72051 100644 --- a/Source/MQTTnet/MqttApplicationMessage.cs +++ b/Source/MQTTnet/MqttApplicationMessage.cs @@ -8,6 +8,7 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.Linq; namespace MQTTnet { @@ -142,5 +143,29 @@ public ArraySegment PayloadSegment /// Hint: MQTT 5 feature only. /// public List UserProperties { get; set; } + + /// + /// Deep clone all fields. + /// + /// + public MqttApplicationMessage Clone() + { + return new MqttApplicationMessage + { + ContentType = ContentType, + CorrelationData = CorrelationData == default ? default : CorrelationData.ToArray(), + Dup = Dup, + MessageExpiryInterval = MessageExpiryInterval, + Payload = Payload.IsEmpty ? default : new ReadOnlySequence(Payload.ToArray()), + PayloadFormatIndicator = PayloadFormatIndicator, + QualityOfServiceLevel = QualityOfServiceLevel, + Retain = Retain, + ResponseTopic = ResponseTopic, + Topic = Topic, + UserProperties = UserProperties?.Select(u => u.Clone()).ToList(), + SubscriptionIdentifiers = SubscriptionIdentifiers?.ToList(), + TopicAlias = TopicAlias + }; + } } } \ No newline at end of file diff --git a/Source/MQTTnet/MqttClient.cs b/Source/MQTTnet/MqttClient.cs index 9d19ce574..3ec118952 100644 --- a/Source/MQTTnet/MqttClient.cs +++ b/Source/MQTTnet/MqttClient.cs @@ -2,10 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Diagnostics.PacketInspection; @@ -15,6 +11,10 @@ using MQTTnet.PacketDispatcher; using MQTTnet.Packets; using MQTTnet.Protocol; +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet; @@ -270,21 +270,21 @@ public Task PublishAsync(MqttApplicationMessage applica switch (applicationMessage.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: - { - return PublishAtMostOnce(publishPacket, cancellationToken); - } + { + return PublishAtMostOnce(publishPacket, cancellationToken); + } case MqttQualityOfServiceLevel.AtLeastOnce: - { - return PublishAtLeastOnce(publishPacket, cancellationToken); - } + { + return PublishAtLeastOnce(publishPacket, cancellationToken); + } case MqttQualityOfServiceLevel.ExactlyOnce: - { - return PublishExactlyOnce(publishPacket, cancellationToken); - } + { + return PublishExactlyOnce(publishPacket, cancellationToken); + } default: - { - throw new NotSupportedException(); - } + { + throw new NotSupportedException(); + } } } @@ -395,34 +395,34 @@ Task AcknowledgeReceivedPublishPacket(MqttApplicationMessageReceivedEventArgs ev switch (eventArgs.PublishPacket.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: - { - // no response required - break; - } - case MqttQualityOfServiceLevel.AtLeastOnce: - { - if (!eventArgs.ProcessingFailed) { - var pubAckPacket = MqttPubAckPacketFactory.Create(eventArgs); - return Send(pubAckPacket, cancellationToken); + // no response required + break; } + case MqttQualityOfServiceLevel.AtLeastOnce: + { + if (!eventArgs.ProcessingFailed) + { + var pubAckPacket = MqttPubAckPacketFactory.Create(eventArgs); + return Send(pubAckPacket, cancellationToken); + } - break; - } + break; + } case MqttQualityOfServiceLevel.ExactlyOnce: - { - if (!eventArgs.ProcessingFailed) { - var pubRecPacket = MqttPubRecPacketFactory.Create(eventArgs); - return Send(pubRecPacket, cancellationToken); - } + if (!eventArgs.ProcessingFailed) + { + var pubRecPacket = MqttPubRecPacketFactory.Create(eventArgs); + return Send(pubRecPacket, cancellationToken); + } - break; - } + break; + } default: - { - throw new MqttProtocolViolationException("Received a not supported QoS level."); - } + { + throw new MqttProtocolViolationException("Received a not supported QoS level."); + } } return CompletedTask.Instance; @@ -442,22 +442,22 @@ async Task Authenticate(IMqttChannelAdapter channelAdap switch (receivedPacket) { case MqttConnAckPacket connAckPacket: - { - result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion); - break; - } + { + result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion); + break; + } case MqttAuthPacket _: - { - throw new NotSupportedException("Extended authentication handler is not yet supported"); - } + { + throw new NotSupportedException("Extended authentication handler is not yet supported"); + } case null: - { - throw new MqttCommunicationException("Connection closed."); - } + { + throw new MqttCommunicationException("Connection closed."); + } default: - { - throw new InvalidOperationException($"Received an unexpected MQTT packet ({receivedPacket})."); - } + { + throw new InvalidOperationException($"Received an unexpected MQTT packet ({receivedPacket})."); + } } } catch (Exception exception) @@ -513,7 +513,9 @@ async Task ConnectInternal(IMqttChannelAdapter channelA var connectResult = await Authenticate(channelAdapter, Options, effectiveCancellationToken.Token).ConfigureAwait(false); if (connectResult.ResultCode == MqttClientConnectResultCode.Success) { - _publishPacketReceiverTask = Task.Run(() => ProcessReceivedPublishPackets(backgroundCancellationToken), backgroundCancellationToken); + _publishPacketReceiverTask = Options.ReceivedApplicationMessageQueueable + ? Task.Run(() => ProcessReceivedPublishPackets(backgroundCancellationToken), backgroundCancellationToken) + : Task.CompletedTask; _packetReceiverTask = Task.Run(() => ReceivePacketsLoop(backgroundCancellationToken), backgroundCancellationToken); } @@ -700,6 +702,28 @@ async Task ProcessReceivedPublishPackets(CancellationToken cancellationToken) } } + async Task ProcessReceivedPublishPacket(MqttPublishPacket publishPacket, CancellationToken cancellationToken) + { + try + { + var eventArgs = await HandleReceivedApplicationMessage(publishPacket).ConfigureAwait(false); + if (eventArgs.AutoAcknowledge) + { + await eventArgs.AcknowledgeAsync(cancellationToken).ConfigureAwait(false); + } + } + catch (ObjectDisposedException) + { + } + catch (OperationCanceledException) + { + } + catch (Exception exception) + { + _logger.Error(exception, "Error while handling application message"); + } + } + Task ProcessReceivedPubRecPacket(MqttPubRecPacket pubRecPacket, CancellationToken cancellationToken) { if (_packetDispatcher.TryDispatch(pubRecPacket)) @@ -947,7 +971,15 @@ async Task TryProcessReceivedPacket(MqttPacket packet, CancellationToken cancell switch (packet) { case MqttPublishPacket publishPacket: - EnqueueReceivedPublishPacket(publishPacket); + if (Options.ReceivedApplicationMessageQueueable) + { + // publishPacket must be copied + EnqueueReceivedPublishPacket(publishPacket.Clone()); + } + else + { + await ProcessReceivedPublishPacket(publishPacket, cancellationToken).ConfigureAwait(false); + } break; case MqttPubRecPacket pubRecPacket: await ProcessReceivedPubRecPacket(pubRecPacket, cancellationToken).ConfigureAwait(false); @@ -967,14 +999,14 @@ async Task TryProcessReceivedPacket(MqttPacket packet, CancellationToken cancell case MqttPingReqPacket _: throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a client to the server only."); default: - { - if (!_packetDispatcher.TryDispatch(packet)) { - throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); - } + if (!_packetDispatcher.TryDispatch(packet)) + { + throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + } - break; - } + break; + } } } catch (Exception exception) diff --git a/Source/MQTTnet/Options/MqttClientOptions.cs b/Source/MQTTnet/Options/MqttClientOptions.cs index d337ea0e8..04d162b81 100644 --- a/Source/MQTTnet/Options/MqttClientOptions.cs +++ b/Source/MQTTnet/Options/MqttClientOptions.cs @@ -225,4 +225,9 @@ public sealed class MqttClientOptions /// Do not change this value when no memory issues are experienced. /// public int WriterBufferSizeMax { get; set; } = 65535; + + /// + /// When enabled, received ApplicationMessage will be deep cloned and enqueued. + /// + public bool ReceivedApplicationMessageQueueable { get; set; } = true; } \ No newline at end of file diff --git a/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs index a4dedff97..8faa8e068 100644 --- a/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs @@ -482,4 +482,10 @@ public MqttClientOptionsBuilder WithWillUserProperty(string name, string value) _options.WillUserProperties.Add(new MqttUserProperty(name, value)); return this; } + + public MqttClientOptionsBuilder WithReceivedApplicationMessageQueueable(bool value) + { + _options.ReceivedApplicationMessageQueueable = value; + return this; + } } \ No newline at end of file diff --git a/Source/MQTTnet/Packets/MqttPublishPacket.cs b/Source/MQTTnet/Packets/MqttPublishPacket.cs index 9edc789e2..2d9261c9a 100644 --- a/Source/MQTTnet/Packets/MqttPublishPacket.cs +++ b/Source/MQTTnet/Packets/MqttPublishPacket.cs @@ -5,6 +5,7 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.Linq; using MQTTnet.Protocol; namespace MQTTnet.Packets; @@ -39,6 +40,31 @@ public sealed class MqttPublishPacket : MqttPacketWithIdentifier public List UserProperties { get; set; } + /// + /// Deep clone all fields. + /// + /// + public MqttPublishPacket Clone() + { + return new MqttPublishPacket + { + PacketIdentifier = PacketIdentifier, + ContentType = ContentType, + CorrelationData = CorrelationData == default ? default : CorrelationData.ToArray(), + Dup = Dup, + MessageExpiryInterval = MessageExpiryInterval, + Payload = Payload.IsEmpty ? default : new ReadOnlySequence(Payload.ToArray()), + PayloadFormatIndicator = PayloadFormatIndicator, + QualityOfServiceLevel = QualityOfServiceLevel, + Retain = Retain, + ResponseTopic = ResponseTopic, + Topic = Topic, + UserProperties = UserProperties?.Select(u => u.Clone()).ToList(), + SubscriptionIdentifiers = SubscriptionIdentifiers?.ToList(), + TopicAlias = TopicAlias + }; + } + public override string ToString() { return diff --git a/Source/MQTTnet/Packets/MqttUserProperty.cs b/Source/MQTTnet/Packets/MqttUserProperty.cs index 7370707b4..242077415 100644 --- a/Source/MQTTnet/Packets/MqttUserProperty.cs +++ b/Source/MQTTnet/Packets/MqttUserProperty.cs @@ -48,6 +48,10 @@ public override string ToString() return $"{Name} = {Value}"; } + /// + /// Deep clone all fields. + /// + /// public MqttUserProperty Clone() { return new MqttUserProperty(Name, Value);