diff --git a/src/main/java/com/neovisionaries/ws/client/OkHostnameVerifier.java b/src/main/java/com/neovisionaries/ws/client/OkHostnameVerifier.java index 798adcf..b8be006 100644 --- a/src/main/java/com/neovisionaries/ws/client/OkHostnameVerifier.java +++ b/src/main/java/com/neovisionaries/ws/client/OkHostnameVerifier.java @@ -68,6 +68,11 @@ public boolean verify(String host, SSLSession session) { } } + public boolean verifyWithExc(String host, SSLSession session) throws SSLException { + Certificate[] certificates = session.getPeerCertificates(); + return verify(host, (X509Certificate) certificates[0]); + } + public boolean verify(String host, X509Certificate certificate) { return verifyAsIpAddress(host) ? verifyIpAddress(host, certificate) diff --git a/src/main/java/com/neovisionaries/ws/client/ProxyHandshaker.java b/src/main/java/com/neovisionaries/ws/client/ProxyHandshaker.java index cbaddef..3dcca09 100644 --- a/src/main/java/com/neovisionaries/ws/client/ProxyHandshaker.java +++ b/src/main/java/com/neovisionaries/ws/client/ProxyHandshaker.java @@ -28,32 +28,30 @@ class ProxyHandshaker { private static final String RN = "\r\n"; - private final Socket mSocket; private final String mHost; private final int mPort; private final ProxySettings mSettings; - public ProxyHandshaker(Socket socket, String host, int port, ProxySettings settings) + public ProxyHandshaker(String host, int port, ProxySettings settings) { - mSocket = socket; mHost = host; mPort = port; mSettings = settings; } - public void perform() throws IOException + public void perform(Socket socket) throws IOException { // Send a CONNECT request to the proxy server. - sendRequest(); + sendRequest(socket); // Receive a response. - receiveResponse(); + receiveResponse(socket); } - private void sendRequest() throws IOException + private void sendRequest(Socket socket) throws IOException { // Build a CONNECT request. String request = buildRequest(); @@ -62,7 +60,7 @@ private void sendRequest() throws IOException byte[] requestBytes = Misc.getBytesUTF8(request); // Get the stream to send data to the proxy server. - OutputStream output = mSocket.getOutputStream(); + OutputStream output = socket.getOutputStream(); // Send the request to the proxy server. output.write(requestBytes); @@ -140,10 +138,10 @@ private void addProxyAuthorization(StringBuilder builder) } - private void receiveResponse() throws IOException + private void receiveResponse(Socket socket) throws IOException { // Get the stream to read data from the proxy server. - InputStream input = mSocket.getInputStream(); + InputStream input = socket.getInputStream(); // Read the status line. readStatusLine(input); diff --git a/src/main/java/com/neovisionaries/ws/client/SocketConnector.java b/src/main/java/com/neovisionaries/ws/client/SocketConnector.java index 2c78e20..dbac6c4 100644 --- a/src/main/java/com/neovisionaries/ws/client/SocketConnector.java +++ b/src/main/java/com/neovisionaries/ws/client/SocketConnector.java @@ -21,7 +21,13 @@ import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; - +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SNIHostName; +import java.util.List; +import java.util.ArrayList; +import javax.net.ssl.SSLException; +import javax.net.SocketFactory; /** * A class to connect to the server. @@ -32,6 +38,7 @@ */ class SocketConnector { + private SocketFactory mSocketFactory; private Socket mSocket; private final Address mAddress; private final int mConnectionTimeout; @@ -41,18 +48,18 @@ class SocketConnector private final int mPort; - SocketConnector(Socket socket, Address address, int timeout) + SocketConnector(SocketFactory socketFactory, Address address, int timeout) { - this(socket, address, timeout, null, null, null, 0); + this(socketFactory, address, timeout, null, null, null, 0); } SocketConnector( - Socket socket, Address address, int timeout, + SocketFactory socketFactory, Address address, int timeout, ProxyHandshaker handshaker, SSLSocketFactory sslSocketFactory, String host, int port) { - mSocket = socket; + mSocketFactory = socketFactory; mAddress = address; mConnectionTimeout = timeout; mProxyHandshaker = handshaker; @@ -87,8 +94,11 @@ public void connect() throws WebSocketException try { - // Close the socket. - mSocket.close(); + if (mSocket != null) + { + // Close the socket. + mSocket.close(); + } } catch (IOException ioe) { @@ -107,8 +117,10 @@ private void doConnect() throws WebSocketException try { - // Connect to the server (either a proxy or a WebSocket endpoint). - mSocket.connect(mAddress.toInetSocketAddress(), mConnectionTimeout); + // Connect to the server (either a proxy or a WebSocket endpoint); overlay mSocket + // Ignore mConnectionTimeout + mSocket = mSocketFactory.createSocket(mAddress.getHostname(), mAddress.toInetSocketAddress().getPort()); + // mSocket.connect(mAddress.toInetSocketAddress(), mConnectionTimeout); if (mSocket instanceof SSLSocket) { @@ -137,7 +149,7 @@ private void doConnect() throws WebSocketException } - private void verifyHostname(SSLSocket socket, String hostname) throws HostnameUnverifiedException + private void verifyHostname(SSLSocket socket, String hostname) throws HostnameUnverifiedException, WebSocketException { // Hostname verifier. OkHostnameVerifier verifier = OkHostnameVerifier.INSTANCE; @@ -145,18 +157,23 @@ private void verifyHostname(SSLSocket socket, String hostname) throws HostnameUn // The SSL session. SSLSession session = socket.getSession(); - // Verify the hostname. - if (verifier.verify(hostname, session)) - { - // Verified. No problem. - return; + try { + // Verify the hostname. + if (verifier.verifyWithExc(hostname, session)) + { + // Verified. No problem. + return; + } + } catch (SSLException e) { + String message = String.format( + "Handshake with the server (%s) failed: %s", hostname, e.getMessage()); + throw new WebSocketException(WebSocketError.SSL_HANDSHAKE_ERROR, message, e); } // The certificate of the peer does not match the expected hostname. throw new HostnameUnverifiedException(socket, hostname); } - /** * Perform proxy handshake and optionally SSL handshake. */ @@ -165,7 +182,7 @@ private void handshake() throws WebSocketException try { // Perform handshake with the proxy server. - mProxyHandshaker.perform(); + mProxyHandshaker.perform(mSocket); } catch (IOException e) { diff --git a/src/main/java/com/neovisionaries/ws/client/WebSocketFactory.java b/src/main/java/com/neovisionaries/ws/client/WebSocketFactory.java index 1b9aa74..05c598b 100644 --- a/src/main/java/com/neovisionaries/ws/client/WebSocketFactory.java +++ b/src/main/java/com/neovisionaries/ws/client/WebSocketFactory.java @@ -588,14 +588,11 @@ private SocketConnector createProxiedRawSocket( // Select a socket factory. SocketFactory socketFactory = mProxySettings.selectSocketFactory(); - // Let the socket factory create a socket. - Socket socket = socketFactory.createSocket(); - // The address to connect to. Address address = new Address(mProxySettings.getHost(), proxyPort); // The delegatee for the handshake with the proxy. - ProxyHandshaker handshaker = new ProxyHandshaker(socket, host, port, mProxySettings); + ProxyHandshaker handshaker = new ProxyHandshaker(host, port, mProxySettings); // SSLSocketFactory for SSL handshake with the WebSocket endpoint. SSLSocketFactory sslSocketFactory = secure ? @@ -603,7 +600,7 @@ private SocketConnector createProxiedRawSocket( // Create an instance that will execute the task to connect to the server later. return new SocketConnector( - socket, address, timeout, handshaker, sslSocketFactory, host, port); + socketFactory, address, timeout, handshaker, sslSocketFactory, host, port); } @@ -612,14 +609,11 @@ private SocketConnector createDirectRawSocket(String host, int port, boolean sec // Select a socket factory. SocketFactory factory = mSocketFactorySettings.selectSocketFactory(secure); - // Let the socket factory create a socket. - Socket socket = factory.createSocket(); - // The address to connect to. Address address = new Address(host, port); // Create an instance that will execute the task to connect to the server later. - return new SocketConnector(socket, address, timeout); + return new SocketConnector(factory, address, timeout); }