Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Android-compatibile support for SNI in direct connections #109

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public boolean verify(String host, SSLSession session) {
}
}

public boolean verifyWithExc(String host, SSLSession session) throws SSLException {
Certificate[] certificates = session.getPeerCertificates();
return verify(host, (X509Certificate) certificates[0]);
}

public boolean verify(String host, X509Certificate certificate) {
return verifyAsIpAddress(host)
? verifyIpAddress(host, certificate)
Expand Down
18 changes: 8 additions & 10 deletions src/main/java/com/neovisionaries/ws/client/ProxyHandshaker.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,30 @@
class ProxyHandshaker
{
private static final String RN = "\r\n";
private final Socket mSocket;
private final String mHost;
private final int mPort;
private final ProxySettings mSettings;


public ProxyHandshaker(Socket socket, String host, int port, ProxySettings settings)
public ProxyHandshaker(String host, int port, ProxySettings settings)
{
mSocket = socket;
mHost = host;
mPort = port;
mSettings = settings;
}


public void perform() throws IOException
public void perform(Socket socket) throws IOException
{
// Send a CONNECT request to the proxy server.
sendRequest();
sendRequest(socket);

// Receive a response.
receiveResponse();
receiveResponse(socket);
}


private void sendRequest() throws IOException
private void sendRequest(Socket socket) throws IOException
{
// Build a CONNECT request.
String request = buildRequest();
Expand All @@ -62,7 +60,7 @@ private void sendRequest() throws IOException
byte[] requestBytes = Misc.getBytesUTF8(request);

// Get the stream to send data to the proxy server.
OutputStream output = mSocket.getOutputStream();
OutputStream output = socket.getOutputStream();

// Send the request to the proxy server.
output.write(requestBytes);
Expand Down Expand Up @@ -140,10 +138,10 @@ private void addProxyAuthorization(StringBuilder builder)
}


private void receiveResponse() throws IOException
private void receiveResponse(Socket socket) throws IOException
{
// Get the stream to read data from the proxy server.
InputStream input = mSocket.getInputStream();
InputStream input = socket.getInputStream();

// Read the status line.
readStatusLine(input);
Expand Down
51 changes: 34 additions & 17 deletions src/main/java/com/neovisionaries/ws/client/SocketConnector.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

import javax.net.ssl.SSLParameters;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SNIHostName;
import java.util.List;
import java.util.ArrayList;
import javax.net.ssl.SSLException;
import javax.net.SocketFactory;

/**
* A class to connect to the server.
Expand All @@ -32,6 +38,7 @@
*/
class SocketConnector
{
private SocketFactory mSocketFactory;
private Socket mSocket;
private final Address mAddress;
private final int mConnectionTimeout;
Expand All @@ -41,18 +48,18 @@ class SocketConnector
private final int mPort;


SocketConnector(Socket socket, Address address, int timeout)
SocketConnector(SocketFactory socketFactory, Address address, int timeout)
{
this(socket, address, timeout, null, null, null, 0);
this(socketFactory, address, timeout, null, null, null, 0);
}


SocketConnector(
Socket socket, Address address, int timeout,
SocketFactory socketFactory, Address address, int timeout,
ProxyHandshaker handshaker, SSLSocketFactory sslSocketFactory,
String host, int port)
{
mSocket = socket;
mSocketFactory = socketFactory;
mAddress = address;
mConnectionTimeout = timeout;
mProxyHandshaker = handshaker;
Expand Down Expand Up @@ -87,8 +94,11 @@ public void connect() throws WebSocketException

try
{
// Close the socket.
mSocket.close();
if (mSocket != null)
{
// Close the socket.
mSocket.close();
}
}
catch (IOException ioe)
{
Expand All @@ -107,8 +117,10 @@ private void doConnect() throws WebSocketException

try
{
// Connect to the server (either a proxy or a WebSocket endpoint).
mSocket.connect(mAddress.toInetSocketAddress(), mConnectionTimeout);
// Connect to the server (either a proxy or a WebSocket endpoint); overlay mSocket
// Ignore mConnectionTimeout
mSocket = mSocketFactory.createSocket(mAddress.getHostname(), mAddress.toInetSocketAddress().getPort());
// mSocket.connect(mAddress.toInetSocketAddress(), mConnectionTimeout);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add here mSocket.setSoTimeout(mConnectionTimeout);. It should solve the problem with timeout not being set.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, even better,

mSocket = mSocketFactory.createSocket(); // This will create an unconnected socket
mSocket.setSoTimeout(mConnectionTimeout); // Set timeout
mSocket = mSocketFactory.createSocket(mSocket, mAddress.getHostname(), mAddress.toInetSocketAddress().getPort(), true); // Connect.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if (mSocket instanceof SSLSocket)
{
Expand Down Expand Up @@ -137,26 +149,31 @@ private void doConnect() throws WebSocketException
}


private void verifyHostname(SSLSocket socket, String hostname) throws HostnameUnverifiedException
private void verifyHostname(SSLSocket socket, String hostname) throws HostnameUnverifiedException, WebSocketException
{
// Hostname verifier.
OkHostnameVerifier verifier = OkHostnameVerifier.INSTANCE;

// The SSL session.
SSLSession session = socket.getSession();

// Verify the hostname.
if (verifier.verify(hostname, session))
{
// Verified. No problem.
return;
try {
// Verify the hostname.
if (verifier.verifyWithExc(hostname, session))
{
// Verified. No problem.
return;
}
} catch (SSLException e) {
String message = String.format(
"Handshake with the server (%s) failed: %s", hostname, e.getMessage());
throw new WebSocketException(WebSocketError.SSL_HANDSHAKE_ERROR, message, e);
}

// The certificate of the peer does not match the expected hostname.
throw new HostnameUnverifiedException(socket, hostname);
}


/**
* Perform proxy handshake and optionally SSL handshake.
*/
Expand All @@ -165,7 +182,7 @@ private void handshake() throws WebSocketException
try
{
// Perform handshake with the proxy server.
mProxyHandshaker.perform();
mProxyHandshaker.perform(mSocket);
}
catch (IOException e)
{
Expand Down
12 changes: 3 additions & 9 deletions src/main/java/com/neovisionaries/ws/client/WebSocketFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -588,22 +588,19 @@ private SocketConnector createProxiedRawSocket(
// Select a socket factory.
SocketFactory socketFactory = mProxySettings.selectSocketFactory();

// Let the socket factory create a socket.
Socket socket = socketFactory.createSocket();

// The address to connect to.
Address address = new Address(mProxySettings.getHost(), proxyPort);

// The delegatee for the handshake with the proxy.
ProxyHandshaker handshaker = new ProxyHandshaker(socket, host, port, mProxySettings);
ProxyHandshaker handshaker = new ProxyHandshaker(host, port, mProxySettings);

// SSLSocketFactory for SSL handshake with the WebSocket endpoint.
SSLSocketFactory sslSocketFactory = secure ?
(SSLSocketFactory)mSocketFactorySettings.selectSocketFactory(secure) : null;

// Create an instance that will execute the task to connect to the server later.
return new SocketConnector(
socket, address, timeout, handshaker, sslSocketFactory, host, port);
socketFactory, address, timeout, handshaker, sslSocketFactory, host, port);
}


Expand All @@ -612,14 +609,11 @@ private SocketConnector createDirectRawSocket(String host, int port, boolean sec
// Select a socket factory.
SocketFactory factory = mSocketFactorySettings.selectSocketFactory(secure);

// Let the socket factory create a socket.
Socket socket = factory.createSocket();

// The address to connect to.
Address address = new Address(host, port);

// Create an instance that will execute the task to connect to the server later.
return new SocketConnector(socket, address, timeout);
return new SocketConnector(factory, address, timeout);
}


Expand Down