Skip to content

Commit

Permalink
Fixed client not copying PublishPacket before enqueuing it.
Browse files Browse the repository at this point in the history
  • Loading branch information
xljiulang committed Dec 5, 2024
1 parent 64a5c5b commit 9ce7ee3
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 121 deletions.
80 changes: 31 additions & 49 deletions Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,11 @@ public async Task<DispatchApplicationMessageResult> DispatchApplicationMessage(
var matchingSubscribersCount = 0;
try
{
var clonedMessage = CloneApplicationMessage(applicationMessage);
var applicationMessageCopy = new Lazy<MqttApplicationMessage>(() => applicationMessage.Clone());
if (applicationMessage.Retain)
{
await _retainedMessagesManager.UpdateMessage(senderId, clonedMessage.Value).ConfigureAwait(false);
// applicationMessage must be copied
await _retainedMessagesManager.UpdateMessage(senderId, applicationMessage, applicationMessageCopy).ConfigureAwait(false);
}

List<MqttSession> subscriberSessions;
Expand Down Expand Up @@ -202,7 +203,8 @@ public async Task<DispatchApplicationMessageResult> DispatchApplicationMessage(
}
}

var publishPacketCopy = MqttPublishPacketFactory.Create(clonedMessage.Value);
// applicationMessage must be copied
var publishPacketCopy = MqttPublishPacketFactory.Create(applicationMessageCopy.Value);
publishPacketCopy.QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel;
publishPacketCopy.SubscriptionIdentifiers = checkSubscriptionsResult.SubscriptionIdentifiers;

Expand Down Expand Up @@ -508,26 +510,6 @@ public Task UnsubscribeAsync(string clientId, ICollection<string> topicFilters)
return GetClientSession(clientId).Unsubscribe(fakeUnsubscribePacket, CancellationToken.None);
}

static Lazy<MqttApplicationMessage> CloneApplicationMessage(MqttApplicationMessage m)
{
return new Lazy<MqttApplicationMessage>(
() => new MqttApplicationMessage
{
ContentType = m.ContentType,
CorrelationData = m.CorrelationData?.ToArray(),
Dup = m.Dup,
MessageExpiryInterval = m.MessageExpiryInterval,
Payload = m.Payload.IsEmpty ? default : new ReadOnlySequence<byte>(m.Payload.ToArray()),
PayloadFormatIndicator = m.PayloadFormatIndicator,
QualityOfServiceLevel = m.QualityOfServiceLevel,
Retain = m.Retain,
ResponseTopic = m.ResponseTopic,
Topic = m.Topic,
UserProperties = m.UserProperties?.Select(u => u.Clone()).ToList(),
SubscriptionIdentifiers = m.SubscriptionIdentifiers?.ToList(),
TopicAlias = m.TopicAlias
});
}

