diff --git a/src/main/java/com/hivemq/codec/decoder/MQTTMessageDecoder.java b/src/main/java/com/hivemq/codec/decoder/MQTTMessageDecoder.java index 292657849..c55e16ef5 100644 --- a/src/main/java/com/hivemq/codec/decoder/MQTTMessageDecoder.java +++ b/src/main/java/com/hivemq/codec/decoder/MQTTMessageDecoder.java @@ -37,6 +37,7 @@ import java.util.List; import static com.hivemq.mqtt.message.MessageType.CONNECT; +import static com.hivemq.mqtt.message.MessageType.PUBLISH; /** * @author Dominik Obermaier @@ -123,30 +124,48 @@ protected void decode( final int packetSize = fixedHeaderSize + remainingLength; final MessageType messageType = getMessageType(fixedHeader); - final @Nullable Message message; - - if (messageType == CONNECT) { - message = handleConnect(buf, clientConnectionContext, fixedHeader, packetSize, remainingLength); - } else { - message = - handledMessage(buf, clientConnectionContext, fixedHeader, messageType, packetSize, remainingLength); - } + final Message message = messageType == CONNECT ? + handleConnect(buf, clientConnectionContext, fixedHeader, packetSize, remainingLength) : + handleMessage(buf, clientConnectionContext, fixedHeader, messageType, packetSize, remainingLength); if (message == null) { buf.clear(); return; } - globalMQTTMessageCounter.countInbound(message); out.add(message); } - private Message handleConnect( + private @Nullable Message handleConnect( final @NotNull ByteBuf buf, final @NotNull ClientConnectionContext clientConnectionContext, final byte fixedHeader, final int packetSize, final int remainingLength) { - if (isAlreadyConnected(clientConnectionContext)) { + //this is the message size HiveMQ allows for incoming messages + if (packetSize > mqttConfig.maxPacketSize()) { + //connack with PACKET_TOO_LARGE for Mqtt5 + final ProtocolVersion protocolVersion = connectDecoder.decodeProtocolVersion(clientConnectionContext, buf); + if (protocolVersion == ProtocolVersion.MQTTv5) { + mqttConnacker.connackError(clientConnectionContext.getChannel(), + "A client (IP: {}) connect packet exceeded the maximum permissible size.", + "Sent CONNECT exceeded the maximum permissible size", + Mqtt5ConnAckReasonCode.PACKET_TOO_LARGE, + ReasonStrings.CONNACK_PACKET_TOO_LARGE); + } else { //force channel close for Mqtt3.1 and Mqtt3.1.1 + mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(), + "A client (IP: {}) sent a message, that was bigger than the maximum message size. Disconnecting client.", + "Sent a message that was bigger than the maximum size", + Mqtt5DisconnectReasonCode.PACKET_TOO_LARGE, + ReasonStrings.DISCONNECT_PACKET_TOO_LARGE_MESSAGE, + Mqtt5UserProperties.NO_USER_PROPERTIES, + false, + true); + } + return null; + } + + // Check if the client is already connected + if (clientConnectionContext.getProtocolVersion() != null) { mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(), "A client (IP: {}) sent second CONNECT message. This is not allowed. Disconnecting client.", "Sent second CONNECT message", @@ -160,54 +179,15 @@ private Message handleConnect( return null; } - final ByteBuf messageBuffer = readRestOfMessage(buf, remainingLength); - - final @Nullable ProtocolVersion protocolVersion = - connectDecoder.decodeProtocolVersion(clientConnectionContext, messageBuffer); - if (protocolVersion == null) { - return null; - } - - //this is the message size HiveMQ allows for incoming messages - if (packetSize > mqttConfig.maxPacketSize()) { - //force channel close for Mqtt3.1, Mqtt3.1.1 and null (before connect) - if (protocolVersion == ProtocolVersion.MQTTv5) { - connackPacketTooLarge(clientConnectionContext); - } else { - disconnectPacketTooLarge(clientConnectionContext, true); - } - return null; - } + //We're slicing the buffer to the exact MQTT message size so we don't have to pass the actual length around + final ByteBuf messageBuffer = buf.readSlice(remainingLength); + //We mark the end of the message + buf.markReaderIndex(); - globalMQTTMessageCounter.countInboundTraffic(packetSize); return connectDecoder.decode(clientConnectionContext, messageBuffer, fixedHeader); } - private void disconnectPacketTooLarge(ClientConnectionContext clientConnectionContext, boolean forceClose) { - mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(), - "A client (IP: {}) sent a message, that was bigger than the maximum message size. Disconnecting client.", - "Sent a message that was bigger than the maximum size", - Mqtt5DisconnectReasonCode.PACKET_TOO_LARGE, - ReasonStrings.DISCONNECT_PACKET_TOO_LARGE_MESSAGE, - Mqtt5UserProperties.NO_USER_PROPERTIES, - false, - forceClose); - } - - - private void connackPacketTooLarge(final @NotNull ClientConnectionContext clientConnectionContext) { - mqttConnacker.connackError(clientConnectionContext.getChannel(), - "A client (IP: {}) connect packet exceeded the maximum permissible size.", - "Sent CONNECT exceeded the maximum permissible size", - Mqtt5ConnAckReasonCode.PACKET_TOO_LARGE, - ReasonStrings.CONNACK_PACKET_TOO_LARGE); - } - - private static boolean isAlreadyConnected(ClientConnectionContext clientConnectionContext) { - return clientConnectionContext.getProtocolVersion() != null; - } - - private Message handledMessage( + private @Nullable Message handleMessage( final @NotNull ByteBuf buf, final @NotNull ClientConnectionContext clientConnectionContext, final byte fixedHeader, @@ -222,20 +202,29 @@ private Message handledMessage( //force channel close for Mqtt3.1, Mqtt3.1.1 and null (before connect) final boolean forceClose = protocolVersion != ProtocolVersion.MQTTv5; - disconnectPacketTooLarge(clientConnectionContext, forceClose); + mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(), + "A client (IP: {}) sent a message, that was bigger than the maximum message size. Disconnecting client.", + "Sent a message that was bigger than the maximum size", + Mqtt5DisconnectReasonCode.PACKET_TOO_LARGE, + ReasonStrings.DISCONNECT_PACKET_TOO_LARGE_MESSAGE, + Mqtt5UserProperties.NO_USER_PROPERTIES, + false, + forceClose); return null; } - if (isNotConnected(protocolVersion)) { + // Check if client is connected + if (protocolVersion == null) { mqttServerDisconnector.logAndClose(clientConnectionContext.getChannel(), "A client (IP: {}) sent other message before CONNECT. Disconnecting client.", "Sent other message before CONNECT"); return null; } - globalMQTTMessageCounter.countInboundTraffic(packetSize); - - final ByteBuf messageBuffer = readRestOfMessage(buf, remainingLength); + //We're slicing the buffer to the exact MQTT message size so we don't have to pass the actual length around + final ByteBuf messageBuffer = buf.readSlice(remainingLength); + //We mark the end of the message + buf.markReaderIndex(); final MqttDecoder decoder = mqttDecoders.decoder(messageType, protocolVersion); if (decoder != null) { @@ -269,7 +258,6 @@ private Message handledMessage( "Sent a UNSUBACK message", Mqtt5DisconnectReasonCode.PROTOCOL_ERROR, ReasonStrings.DISCONNECT_UNSUBACK_RECEIVED); - ; return null; case PINGRESP: mqttServerDisconnector.disconnect(clientConnectionContext.getChannel(), @@ -296,18 +284,6 @@ private Message handledMessage( } } - private static boolean isNotConnected(ProtocolVersion protocolVersion) { - return protocolVersion == null; - } - - private static ByteBuf readRestOfMessage(ByteBuf buf, int remainingLength) { - //We're slicing the buffer to the exact MQTT message size so we don't have to pass the actual length around - final ByteBuf messageBuffer = buf.readSlice(remainingLength); - //We mark the end of the message - buf.markReaderIndex(); - return messageBuffer; - } - private static int getFixedHeaderSize(final int remainingLength) { // 2 = 1 byte fixed header + 1 byte first byte of remaining length diff --git a/src/main/java/com/hivemq/codec/decoder/MqttConnectDecoder.java b/src/main/java/com/hivemq/codec/decoder/MqttConnectDecoder.java index 05b5a50c5..dc1871a2e 100644 --- a/src/main/java/com/hivemq/codec/decoder/MqttConnectDecoder.java +++ b/src/main/java/com/hivemq/codec/decoder/MqttConnectDecoder.java @@ -60,8 +60,7 @@ public MqttConnectDecoder( public @Nullable ProtocolVersion decodeProtocolVersion( - final @NotNull ClientConnectionContext clientConnectionContext, - final @NotNull ByteBuf buf) { + final @NotNull ClientConnectionContext clientConnectionContext, final @NotNull ByteBuf buf) { /* * It is sufficient to look at the second byte of the variable header (Length LSB) This byte * indicates how long the following protocol name is going to be. In case of the @@ -77,7 +76,7 @@ public MqttConnectDecoder( // interested in the Length LSB byte if (buf.readableBytes() < 2) { mqttConnacker.connackError(clientConnectionContext.getChannel(), - "A client (IP: {}) connected with a packet without protocol version.", + "A client (ID: {}, IP: {}) connected with a packet without protocol version.", "Sent CONNECT without protocol version", Mqtt5ConnAckReasonCode.UNSUPPORTED_PROTOCOL_VERSION, ReasonStrings.CONNACK_UNSUPPORTED_PROTOCOL_VERSION); @@ -114,6 +113,7 @@ public MqttConnectDecoder( clientConnectionContext.setProtocolVersion(protocolVersion); clientConnectionContext.setConnectReceivedTimestamp(System.currentTimeMillis()); + return protocolVersion; } @@ -122,7 +122,11 @@ public MqttConnectDecoder( final @NotNull ClientConnectionContext clientConnectionContext, final @NotNull ByteBuf buf, final byte fixedHeader) { - final ProtocolVersion protocolVersion = clientConnectionContext.getProtocolVersion(); + + final ProtocolVersion protocolVersion = decodeProtocolVersion(clientConnectionContext, buf); + if (protocolVersion == null) { + return null; + } if (protocolVersion == ProtocolVersion.MQTTv5) { return mqtt5ConnectDecoder.decode(clientConnectionContext, buf, fixedHeader); } else if (protocolVersion == ProtocolVersion.MQTTv3_1_1) { diff --git a/src/test/java/com/hivemq/codec/decoder/MqttConnectDecoderTest.java b/src/test/java/com/hivemq/codec/decoder/MqttConnectDecoderTest.java index 44616a752..6ed593f13 100644 --- a/src/test/java/com/hivemq/codec/decoder/MqttConnectDecoderTest.java +++ b/src/test/java/com/hivemq/codec/decoder/MqttConnectDecoderTest.java @@ -40,6 +40,7 @@ import static org.mockito.Mockito.verify; public class MqttConnectDecoderTest { + private static final byte FIXED_HEADER = 0b0001_0000; private @NotNull MqttConnacker mqttConnacker; private @NotNull Channel channel; private @NotNull MqttConnectDecoder decoder; @@ -61,7 +62,7 @@ public void setUp() throws Exception { @Test public void test_no_protocol_version() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{1}); - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); verify(mqttConnacker).connackError(eq(channel), anyString(), anyString(), @@ -72,7 +73,7 @@ public void test_no_protocol_version() { @Test public void test_invalid_protocol_version_not_enough_readable_bytes() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 1, 2, 3, 4}); - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); verify(mqttConnacker).connackError(eq(channel), anyString(), anyString(), @@ -84,7 +85,7 @@ public void test_invalid_protocol_version_not_enough_readable_bytes() { public void test_valid_mqtt5_version() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 'M', 'Q', 'T', 'T', 5}); try { - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); } catch (final Exception e) { //ignore because mqtt5ConnectDecoder not tested here } @@ -96,7 +97,7 @@ public void test_valid_mqtt5_version() { @Test public void test_valid_mqtt3_1_1_version() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 'M', 'Q', 'T', 'T', 4}); - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); assertSame(ProtocolVersion.MQTTv3_1_1, clientConnection.getProtocolVersion()); assertNotNull(ClientConnection.of(channel).getConnectReceivedTimestamp()); } @@ -104,7 +105,7 @@ public void test_valid_mqtt3_1_1_version() { @Test public void test_valid_mqtt3_1_version() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 6, 'M', 'Q', 'T', 'T', 3, 1}); - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); assertSame(ProtocolVersion.MQTTv3_1, clientConnection.getProtocolVersion()); assertNotNull(ClientConnection.of(channel).getConnectReceivedTimestamp()); } @@ -112,7 +113,7 @@ public void test_valid_mqtt3_1_version() { @Test public void test_invalid_protocol_version_mqtt_5() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 5}); - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); verify(mqttConnacker).connackError(eq(channel), anyString(), anyString(), @@ -123,7 +124,7 @@ public void test_invalid_protocol_version_mqtt_5() { @Test public void test_invalid_protocol_version_7() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 4, 'M', 'Q', 'T', 'T', 7}); - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); verify(mqttConnacker).connackError(eq(channel), anyString(), anyString(), @@ -134,7 +135,7 @@ public void test_invalid_protocol_version_7() { @Test public void test_invalid_protocol_version_length() { final ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{0, 5, 'M', 'Q', 'T', 'T', 7}); - decoder.decodeProtocolVersion(clientConnection, buf); + decoder.decode(clientConnection, buf, FIXED_HEADER); verify(mqttConnacker).connackError(eq(channel), anyString(), anyString(),