Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tgracchus committed Oct 11, 2023
1 parent b5c166c commit 3db335d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 85 deletions.
122 changes: 49 additions & 73 deletions src/main/java/com/hivemq/codec/decoder/MQTTMessageDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.List;

import static com.hivemq.mqtt.message.MessageType.CONNECT;
import static com.hivemq.mqtt.message.MessageType.PUBLISH;

/**
* @author Dominik Obermaier
Expand Down Expand Up @@ -123,30 +124,48 @@ protected void decode(
final int packetSize = fixedHeaderSize + remainingLength;

final MessageType messageType = getMessageType(fixedHeader);
final @Nullable Message message;

if (messageType == CONNECT) {
message = handleConnect(buf, clientConnectionContext, fixedHeader, packetSize, remainingLength);
} else {
message =
handledMessage(buf, clientConnectionContext, fixedHeader, messageType, packetSize, remainingLength);
}
final Message message = messageType == CONNECT ?
handleConnect(buf, clientConnectionContext, fixedHeader, packetSize, remainingLength) :
handleMessage(buf, clientConnectionContext, fixedHeader, messageType, packetSize, remainingLength);
if (message == null) {
buf.clear();
return;
}
globalMQTTMessageCounter.countInbound(message);
out.add(message);
}

private Message handleConnect(
private @Nullable Message handleConnect(
final @NotNull ByteBuf buf,
final @NotNull ClientConnectionContext clientConnectionContext,
final byte fixedHeader,
final int packetSize,
final int remainingLength) {

if (isAlreadyConnected(clientConnectionContext)) {
//this is the message size HiveMQ allows for incoming messages
if (packetSize > mqttConfig.maxPacketSize()) {
//connack with PACKET_TOO_LARGE for Mqtt5
final ProtocolVersion protocolVersion = connectDecoder.decodeProtocolVersion(clientConnectionContext, buf);
if (protocolVersion == ProtocolVersion.MQTTv5) {
mqttConnacker.connackError(clientConnectionContext.getChannel(),
"A client (IP: {}) connect packet exceeded the maximum permissible size.",
"Sent CONNECT exceeded the maximum permissible size",
Mqtt5ConnAckReasonCode.PACKET_TOO_LARGE,
ReasonStrings.CONNACK_PACKET_TOO_LARGE);
} else { //force channel close for Mqtt3.1 and Mqtt3.1.1
mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(),
"A client (IP: {}) sent a message, that was bigger than the maximum message size. Disconnecting client.",
"Sent a message that was bigger than the maximum size",
Mqtt5DisconnectReasonCode.PACKET_TOO_LARGE,
ReasonStrings.DISCONNECT_PACKET_TOO_LARGE_MESSAGE,
Mqtt5UserProperties.NO_USER_PROPERTIES,
false,
true);
}
return null;
}

// Check if the client is already connected
if (clientConnectionContext.getProtocolVersion() != null) {
mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(),
"A client (IP: {}) sent second CONNECT message. This is not allowed. Disconnecting client.",
"Sent second CONNECT message",
Expand All @@ -160,54 +179,15 @@ private Message handleConnect(
return null;
}

final ByteBuf messageBuffer = readRestOfMessage(buf, remainingLength);

final @Nullable ProtocolVersion protocolVersion =
connectDecoder.decodeProtocolVersion(clientConnectionContext, messageBuffer);
if (protocolVersion == null) {
return null;
}

//this is the message size HiveMQ allows for incoming messages
if (packetSize > mqttConfig.maxPacketSize()) {
//force channel close for Mqtt3.1, Mqtt3.1.1 and null (before connect)
if (protocolVersion == ProtocolVersion.MQTTv5) {
connackPacketTooLarge(clientConnectionContext);
} else {
disconnectPacketTooLarge(clientConnectionContext, true);
}
return null;
}
//We're slicing the buffer to the exact MQTT message size so we don't have to pass the actual length around
final ByteBuf messageBuffer = buf.readSlice(remainingLength);
//We mark the end of the message
buf.markReaderIndex();

globalMQTTMessageCounter.countInboundTraffic(packetSize);
return connectDecoder.decode(clientConnectionContext, messageBuffer, fixedHeader);
}

private void disconnectPacketTooLarge(ClientConnectionContext clientConnectionContext, boolean forceClose) {
mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(),
"A client (IP: {}) sent a message, that was bigger than the maximum message size. Disconnecting client.",
"Sent a message that was bigger than the maximum size",
Mqtt5DisconnectReasonCode.PACKET_TOO_LARGE,
ReasonStrings.DISCONNECT_PACKET_TOO_LARGE_MESSAGE,
Mqtt5UserProperties.NO_USER_PROPERTIES,
false,
forceClose);
}


private void connackPacketTooLarge(final @NotNull ClientConnectionContext clientConnectionContext) {
mqttConnacker.connackError(clientConnectionContext.getChannel(),
"A client (IP: {}) connect packet exceeded the maximum permissible size.",
"Sent CONNECT exceeded the maximum permissible size",
Mqtt5ConnAckReasonCode.PACKET_TOO_LARGE,
ReasonStrings.CONNACK_PACKET_TOO_LARGE);
}

private static boolean isAlreadyConnected(ClientConnectionContext clientConnectionContext) {
return clientConnectionContext.getProtocolVersion() != null;
}

private Message handledMessage(
private @Nullable Message handleMessage(
final @NotNull ByteBuf buf,
final @NotNull ClientConnectionContext clientConnectionContext,
final byte fixedHeader,
Expand All @@ -222,20 +202,29 @@ private Message handledMessage(

//force channel close for Mqtt3.1, Mqtt3.1.1 and null (before connect)
final boolean forceClose = protocolVersion != ProtocolVersion.MQTTv5;
disconnectPacketTooLarge(clientConnectionContext, forceClose);
mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(),
"A client (IP: {}) sent a message, that was bigger than the maximum message size. Disconnecting client.",
"Sent a message that was bigger than the maximum size",
Mqtt5DisconnectReasonCode.PACKET_TOO_LARGE,
ReasonStrings.DISCONNECT_PACKET_TOO_LARGE_MESSAGE,
Mqtt5UserProperties.NO_USER_PROPERTIES,
false,
forceClose);
return null;
}

if (isNotConnected(protocolVersion)) {
// Check if client is connected
if (protocolVersion == null) {
mqttServerDisconnector.logAndClose(clientConnectionContext.getChannel(),
"A client (IP: {}) sent other message before CONNECT. Disconnecting client.",
"Sent other message before CONNECT");
return null;
}

globalMQTTMessageCounter.countInboundTraffic(packetSize);

final ByteBuf messageBuffer = readRestOfMessage(buf, remainingLength);
//We're slicing the buffer to the exact MQTT message size so we don't have to pass the actual length around
final ByteBuf messageBuffer = buf.readSlice(remainingLength);
//We mark the end of the message
buf.markReaderIndex();

final MqttDecoder<?> decoder = mqttDecoders.decoder(messageType, protocolVersion);
if (decoder != null) {
Expand Down Expand Up @@ -269,7 +258,6 @@ private Message handledMessage(
"Sent a UNSUBACK message",
Mqtt5DisconnectReasonCode.PROTOCOL_ERROR,
ReasonStrings.DISCONNECT_UNSUBACK_RECEIVED);
;
return null;
case PINGRESP:
mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(),
Expand All @@ -296,18 +284,6 @@ private Message handledMessage(
}
}

private static boolean isNotConnected(ProtocolVersion protocolVersion) {
return protocolVersion == null;
}

private static ByteBuf readRestOfMessage(ByteBuf buf, int remainingLength) {
//We're slicing the buffer to the exact MQTT message size so we don't have to pass the actual length around
final ByteBuf messageBuffer = buf.readSlice(remainingLength);
//We mark the end of the message
buf.markReaderIndex();
return messageBuffer;
}

private static int getFixedHeaderSize(final int remainingLength) {

// 2 = 1 byte fixed header + 1 byte first byte of remaining length
Expand Down
12 changes: 8 additions & 4 deletions src/main/java/com/hivemq/codec/decoder/MqttConnectDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ public MqttConnectDecoder(


public @Nullable ProtocolVersion decodeProtocolVersion(
final @NotNull ClientConnectionContext clientConnectionContext,
final @NotNull ByteBuf buf) {
final @NotNull ClientConnectionContext clientConnectionContext, final @NotNull ByteBuf buf) {
/*
* It is sufficient to look at the second byte of the variable header (Length LSB) This byte
* indicates how long the following protocol name is going to be. In case of the
Expand All @@ -77,7 +76,7 @@ public MqttConnectDecoder(
// interested in the Length LSB byte
if (buf.readableBytes() < 2) {
mqttConnacker.connackError(clientConnectionContext.getChannel(),
"A client (IP: {}) connected with a packet without protocol version.",
"A client (ID: {}, IP: {}) connected with a packet without protocol version.",
"Sent CONNECT without protocol version",
Mqtt5ConnAckReasonCode.UNSUPPORTED_PROTOCOL_VERSION,
ReasonStrings.CONNACK_UNSUPPORTED_PROTOCOL_VERSION);
Expand Down Expand Up @@ -114,6 +113,7 @@ public MqttConnectDecoder(

clientConnectionContext.setProtocolVersion(protocolVersion);
clientConnectionContext.setConnectReceivedTimestamp(System.currentTimeMillis());

return protocolVersion;
}

Expand All @@ -122,7 +122,11 @@ public MqttConnectDecoder(
final @NotNull ClientConnectionContext clientConnectionContext,
final @NotNull ByteBuf buf,
final byte fixedHeader) {
final ProtocolVersion protocolVersion = clientConnectionContext.getProtocolVersion();

final ProtocolVersion protocolVersion = decodeProtocolVersion(clientConnectionContext, buf);
if (protocolVersion == null) {
return null;
}
if (protocolVersion == ProtocolVersion.MQTTv5) {
return mqtt5ConnectDecoder.decode(clientConnectionContext, buf, fixedHeader);
} else if (protocolVersion == ProtocolVersion.MQTTv3_1_1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import static org.mockito.Mockito.verify;

public class MqttConnectDecoderTest {
private static final byte FIXED_HEADER = 0b0001_0000;
private @NotNull MqttConnacker mqttConnacker;
private @NotNull Channel channel;
private @NotNull MqttConnectDecoder decoder;
Expand All @@ -61,7 +62,7 @@ public void setUp() throws Exception {
@Test
public void test_no_protocol_version() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{1});
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
verify(mqttConnacker).connackError(eq(channel),
anyString(),
anyString(),
Expand All @@ -72,7 +73,7 @@ public void test_no_protocol_version() {
@Test
public void test_invalid_protocol_version_not_enough_readable_bytes() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 1, 2, 3, 4});
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
verify(mqttConnacker).connackError(eq(channel),
anyString(),
anyString(),
Expand All @@ -84,7 +85,7 @@ public void test_invalid_protocol_version_not_enough_readable_bytes() {
public void test_valid_mqtt5_version() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 'M', 'Q', 'T', 'T', 5});
try {
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
} catch (final Exception e) {
//ignore because mqtt5ConnectDecoder not tested here
}
Expand All @@ -96,23 +97,23 @@ public void test_valid_mqtt5_version() {
@Test
public void test_valid_mqtt3_1_1_version() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 'M', 'Q', 'T', 'T', 4});
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
assertSame(ProtocolVersion.MQTTv3_1_1, clientConnection.getProtocolVersion());
assertNotNull(ClientConnection.of(channel).getConnectReceivedTimestamp());
}

@Test
public void test_valid_mqtt3_1_version() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 6, 'M', 'Q', 'T', 'T', 3, 1});
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
assertSame(ProtocolVersion.MQTTv3_1, clientConnection.getProtocolVersion());
assertNotNull(ClientConnection.of(channel).getConnectReceivedTimestamp());
}

@Test
public void test_invalid_protocol_version_mqtt_5() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 5});
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
verify(mqttConnacker).connackError(eq(channel),
anyString(),
anyString(),
Expand All @@ -123,7 +124,7 @@ public void test_invalid_protocol_version_mqtt_5() {
@Test
public void test_invalid_protocol_version_7() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 'M', 'Q', 'T', 'T', 7});
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
verify(mqttConnacker).connackError(eq(channel),
anyString(),
anyString(),
Expand All @@ -134,7 +135,7 @@ public void test_invalid_protocol_version_7() {
@Test
public void test_invalid_protocol_version_length() {
final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 5, 'M', 'Q', 'T', 'T', 7});
decoder.decodeProtocolVersion(clientConnection, buf);
decoder.decode(clientConnection, buf, FIXED_HEADER);
verify(mqttConnacker).connackError(eq(channel),
anyString(),
anyString(),
Expand Down

0 comments on commit 3db335d

Please sign in to comment.