MqttConnectedClient CreateClient(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, MqttSession session)
{
Expand Down Expand Up @@ -699,39 +681,39 @@ static bool ShouldPersistSession(MqttConnectedClient connectedClient)
switch (connectedClient.ChannelAdapter.PacketFormatterAdapter.ProtocolVersion)
{
case MqttProtocolVersion.V500:
{
// MQTT 5.0 section 3.1.2.11.2
// The Client and Server MUST store the Session State after the Network Connection is closed if the Session Expiry Interval is greater than 0 [MQTT-3.1.2-23].
//
// A Client that only wants to process messages while connected will set the Clean Start to 1 and set the Session Expiry Interval to 0.
// It will not receive Application Messages published before it connected and has to subscribe afresh to any topics that it is interested
// in each time it connects.

var effectiveSessionExpiryInterval = connectedClient.DisconnectPacket?.SessionExpiryInterval ?? 0U;
if (effectiveSessionExpiryInterval == 0U)
{
// From RFC: If the Session Expiry Interval is absent, the Session Expiry Interval in the CONNECT packet is used.
effectiveSessionExpiryInterval = connectedClient.ConnectPacket.SessionExpiryInterval;
}
// MQTT 5.0 section 3.1.2.11.2
// The Client and Server MUST store the Session State after the Network Connection is closed if the Session Expiry Interval is greater than 0 [MQTT-3.1.2-23].
//
// A Client that only wants to process messages while connected will set the Clean Start to 1 and set the Session Expiry Interval to 0.
// It will not receive Application Messages published before it connected and has to subscribe afresh to any topics that it is interested
// in each time it connects.

var effectiveSessionExpiryInterval = connectedClient.DisconnectPacket?.SessionExpiryInterval ?? 0U;
if (effectiveSessionExpiryInterval == 0U)
{
// From RFC: If the Session Expiry Interval is absent, the Session Expiry Interval in the CONNECT packet is used.
effectiveSessionExpiryInterval = connectedClient.ConnectPacket.SessionExpiryInterval;
}

return effectiveSessionExpiryInterval != 0U;
}
return effectiveSessionExpiryInterval != 0U;
}

case MqttProtocolVersion.V311:
{
// MQTT 3.1.1 section 3.1.2.4: persist only if 'not CleanSession'
//
// If CleanSession is set to 1, the Client and Server MUST discard any previous Session and start a new one.
// This Session lasts as long as the Network Connection. State data associated with this Session MUST NOT be
// reused in any subsequent Session [MQTT-3.1.2-6].
{
// MQTT 3.1.1 section 3.1.2.4: persist only if 'not CleanSession'
//
// If CleanSession is set to 1, the Client and Server MUST discard any previous Session and start a new one.
// This Session lasts as long as the Network Connection. State data associated with this Session MUST NOT be
// reused in any subsequent Session [MQTT-3.1.2-6].

return !connectedClient.ConnectPacket.CleanSession;
}
return !connectedClient.ConnectPacket.CleanSession;
}

case MqttProtocolVersion.V310:
{
return true;
}
{
return true;
}

default:
throw new NotSupportedException();
Expand Down
15 changes: 6 additions & 9 deletions Source/MQTTnet.Server/Internal/MqttRetainedMessagesManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public async Task Start()
}
}

public async Task UpdateMessage(string clientId, MqttApplicationMessage applicationMessage)
public async Task UpdateMessage(string clientId, MqttApplicationMessage applicationMessage, Lazy<MqttApplicationMessage> applicationMessageCopy)
{
ArgumentNullException.ThrowIfNull(applicationMessage);

Expand All @@ -61,10 +61,7 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat

lock (_messages)
{
var payload = applicationMessage.Payload;
var hasPayload = payload.Length > 0;

if (!hasPayload)
if (applicationMessage.Payload.IsEmpty)
{
saveIsRequired = _messages.Remove(applicationMessage.Topic);
_logger.Verbose("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic);
Expand All @@ -73,15 +70,15 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat
{
if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage))
{
_messages[applicationMessage.Topic] = applicationMessage;
_messages[applicationMessage.Topic] = applicationMessageCopy.Value;
saveIsRequired = true;
}
else
{
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel ||
!MqttMemoryHelper.SequenceEqual(existingMessage.Payload, payload))
!MqttMemoryHelper.SequenceEqual(existingMessage.Payload, applicationMessage.Payload))
{
_messages[applicationMessage.Topic] = applicationMessage;
_messages[applicationMessage.Topic] = applicationMessageCopy.Value;
saveIsRequired = true;
}
}
Expand All @@ -99,7 +96,7 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat
{
using (await _storageAccessLock.EnterAsync().ConfigureAwait(false))
{
var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessage, messagesForSave);
var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessageCopy.Value, messagesForSave);
await _eventContainer.RetainedMessageChangedEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
}
}
Expand Down
2 changes: 1 addition & 1 deletion Source/MQTTnet.Server/MqttServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ public Task UpdateRetainedMessageAsync(MqttApplicationMessage retainedMessage)
ThrowIfDisposed();
ThrowIfNotStarted();

return _retainedMessagesManager?.UpdateMessage(string.Empty, retainedMessage);
return _retainedMessagesManager?.UpdateMessage(string.Empty, retainedMessage, new Lazy<MqttApplicationMessage>(retainedMessage));
}

