Skip to content

Commit

Permalink
Implement key alias fetching in StaticSSLContext. Introduce methods…
Browse files Browse the repository at this point in the history
… to retrieve the common name from X509 certificates and update tests to validate key alias selection. Ensure fallback logic for missing or incorrect key aliases is handled gracefully with appropriate warnings.
  • Loading branch information
t-burch committed Aug 26, 2024
1 parent a7fbcbe commit 96a1009
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@

import javax.annotation.Nullable;
import javax.crypto.Cipher;
import javax.naming.InvalidNameException;
import javax.naming.ldap.LdapName;
import javax.naming.ldap.Rdn;
import javax.net.ssl.*;
import javax.security.auth.x500.X500Principal;
import javax.validation.constraints.NotNull;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.*;
import java.security.cert.*;
import java.security.cert.Certificate;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPublicKey;
import java.util.*;

public class StaticSSLContext extends SSLContext {
Expand Down Expand Up @@ -69,6 +70,7 @@ public class StaticSSLContext extends SSLContext {

private final SSLParser sslParser;
private List<String> dnsNames;
private String commonName;

private javax.net.ssl.SSLContext sslc;
private long validFrom, validUntil;
Expand All @@ -95,22 +97,15 @@ public StaticSSLContext(SSLParser sslParser, ResolverMap resourceResolver, Strin
kmf = KeyManagerFactory.getInstance(algorihm);
kmf.init(ks, keyPass);

Enumeration<String> aliases = ks.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
if (ks.isKeyEntry(alias)) {
if (sslParser.getKeyStore().getKeyAlias() != null) {
String keyAlias = sslParser.getKeyStore().getKeyAlias();
if (!alias.equals(keyAlias))
continue;
}

dnsNames = getDNSNames(ks.getCertificate(alias));
List<Certificate> certs = Arrays.asList(ks.getCertificateChain(alias));
validUntil = getMinimumValidity(certs);
validFrom = getValidFrom(certs);
break;
}
Optional<String> keyAlias = fetchKeyAlias(ks, sslParser.getKeyStore().getKeyAlias());
if (keyAlias.isPresent()) {
getCommonName(ks.getCertificate(keyAlias.get())).ifPresent(cn -> commonName = cn);
dnsNames = getDNSNames(ks.getCertificate(keyAlias.get()));
List<Certificate> certs = Arrays.asList(ks.getCertificateChain(keyAlias.get()));
validUntil = getMinimumValidity(certs);
validFrom = getValidFrom(certs);
} else {
log.warn("Specified keystore does not contain key of alias '{}'", sslParser.getKeyStore().getKeyAlias());
}
}
if (sslParser.getKey() != null) {
Expand Down Expand Up @@ -210,12 +205,40 @@ public StaticSSLContext(SSLParser sslParser, javax.net.ssl.SSLContext sslc) {
init(sslParser, sslc);
}

static Optional<String> fetchKeyAlias(KeyStore ks, String requiredAlias) throws KeyStoreException {
Enumeration<String> aliases = ks.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
if (ks.isKeyEntry(alias)) {
if (requiredAlias != null) {
if (alias.equals(requiredAlias)) {
return Optional.of(alias);
}
} else {
return Optional.of(alias);
}
}
}
return Optional.empty();
}

static Optional<String> getCommonName(Certificate certificate) throws InvalidNameException {
if (certificate instanceof X509Certificate cert) {
String dn = cert.getSubjectX500Principal().getName();
LdapName lddn = new LdapName(dn);
for (Rdn rdn : lddn.getRdns()) {
if (rdn.getType().equalsIgnoreCase("cn")) {
return Optional.of(rdn.getValue().toString());
}
}
}
return Optional.empty();
}

private List<String> getDNSNames(Certificate certificate) throws CertificateParsingException {
ArrayList<String> dnsNames = new ArrayList<>();
if (certificate instanceof X509Certificate) {
X509Certificate x = (X509Certificate) certificate;

Collection<List<?>> subjectAlternativeNames = x.getSubjectAlternativeNames();
if (certificate instanceof X509Certificate cert) {
Collection<List<?>> subjectAlternativeNames = cert.getSubjectAlternativeNames();
if (subjectAlternativeNames != null)
for (List<?> l : subjectAlternativeNames) {
if (l.get(0) instanceof Integer && ((Integer)l.get(0) == 2))
Expand All @@ -233,7 +256,7 @@ public boolean equals(Object obj) {
return Objects.equal(sslParser, other.sslParser);
}

private KeyStore openKeyStore(Store store, String defaultType, char[] keyPass, ResolverMap resourceResolver, String baseLocation) throws NoSuchAlgorithmException, CertificateException, FileNotFoundException, IOException, KeyStoreException, NoSuchProviderException {
static KeyStore openKeyStore(Store store, String defaultType, char[] keyPass, ResolverMap resourceResolver, String baseLocation) throws NoSuchAlgorithmException, CertificateException, FileNotFoundException, IOException, KeyStoreException, NoSuchProviderException {
String type = store.getType();
if (type == null)
type = defaultType;
Expand Down Expand Up @@ -422,4 +445,8 @@ public long getValidFrom() {
public long getValidUntil() {
return validUntil;
}

public String getCommonName() {
return commonName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,34 @@

package com.predic8.membrane.core.transport.ssl;

import com.predic8.membrane.core.*;
import com.predic8.membrane.core.config.security.*;
import org.junit.jupiter.api.*;
import com.predic8.membrane.core.HttpRouter;
import com.predic8.membrane.core.Router;
import com.predic8.membrane.core.config.security.KeyStore;
import com.predic8.membrane.core.config.security.SSLParser;
import com.predic8.membrane.core.config.security.TrustStore;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import org.junit.platform.commons.util.UnrecoverableExceptions;

import javax.net.ssl.*;
import java.io.*;
import java.net.*;

import static java.lang.String.format;
import static org.junit.jupiter.api.AssertionFailureBuilder.assertionFailure;
import javax.naming.InvalidNameException;
import javax.net.ssl.SSLHandshakeException;
import javax.security.auth.x500.X500Principal;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Optional;

import static com.predic8.membrane.core.transport.ssl.StaticSSLContext.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class SSLContextTest {

Expand All @@ -49,7 +64,7 @@ public SSLContextBuilder() {
sslParser.setEndpointIdentificationAlgorithm("");
}

public SSLContext build() {
public StaticSSLContext build() {
return new StaticSSLContext(sslParser, router.getResolverMap(), router.getBaseLocation());
}

Expand Down Expand Up @@ -83,26 +98,53 @@ private SSLContextBuilder cb() {
}

@Test
public void keyAliasSelectValid() throws Exception {
SSLContext server1 = cb().withKeyStore("classpath:/alias-keystore.p12").byKeyAlias("key1").build();
SSLContext client1 = cb().withTrustStore("classpath:/alias-truststore.p12").build();
testCombination(server1, client1);
public void keyAliasSelectPresent() throws Exception {
Optional<String> key1 = fetchKeyAlias(getAliasKeystoreByAlias("key1", router), "key1");
Optional<String> key2 = fetchKeyAlias(getAliasKeystoreByAlias("key2", router), "key2");
assertTrue(key1.isPresent());
assertTrue(key2.isPresent());
assertEquals("key1", key1.get());
assertEquals("key2", key2.get());
}

@Test
public void keyAliasSelectWrongKey() throws Exception {
SSLContext server2 = cb().withKeyStore("classpath:/alias-keystore.p12").byKeyAlias("key2").build();
SSLContext client2 = cb().withTrustStore("classpath:/alias-truststore.p12").build();
testCombination(server2, client2);
public void keyAliasDefaultFallback() throws Exception {
Optional<String> key1 = fetchKeyAlias(getAliasKeystoreByAlias("key1", router), null);
Optional<String> key2 = fetchKeyAlias(getAliasKeystoreByAlias("key2", router), null);
assertTrue(key1.isPresent());
assertTrue(key2.isPresent());
assertEquals("key1", key1.get());
assertEquals("key1", key2.get());
}

@Test
public void keyAliasInvalid() throws Exception {
assertThrows(Exception.class, () -> {
SSLContext serverInvalid = cb().withKeyStore("classpath:/alias-keystore.p12").byKeyAlias("invalid").build();
SSLContext clientInvalid = cb().withTrustStore("classpath:/alias-truststore.p12").build();
testCombination(serverInvalid, clientInvalid);
});
public void keyAliasSelectNotPresent() throws Exception {
Optional<String> key = fetchKeyAlias(getAliasKeystoreByAlias("key3", router), "key3");
assertTrue(key.isEmpty());
}

@Test
void validX509ReturnsCN() throws InvalidNameException {
X509Certificate cert = mock(X509Certificate.class);
when(cert.getSubjectX500Principal()).thenReturn(new X500Principal("CN=John Doe, O=Example Org, C=US"));
Optional<String> result = getCommonName(cert);
assertTrue(result.isPresent());
assertEquals("John Doe", result.get());
}

@Test
void X509WithoutCNReturnsEmpty() throws InvalidNameException {
X509Certificate cert = mock(X509Certificate.class);
when(cert.getSubjectX500Principal()).thenReturn(new X500Principal("O=Example Org, C=US"));
Optional<String> result = getCommonName(cert);
assertTrue(result.isEmpty());
}

@Test
void nonX509ReturnsEmpty() throws InvalidNameException {
Certificate cert = mock(Certificate.class);
Optional<String> result = getCommonName(cert);
assertTrue(result.isEmpty());
}

@Test
Expand Down Expand Up @@ -139,22 +181,6 @@ public void serverKeyOnlyWithInvalidClientTrust() {
});
}

public static <T extends Throwable, S extends Throwable> void assertThrows2(Class<T> expectedType1, Class<S> expectedType2, Executable executable) {
try {
executable.execute();
} catch (Throwable actualException) {
if (expectedType1.isInstance(actualException)) {
return;
} else if (expectedType2.isInstance(actualException)) {
return;
} else {
UnrecoverableExceptions.rethrowIfUnrecoverable(actualException);
throw new RuntimeException("Unexpected exception type thrown");
}
}
throw new RuntimeException("Expected exception to be thrown, but nothing was thrown.");
}

@Test
public void serverAndClientCertificates() throws Exception {
SSLContext server = cb().withKeyStore("classpath:/ssl-rsa.keystore").withTrustStore("classpath:/ssl-rsa-pub2.keystore").needClientAuth().build();
Expand All @@ -171,6 +197,29 @@ public void serverAndClientCertificatesWithoutServerTrust() {
});
}

private static @NotNull java.security.KeyStore getAliasKeystoreByAlias(String alias, Router router) throws Exception {
return openKeyStore(new KeyStore() {{
setLocation("classpath:/alias-keystore.p12");
setKeyPassword("secret");
setKeyAlias(alias);
}}, "PKCS12", "secret".toCharArray(), router.getResolverMap(), router.getBaseLocation());
}

public static <T extends Throwable, S extends Throwable> void assertThrows2(Class<T> expectedType1, Class<S> expectedType2, Executable executable) {
try {
executable.execute();
} catch (Throwable actualException) {
if (expectedType1.isInstance(actualException)) {
return;
} else if (expectedType2.isInstance(actualException)) {
return;
} else {
UnrecoverableExceptions.rethrowIfUnrecoverable(actualException);
throw new RuntimeException("Unexpected exception type thrown");
}
}
throw new RuntimeException("Expected exception to be thrown, but nothing was thrown.");
}

private void testCombination(SSLContext server, final SSLContext client) throws Exception {
ServerSocket ss = server.createServerSocket(3020, 50, null);
Expand Down Expand Up @@ -202,5 +251,4 @@ private void testCombination(SSLContext server, final SSLContext client) throws
if (ex[0] != null)
throw ex[0];
}

}

0 comments on commit 96a1009

Please sign in to comment.