diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java index d6b98a60f..879ca04be 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/common/Constants.java @@ -179,6 +179,7 @@ public final class Constants { public static final String WEBSOCKET_FRAME_BLOCKING_HANDLER = "WEBSOCKET_FRAME_BLOCKING_HANDLER"; public static final int WEBSOCKET_STATUS_CODE_NORMAL_CLOSURE = 1000; public static final int WEBSOCKET_STATUS_CODE_GOING_AWAY = 1001; + public static final int WEBSOCKET_STATUS_CODE_PROTOCOL_ERROR = 1002; public static final int WEBSOCKET_STATUS_CODE_ABNORMAL_CLOSURE = 1006; public static final int WEBSOCKET_STATUS_CODE_UNEXPECTED_CONDITION = 1011; diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/HttpWsConnectorFactory.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/HttpWsConnectorFactory.java index 27380c1f3..8ee03e78f 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/HttpWsConnectorFactory.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/HttpWsConnectorFactory.java @@ -22,7 +22,7 @@ import org.wso2.transport.http.netty.config.ListenerConfiguration; import org.wso2.transport.http.netty.config.SenderConfiguration; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; -import org.wso2.transport.http.netty.contract.websocket.WsClientConnectorConfig; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.listener.ServerBootstrapConfiguration; import java.util.Map; @@ -57,7 +57,7 @@ HttpClientConnector createHttpClientConnector(Map transportPrope * @param clientConnectorConfig Properties to create a client connector. * @return WebSocketClientConnector. */ - WebSocketClientConnector createWsClientConnector(WsClientConnectorConfig clientConnectorConfig); + WebSocketClientConnector createWsClientConnector(WebSocketClientConnectorConfig clientConnectorConfig); /** * Shutdown all the server channels and the accepted channels. It also shutdown all the eventloop groups. diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WsClientConnectorConfig.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java similarity index 84% rename from components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WsClientConnectorConfig.java rename to components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java index a353f5dd2..6d4406ab8 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WsClientConnectorConfig.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketClientConnectorConfig.java @@ -19,35 +19,35 @@ package org.wso2.transport.http.netty.contract.websocket; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; + import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Map; /** - * Sender configuration for WebSocket client connector. + * Configuration for WebSocket client connector. */ -public class WsClientConnectorConfig { +public class WebSocketClientConnectorConfig { private final String remoteAddress; private List subProtocols; private int idleTimeoutInSeconds; private boolean autoRead; - private final Map headers = new HashMap<>(); - - public WsClientConnectorConfig(String remoteAddress) { - this.remoteAddress = remoteAddress; - this.idleTimeoutInSeconds = -1; - this.autoRead = true; + private final HttpHeaders headers; + public WebSocketClientConnectorConfig(String remoteAddress) { + this(remoteAddress, null, -1, true); } - public WsClientConnectorConfig(String remoteAddress, List subProtocols, - int idleTimeoutInSeconds, boolean autoRead) { + public WebSocketClientConnectorConfig(String remoteAddress, List subProtocols, + int idleTimeoutInSeconds, boolean autoRead) { this.remoteAddress = remoteAddress; this.subProtocols = subProtocols; this.idleTimeoutInSeconds = idleTimeoutInSeconds; this.autoRead = autoRead; + this.headers = new DefaultHttpHeaders(); } /** @@ -126,7 +126,7 @@ public String getRemoteAddress() { * @param headers Headers map. */ public void addHeaders(Map headers) { - this.headers.putAll(headers); + headers.forEach(this.headers::add); } /** @@ -136,7 +136,7 @@ public void addHeaders(Map headers) { * @param value Value of the header. */ public void addHeader(String key, String value) { - this.headers.put(key, value); + this.headers.add(key, value); } /** @@ -144,7 +144,7 @@ public void addHeader(String key, String value) { * * @return all the headers as a map. */ - public Map getHeaders() { + public HttpHeaders getHeaders() { return headers; } @@ -155,7 +155,7 @@ public Map getHeaders() { * @return true of the header is present. */ public boolean containsHeader(String key) { - return headers.containsKey(key); + return headers.contains(key); } /** diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java index ca86152c8..3863f89bb 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contract/websocket/WebSocketControlMessage.java @@ -19,12 +19,10 @@ package org.wso2.transport.http.netty.contract.websocket; -import java.nio.ByteBuffer; - /** * This message contains the details of WebSocket bong message. */ -public interface WebSocketControlMessage extends WebSocketMessage { +public interface WebSocketControlMessage extends WebSocketBinaryMessage { /** * Get the control signal. @@ -32,18 +30,4 @@ public interface WebSocketControlMessage extends WebSocketMessage { * @return the control signal as a {@link WebSocketControlSignal}. */ WebSocketControlSignal getControlSignal(); - - /** - * Get the payload of the control signal. - * - * @return the payload of the control signal. - */ - ByteBuffer getPayload(); - - /** - * Get the binary data as a byte array. - * - * @return the binary data as a byte array. - */ - byte[] getByteArray(); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpWsConnectorFactory.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpWsConnectorFactory.java index e58076bcd..8f3495cab 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpWsConnectorFactory.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/DefaultHttpWsConnectorFactory.java @@ -31,7 +31,7 @@ import org.wso2.transport.http.netty.contract.HttpWsConnectorFactory; import org.wso2.transport.http.netty.contract.ServerConnector; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; -import org.wso2.transport.http.netty.contract.websocket.WsClientConnectorConfig; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contractimpl.websocket.DefaultWebSocketClientConnector; import org.wso2.transport.http.netty.listener.ServerBootstrapConfiguration; import org.wso2.transport.http.netty.listener.ServerConnectorBootstrap; @@ -96,7 +96,7 @@ public HttpClientConnector createHttpClientConnector( } @Override - public WebSocketClientConnector createWsClientConnector(WsClientConnectorConfig clientConnectorConfig) { + public WebSocketClientConnector createWsClientConnector(WebSocketClientConnectorConfig clientConnectorConfig) { return new DefaultWebSocketClientConnector(clientConnectorConfig, clientGroup); } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java index 47145c329..7304c093f 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketClientConnector.java @@ -20,13 +20,12 @@ package org.wso2.transport.http.netty.contractimpl.websocket; import io.netty.channel.EventLoopGroup; +import io.netty.handler.codec.http.HttpHeaders; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; -import org.wso2.transport.http.netty.contract.websocket.WsClientConnectorConfig; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.sender.websocket.WebSocketClient; -import java.util.Map; - /** * Implementation of WebSocket client connector. */ @@ -35,11 +34,11 @@ public class DefaultWebSocketClientConnector implements WebSocketClientConnector private final String remoteUrl; private final String subProtocols; private final int idleTimeout; - private final Map customHeaders; + private final HttpHeaders customHeaders; private final EventLoopGroup wsClientEventLoopGroup; private final boolean autoRead; - public DefaultWebSocketClientConnector(WsClientConnectorConfig clientConnectorConfig, + public DefaultWebSocketClientConnector(WebSocketClientConnectorConfig clientConnectorConfig, EventLoopGroup wsClientEventLoopGroup) { this.remoteUrl = clientConnectorConfig.getRemoteAddress(); this.subProtocols = clientConnectorConfig.getSubProtocolsAsCSV(); diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketMessage.java index dd70f4b1a..7b4d3c143 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/DefaultWebSocketMessage.java @@ -36,7 +36,7 @@ public class DefaultWebSocketMessage implements WebSocketMessage { protected boolean secureConnection; protected boolean isServerMessage; protected WebSocketConnection webSocketConnection; - protected String sessionlID; + protected String sessionID; public void setProperty(String key, Object value) { properties.put(key, value); @@ -54,8 +54,8 @@ public Map getProperties() { return properties; } - public void setSessionlID(String sessionlID) { - this.sessionlID = sessionlID; + public void setSessionID(String sessionID) { + this.sessionID = sessionID; } public void setTarget(String target) { @@ -76,7 +76,7 @@ public String getListenerInterface() { return listenerInterface; } - public void setIsConnectionSecured(boolean isConnectionSecured) { + public void setIsSecureConnection(boolean isConnectionSecured) { this.secureConnection = isConnectionSecured; } @@ -105,6 +105,6 @@ public WebSocketConnection getWebSocketConnection() { @Override public String getSessionID() { - return sessionlID; + return sessionID; } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java index a275f84ac..58fe62c82 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/WebSocketInboundFrameHandler.java @@ -267,9 +267,9 @@ private void notifyIdleTimeout() throws WebSocketConnectorException { private void setupCommonProperties(DefaultWebSocketMessage webSocketMessage) { webSocketMessage.setTarget(target); webSocketMessage.setListenerInterface(interfaceId); - webSocketMessage.setIsConnectionSecured(securedConnection); + webSocketMessage.setIsSecureConnection(securedConnection); webSocketMessage.setWebSocketConnection(webSocketConnection); - webSocketMessage.setSessionlID(webSocketConnection.getId()); + webSocketMessage.setSessionID(webSocketConnection.getId()); webSocketMessage.setIsServerMessage(isServer); webSocketMessage.setProperty(Constants.LISTENER_PORT, ((InetSocketAddress) ctx.channel().localAddress()).getPort()); diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketControlMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketControlMessage.java index 0b38b6343..41180a71e 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketControlMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketControlMessage.java @@ -21,45 +21,23 @@ import org.wso2.transport.http.netty.contract.websocket.WebSocketControlMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketControlSignal; -import org.wso2.transport.http.netty.contractimpl.websocket.DefaultWebSocketMessage; import java.nio.ByteBuffer; /** * Implementation of WebSocket control message. */ -public class DefaultWebSocketControlMessage extends DefaultWebSocketMessage implements WebSocketControlMessage { +public class DefaultWebSocketControlMessage extends DefaultWebSocketBinaryMessage implements WebSocketControlMessage { private final WebSocketControlSignal controlSignal; - private final ByteBuffer buffer; public DefaultWebSocketControlMessage(WebSocketControlSignal controlSignal, ByteBuffer buffer) { + super(buffer, true); this.controlSignal = controlSignal; - this.buffer = buffer; } @Override public WebSocketControlSignal getControlSignal() { return controlSignal; } - - @Override - public byte[] getByteArray() { - byte[] bytes; - if (buffer.hasArray()) { - bytes = buffer.array(); - } else { - int remaining = buffer.remaining(); - bytes = new byte[remaining]; - for (int i = 0; i < remaining; i++) { - bytes[i] = buffer.get(); - } - } - return bytes; - } - - @Override - public ByteBuffer getPayload() { - return buffer; - } } diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java index eb872307a..6eba145bf 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/contractimpl/websocket/message/DefaultWebSocketInitMessage.java @@ -30,7 +30,6 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory; import io.netty.handler.timeout.IdleStateHandler; @@ -67,7 +66,7 @@ public DefaultWebSocketInitMessage(ChannelHandlerContext ctx, ServerConnectorFut this.connectorFuture = connectorFuture; this.secureConnection = ctx.channel().pipeline().get(Constants.SSL_HANDLER) != null; this.httpRequest = httpRequest; - this.sessionlID = WebSocketUtil.getSessionID(ctx); + this.sessionID = WebSocketUtil.getSessionID(ctx); } @Override @@ -118,31 +117,28 @@ public ServerHandshakeFuture handshake(String[] subProtocols, boolean allowExten @Override public ChannelFuture cancelHandshake(int statusCode, String closeReason) { - if (!cancelled && !handshakeStarted) { - try { - int responseStatusCode = statusCode >= 400 && statusCode < 500 ? statusCode : 400; - ChannelFuture responseFuture; - if (closeReason != null) { - ByteBuf content = Unpooled.wrappedBuffer(closeReason.getBytes(StandardCharsets.UTF_8)); - responseFuture = ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, - HttpResponseStatus - .valueOf(responseStatusCode), - content)); - } else { - responseFuture = ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, - HttpResponseStatus.valueOf( - responseStatusCode))); - } - return responseFuture; - } finally { - cancelled = true; - } - } else { - if (cancelled) { - throw new IllegalStateException("Cannot cancel the handshake: handshake already cancelled"); + if (cancelled) { + throw new IllegalStateException("Cannot cancel the handshake: handshake already cancelled"); + } + + if (handshakeStarted) { + throw new IllegalStateException("Cannot cancel the handshake: handshake already started"); + } + + try { + int responseStatusCode = statusCode >= 400 && statusCode < 500 ? statusCode : 400; + ChannelFuture responseFuture; + if (closeReason != null) { + ByteBuf content = Unpooled.wrappedBuffer(closeReason.getBytes(StandardCharsets.UTF_8)); + responseFuture = ctx.writeAndFlush(new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(responseStatusCode), content)); } else { - throw new IllegalStateException("Cannot cancel the handshake: handshake already started"); + responseFuture = ctx.writeAndFlush(new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(responseStatusCode))); } + return responseFuture; + } finally { + cancelled = true; } } @@ -159,53 +155,45 @@ public boolean isHandshakeStarted() { private ServerHandshakeFuture handleHandshake(WebSocketServerHandshaker handshaker, int idleTimeout, HttpHeaders headers) { DefaultServerHandshakeFuture handshakeFuture = new DefaultServerHandshakeFuture(); - if (cancelled) { Throwable e = new IllegalAccessException("Handshake is already cancelled!"); handshakeFuture.notifyError(e); return handshakeFuture; } - - try { - ChannelFuture channelFuture = handshaker.handshake(ctx.channel(), httpRequest, headers, - ctx.channel().newPromise()); - channelFuture.addListener(future -> { + ChannelFuture channelFuture = + handshaker.handshake(ctx.channel(), httpRequest, headers, ctx.channel().newPromise()); + channelFuture.addListener(future -> { + if (future.isSuccess() && future.cause() == null) { String selectedSubProtocol = handshaker.selectedSubprotocol(); WebSocketFramesBlockingHandler blockingHandler = new WebSocketFramesBlockingHandler(); WebSocketInboundFrameHandler frameHandler = new WebSocketInboundFrameHandler(connectorFuture, blockingHandler, true, secureConnection, target, listenerInterface); - - //Replace HTTP handlers with new Handlers for WebSocket in the pipeline - ChannelPipeline pipeline = ctx.pipeline(); - if (idleTimeout > 0) { - pipeline.replace(Constants.IDLE_STATE_HANDLER, Constants.IDLE_STATE_HANDLER, - new IdleStateHandler(idleTimeout, idleTimeout, idleTimeout, - TimeUnit.MILLISECONDS)); - } else { - pipeline.remove(Constants.IDLE_STATE_HANDLER); - } - pipeline.addLast(Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, blockingHandler); - pipeline.addLast(Constants.WEBSOCKET_FRAME_HANDLER, frameHandler); - pipeline.remove(Constants.HTTP_SOURCE_HANDLER); - pipeline.fireChannelActive(); - // Make sure to get WebSocket connection after fireChannelActive + configureFrameHandlingPipeline(idleTimeout, blockingHandler, frameHandler); DefaultWebSocketConnection webSocketConnection = frameHandler.getWebSocketConnection(); webSocketConnection.getDefaultWebSocketSession().setNegotiatedSubProtocol(selectedSubProtocol); handshakeFuture.notifySuccess(frameHandler.getWebSocketConnection()); - }); - handshakeStarted = true; - return handshakeFuture; - } catch (Exception e) { - /* - Code 1002 : indicates that an endpoint is terminating the connection - due to a protocol error. - */ - handshaker.close(ctx.channel(), - new CloseWebSocketFrame(1002, - "Terminating the connection due to a protocol error.")); - handshakeFuture.notifyError(e); - return handshakeFuture; + } else { + handshakeFuture.notifyError(future.cause()); + } + }); + handshakeStarted = true; + return handshakeFuture; + } + + private void configureFrameHandlingPipeline(int idleTimeout, WebSocketFramesBlockingHandler blockingHandler, + WebSocketInboundFrameHandler frameHandler) { + ChannelPipeline pipeline = ctx.pipeline(); + if (idleTimeout > 0) { + pipeline.replace(Constants.IDLE_STATE_HANDLER, Constants.IDLE_STATE_HANDLER, + new IdleStateHandler(idleTimeout, idleTimeout, idleTimeout, + TimeUnit.MILLISECONDS)); + } else { + pipeline.remove(Constants.IDLE_STATE_HANDLER); } + pipeline.addLast(Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, blockingHandler); + pipeline.addLast(Constants.WEBSOCKET_FRAME_HANDLER, frameHandler); + pipeline.remove(Constants.HTTP_SOURCE_HANDLER); + pipeline.fireChannelActive(); } /* Get the URL of the given connection */ diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketServerHandshakeHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketServerHandshakeHandler.java index fbd61d756..2e51eaec6 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketServerHandshakeHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/listener/WebSocketServerHandshakeHandler.java @@ -140,7 +140,7 @@ private void handleWebSocketHandshake(FullHttpRequest fullHttpRequest, ChannelHa initMessage.setIsServerMessage(true); initMessage.setTarget(fullHttpRequest.uri()); initMessage.setListenerInterface(interfaceId); - initMessage.setIsConnectionSecured(ctx.channel().pipeline().get(Constants.SSL_HANDLER) != null); + initMessage.setIsSecureConnection(ctx.channel().pipeline().get(Constants.SSL_HANDLER) != null); initMessage.setHttpCarbonRequest(setupHttpCarbonRequest(fullHttpRequest, ctx)); diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java index c8a3cba6d..4147bc4db 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClient.java @@ -20,14 +20,11 @@ package org.wso2.transport.http.netty.sender.websocket; import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; @@ -48,7 +45,7 @@ import org.wso2.transport.http.netty.listener.WebSocketFramesBlockingHandler; import java.net.URI; -import java.util.Map; +import java.net.URISyntaxException; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; @@ -64,10 +61,9 @@ public class WebSocketClient { private final String url; private final String subProtocols; private final int idleTimeout; - private final Map headers; + private final HttpHeaders headers; private final EventLoopGroup wsClientEventLoopGroup; private final boolean autoRead; - private Channel channel = null; /** * @param url url of the remote endpoint @@ -78,7 +74,7 @@ public class WebSocketClient { * @param autoRead sets the read interest */ public WebSocketClient(String url, String subProtocols, int idleTimeout, EventLoopGroup wsClientEventLoopGroup, - Map headers, boolean autoRead) { + HttpHeaders headers, boolean autoRead) { this.url = url; this.subProtocols = subProtocols; this.idleTimeout = idleTimeout; @@ -102,51 +98,19 @@ public ClientHandshakeFuture handshake() { if (!"ws".equalsIgnoreCase(scheme) && !"wss".equalsIgnoreCase(scheme)) { log.error("Only WS(S) is supported."); - throw new SSLException(""); + throw new URISyntaxException(url, "WebSocket client supports only WS(S) scheme"); } - final boolean ssl = "wss".equalsIgnoreCase(scheme); - final SslContext sslCtx = getSslContext(ssl); - HttpHeaders httpHeaders = new DefaultHttpHeaders(); - - // Adding custom headers to the handshake request. - if (headers != null) { - headers.forEach(httpHeaders::add); - } - WebSocketClientHandshaker webSocketHandshaker = WebSocketClientHandshakerFactory.newHandshaker( - uri, WebSocketVersion.V13, subProtocols, true, httpHeaders); + uri, WebSocketVersion.V13, subProtocols, true, headers); WebSocketFramesBlockingHandler blockingHandler = new WebSocketFramesBlockingHandler(); clientHandshakeHandler = new WebSocketClientHandshakeHandler(webSocketHandshaker, blockingHandler, ssl, autoRead, url, handshakeFuture); - - Bootstrap clientBootstrap = new Bootstrap(); - clientBootstrap.group(wsClientEventLoopGroup).channel(NioSocketChannel.class).handler( - new ChannelInitializer() { - @Override - protected void initChannel(SocketChannel ch) { - ChannelPipeline pipeline = ch.pipeline(); - if (sslCtx != null) { - pipeline.addLast(sslCtx.newHandler(ch.alloc(), host, port)); - } - pipeline.addLast(new HttpClientCodec()); - // Assuming that WebSocket Handshake messages will not be large than 8KB - pipeline.addLast(new HttpObjectAggregator(8192)); - pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE); - if (idleTimeout > 0) { - pipeline.addLast(new IdleStateHandler(idleTimeout, idleTimeout, - idleTimeout, TimeUnit.MILLISECONDS)); - } - pipeline.addLast(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER, clientHandshakeHandler); - } - }); - - channel = clientBootstrap.connect(uri.getHost(), port).sync().channel(); - clientHandshakeHandler - .handshakeFuture().addListener((ChannelFutureListener) clientHandshakeFuture -> { + Bootstrap clientBootstrap = initClientBootstrap(host, port, getSslContext(ssl)); + clientBootstrap.connect(uri.getHost(), port).sync().channel(); + clientHandshakeHandler.handshakeFuture().addListener(clientHandshakeFuture -> { Throwable cause = clientHandshakeFuture.cause(); if (clientHandshakeFuture.isSuccess() && cause == null) { - channel.config().setAutoRead(autoRead); DefaultWebSocketConnection webSocketConnection = clientHandshakeHandler.getInboundFrameHandler().getWebSocketConnection(); String actualSubProtocol = webSocketHandshaker.actualSubprotocol(); @@ -156,16 +120,40 @@ protected void initChannel(SocketChannel ch) { handshakeFuture.notifyError(cause, clientHandshakeHandler.getHttpCarbonResponse()); } }); - } catch (Throwable t) { + } catch (Throwable throwable) { if (clientHandshakeHandler != null) { - handshakeFuture.notifyError(t, clientHandshakeHandler.getHttpCarbonResponse()); + handshakeFuture.notifyError(throwable, clientHandshakeHandler.getHttpCarbonResponse()); } else { - handshakeFuture.notifyError(t, null); + handshakeFuture.notifyError(throwable, null); } } return handshakeFuture; } + private Bootstrap initClientBootstrap(String host, int port, SslContext sslCtx) { + Bootstrap clientBootstrap = new Bootstrap(); + clientBootstrap.group(wsClientEventLoopGroup).channel(NioSocketChannel.class).handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline pipeline = ch.pipeline(); + if (sslCtx != null) { + pipeline.addLast(sslCtx.newHandler(ch.alloc(), host, port)); + } + pipeline.addLast(new HttpClientCodec()); + // Assuming that WebSocket Handshake messages will not be large than 8KB + pipeline.addLast(new HttpObjectAggregator(8192)); + pipeline.addLast(WebSocketClientCompressionHandler.INSTANCE); + if (idleTimeout > 0) { + pipeline.addLast(new IdleStateHandler(idleTimeout, idleTimeout, + idleTimeout, TimeUnit.MILLISECONDS)); + } + pipeline.addLast(Constants.WEBSOCKET_CLIENT_HANDSHAKE_HANDLER, clientHandshakeHandler); + } + }); + return clientBootstrap; + } + private int getPort(URI uri) { String scheme = uri.getScheme(); if (uri.getPort() == -1) { diff --git a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java index 008652edc..0eaa8e17c 100644 --- a/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java +++ b/components/org.wso2.transport.http.netty/src/main/java/org/wso2/transport/http/netty/sender/websocket/WebSocketClientHandshakeHandler.java @@ -94,6 +94,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception FullHttpResponse fullHttpResponse = (FullHttpResponse) msg; httpCarbonResponse = setUpCarbonMessage(ctx, fullHttpResponse); log.debug("WebSocket Client connected!"); + ctx.channel().config().setAutoRead(autoRead); if (!autoRead) { ctx.channel().pipeline().addLast(Constants.WEBSOCKET_FRAME_BLOCKING_HANDLER, blockingHandler); } diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java index f1fed1201..3e5746498 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientFunctionalityTestCase.java @@ -30,10 +30,10 @@ import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contract.websocket.WebSocketCloseMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnectorListener; -import org.wso2.transport.http.netty.contract.websocket.WsClientConnectorConfig; import org.wso2.transport.http.netty.contractimpl.DefaultHttpWsConnectorFactory; import org.wso2.transport.http.netty.message.HttpCarbonResponse; import org.wso2.transport.http.netty.util.server.websocket.WebSocketRemoteServer; @@ -62,7 +62,7 @@ public class WebSocketClientFunctionalityTestCase { public void setup() throws InterruptedException { remoteServer = new WebSocketRemoteServer(WEBSOCKET_REMOTE_SERVER_PORT, "xml, json"); remoteServer.run(); - WsClientConnectorConfig configuration = new WsClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); clientConnector = httpConnectorFactory.createWsClientConnector(configuration); } diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java index 87d4cf39b..4228dc65b 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/client/WebSocketClientHandshakeFunctionalityTestCase.java @@ -28,9 +28,9 @@ import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeFuture; import org.wso2.transport.http.netty.contract.websocket.ClientHandshakeListener; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnectorListener; -import org.wso2.transport.http.netty.contract.websocket.WsClientConnectorConfig; import org.wso2.transport.http.netty.contractimpl.DefaultHttpWsConnectorFactory; import org.wso2.transport.http.netty.message.HttpCarbonResponse; import org.wso2.transport.http.netty.util.server.websocket.WebSocketRemoteServer; @@ -59,13 +59,13 @@ public class WebSocketClientHandshakeFunctionalityTestCase { public void setup() throws InterruptedException { remoteServer = new WebSocketRemoteServer(WEBSOCKET_REMOTE_SERVER_PORT, "xml, json"); remoteServer.run(); - WsClientConnectorConfig configuration = new WsClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); clientConnector = httpConnectorFactory.createWsClientConnector(configuration); } @Test(description = "Test the idle timeout for WebSocket") public void testIdleTimeout() throws Throwable { - WsClientConnectorConfig configuration = new WsClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); configuration.setIdleTimeoutInMillis(3000); HandshakeResult result = connectAndGetHandshakeResult(configuration); @@ -79,7 +79,7 @@ public void testIdleTimeout() throws Throwable { @Test(description = "Test the sub protocol negotiation with the remote server") public void testSubProtocolNegotiationSuccessful() throws InterruptedException { - WsClientConnectorConfig configuration = new WsClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); String[] subProtocolsSuccess = {"xmlx", "json"}; configuration.setSubProtocols(subProtocolsSuccess); HandshakeResult result = connectAndGetHandshakeResult(configuration); @@ -93,7 +93,7 @@ public void testSubProtocolNegotiationSuccessful() throws InterruptedException { @Test(description = "Test the sub protocol negotiation with the remote server") public void testSubProtocolNegotiationFail() throws InterruptedException { - WsClientConnectorConfig configuration = new WsClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); String[] subProtocolsFail = {"xmlx", "jsonx"}; configuration.setSubProtocols(subProtocolsFail); HandshakeResult result = connectAndGetHandshakeResult(configuration); @@ -106,7 +106,7 @@ public void testSubProtocolNegotiationFail() throws InterruptedException { @Test(description = "Test whether client can send custom headers and receive.") public void testSendAndReceiveCustomHeaders() throws InterruptedException { - WsClientConnectorConfig configuration = new WsClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); configuration.addHeader("x-ack-custom-header", "true"); HandshakeResult result = connectAndGetHandshakeResult(configuration); HttpCarbonResponse response = result.getHandshakeResponse(); @@ -119,7 +119,7 @@ public void testSendAndReceiveCustomHeaders() throws InterruptedException { @Test(description = "Test the behavior of client connector when auto read is false.") public void testReadNextFrame() throws Throwable { - WsClientConnectorConfig configuration = new WsClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(WEBSOCKET_REMOTE_SERVER_URL); configuration.setAutoRead(false); HandshakeResult result = connectAndGetHandshakeResult(configuration); @@ -164,7 +164,7 @@ private String[] sendTextMessages(WebSocketConnection webSocketConnection, int n return testMsgArray; } - private HandshakeResult connectAndGetHandshakeResult(WsClientConnectorConfig configuration) + private HandshakeResult connectAndGetHandshakeResult(WebSocketClientConnectorConfig configuration) throws InterruptedException { clientConnector = httpConnectorFactory.createWsClientConnector(configuration); WebSocketTestClientConnectorListener connectorListener = new WebSocketTestClientConnectorListener(); diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughServerConnectorListener.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughServerConnectorListener.java index ac07bc936..8ce6a81e7 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughServerConnectorListener.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughServerConnectorListener.java @@ -28,13 +28,13 @@ import org.wso2.transport.http.netty.contract.websocket.ServerHandshakeListener; import org.wso2.transport.http.netty.contract.websocket.WebSocketBinaryMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnector; +import org.wso2.transport.http.netty.contract.websocket.WebSocketClientConnectorConfig; import org.wso2.transport.http.netty.contract.websocket.WebSocketCloseMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnection; import org.wso2.transport.http.netty.contract.websocket.WebSocketConnectorListener; import org.wso2.transport.http.netty.contract.websocket.WebSocketControlMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketInitMessage; import org.wso2.transport.http.netty.contract.websocket.WebSocketTextMessage; -import org.wso2.transport.http.netty.contract.websocket.WsClientConnectorConfig; import org.wso2.transport.http.netty.contractimpl.DefaultHttpWsConnectorFactory; import org.wso2.transport.http.netty.message.HttpCarbonResponse; import org.wso2.transport.http.netty.util.TestUtil; @@ -52,7 +52,7 @@ public class WebSocketPassThroughServerConnectorListener implements WebSocketCon public void onMessage(WebSocketInitMessage initMessage) { String remoteUrl = String.format("ws://%s:%d/%s", "localhost", TestUtil.WEBSOCKET_REMOTE_SERVER_PORT, "websocket"); - WsClientConnectorConfig configuration = new WsClientConnectorConfig(remoteUrl); + WebSocketClientConnectorConfig configuration = new WebSocketClientConnectorConfig(remoteUrl); configuration.setAutoRead(false); WebSocketClientConnector clientConnector = connectorFactory.createWsClientConnector(configuration); WebSocketConnectorListener clientConnectorListener = new WebSocketPassThroughClientConnectorListener(); diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughTestCase.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughTestCase.java index c18aa43a6..344d39124 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughTestCase.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/passthrough/WebSocketPassThroughTestCase.java @@ -35,7 +35,9 @@ import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; + +import static org.wso2.transport.http.netty.util.TestUtil.WEBSOCKET_TEST_IDLE_TIMEOUT; +import static java.util.concurrent.TimeUnit.SECONDS; /** * Test cases for WebSocket pass-through scenarios. @@ -44,19 +46,18 @@ public class WebSocketPassThroughTestCase { private static final Logger log = LoggerFactory.getLogger(WebSocketPassThroughTestCase.class); - private final int latchCountDownInSecs = 10; - - private DefaultHttpWsConnectorFactory httpConnectorFactory = new DefaultHttpWsConnectorFactory(); - private WebSocketRemoteServer remoteServer = new WebSocketRemoteServer(TestUtil.WEBSOCKET_REMOTE_SERVER_PORT); - + private DefaultHttpWsConnectorFactory httpConnectorFactory; + private WebSocketRemoteServer remoteServer; private ServerConnector serverConnector; @BeforeClass public void setup() throws InterruptedException { + remoteServer = new WebSocketRemoteServer(TestUtil.WEBSOCKET_REMOTE_SERVER_PORT); remoteServer.run(); ListenerConfiguration listenerConfiguration = new ListenerConfiguration(); listenerConfiguration.setHost("localhost"); listenerConfiguration.setPort(TestUtil.SERVER_CONNECTOR_PORT); + httpConnectorFactory = new DefaultHttpWsConnectorFactory(); serverConnector = httpConnectorFactory.createServerConnector(TestUtil.getDefaultServerBootstrapConfig(), listenerConfiguration); ServerConnectorFuture connectorFuture = serverConnector.start(); @@ -72,7 +73,7 @@ public void testTextPassThrough() throws InterruptedException, URISyntaxExceptio webSocketClient.setCountDownLatch(latch); String text = "hello-pass-through"; webSocketClient.sendText(text); - latch.await(latchCountDownInSecs, TimeUnit.SECONDS); + latch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); Assert.assertEquals(webSocketClient.getTextReceived(), text); @@ -87,7 +88,7 @@ public void testBinaryPassThrough() throws InterruptedException, URISyntaxExcept webSocketClient.setCountDownLatch(latch); ByteBuffer sentBuffer = ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5}); webSocketClient.sendBinary(sentBuffer); - latch.await(latchCountDownInSecs, TimeUnit.SECONDS); + latch.await(WEBSOCKET_TEST_IDLE_TIMEOUT, SECONDS); Assert.assertEquals(webSocketClient.getBufferReceived(), sentBuffer); diff --git a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/server/WebSocketTestServerConnectorListener.java b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/server/WebSocketTestServerConnectorListener.java index 72e11348e..220b7d4b9 100644 --- a/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/server/WebSocketTestServerConnectorListener.java +++ b/components/org.wso2.transport.http.netty/src/test/java/org/wso2/transport/http/netty/websocket/server/WebSocketTestServerConnectorListener.java @@ -125,7 +125,7 @@ public void onMessage(WebSocketBinaryMessage binaryMessage) { public void onMessage(WebSocketControlMessage controlMessage) { if (controlMessage.getControlSignal() == WebSocketControlSignal.PING) { WebSocketConnection webSocketConnection = controlMessage.getWebSocketConnection(); - webSocketConnection.pong(controlMessage.getPayload()).addListener(future -> { + webSocketConnection.pong(controlMessage.getByteBuffer()).addListener(future -> { if (!future.isSuccess()) { Assert.fail("Could not send the message. " + future.cause().getMessage());