protected override void Dispose(bool disposing)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
namespace MQTTnet.Tests.Mockups
{
public record MqttApplicationMessageReceived(string ClientId, MqttApplicationMessage ApplicationMessage);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ namespace MQTTnet.Tests.Mockups
{
public sealed class TestApplicationMessageReceivedHandler
{
readonly List<MqttApplicationMessageReceivedEventArgs> _receivedEventArgs = new List<MqttApplicationMessageReceivedEventArgs>();
readonly IMqttClient _mqttClient;
readonly List<MqttApplicationMessageReceived> _receivedEventArgs = new();


public TestApplicationMessageReceivedHandler(IMqttClient mqttClient)
{
ArgumentNullException.ThrowIfNull(mqttClient);

mqttClient.ApplicationMessageReceivedAsync += OnApplicationMessageReceivedAsync;
_mqttClient = mqttClient;
}

public int Count
Expand All @@ -34,7 +37,7 @@ public int Count
}
}

public List<MqttApplicationMessageReceivedEventArgs> ReceivedEventArgs
public List<MqttApplicationMessageReceived> ReceivedEventArgs
{
get
{
Expand Down Expand Up @@ -72,7 +75,11 @@ Task OnApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventArgs e
{
lock (_receivedEventArgs)
{
_receivedEventArgs.Add(eventArgs);
var applicationMessage = _mqttClient.Options.ReceivedApplicationMessageQueueable
? eventArgs.ApplicationMessage
: eventArgs.ApplicationMessage.Clone();

_receivedEventArgs.Add(new MqttApplicationMessageReceived(eventArgs.ClientId, applicationMessage));
}

return CompletedTask.Instance;
Expand Down
13 changes: 12 additions & 1 deletion Source/MQTTnet.Tests/Server/QoS_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,24 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_2()

[TestMethod]
public async Task Preserve_Message_Order_For_Queued_Messages()
{
await Preserve_Message_Order_For_Queued_Messages(receivedPublishPacketQueueable: true);
}

[TestMethod]
public async Task Preserve_Message_Order_For_Queued_Messages_NoQueue()
{
await Preserve_Message_Order_For_Queued_Messages(receivedPublishPacketQueueable: false);
}

private async Task Preserve_Message_Order_For_Queued_Messages(bool receivedPublishPacketQueueable)
{
using (var testEnvironment = CreateTestEnvironment())
{
var server = await testEnvironment.StartServer(o => o.WithPersistentSessions());

// Create a session which will contain the messages.
var dummyClient = await testEnvironment.ConnectClient(o => o.WithClientId("A").WithCleanSession(false));
var dummyClient = await testEnvironment.ConnectClient(o => o.WithClientId("A").WithCleanSession(false).WithReceivedApplicationMessageQueueable(receivedPublishPacketQueueable));
await dummyClient.SubscribeAsync("#", MqttQualityOfServiceLevel.AtLeastOnce);
dummyClient.Dispose();

Expand Down
25 changes: 25 additions & 0 deletions Source/MQTTnet/MqttApplicationMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;

namespace MQTTnet
{
Expand Down Expand Up @@ -142,5 +143,29 @@ public ArraySegment<byte> PayloadSegment
/// Hint: MQTT 5 feature only.
/// </summary>
public List<MqttUserProperty> UserProperties { get; set; }

/// <summary>
/// Deep clone all fields.
/// </summary>
/// <returns></returns>
public MqttApplicationMessage Clone()
{
return new MqttApplicationMessage
{
ContentType = ContentType,
CorrelationData = CorrelationData == default ? default : CorrelationData.ToArray(),
Dup = Dup,
MessageExpiryInterval = MessageExpiryInterval,
Payload = Payload.IsEmpty ? default : new ReadOnlySequence<byte>(Payload.ToArray()),
PayloadFormatIndicator = PayloadFormatIndicator,
QualityOfServiceLevel = QualityOfServiceLevel,
Retain = Retain,
ResponseTopic = ResponseTopic,
Topic = Topic,
UserProperties = UserProperties?.Select(u => u.Clone()).ToList(),
SubscriptionIdentifiers = SubscriptionIdentifiers?.ToList(),
TopicAlias = TopicAlias
};
}
}
}
Loading

0 comments on commit 9ce7ee3

Please sign in to comment.