From a13d16761772f638245a449187b0dc01e0b9ae12 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Fri, 17 May 2024 17:08:11 +0800 Subject: [PATCH] [CELEBORN-1401] Add SSL support for ratis communication ### What changes were proposed in this pull request? When SSL is enabled for master, secure the Ratis communication as well with TLS ### Why are the changes needed? Currently, when TLS is enabled for RPC, Ratis comms still goes in the clear - add support for TLS. Note that currently this only supports GRPC, and not netty. ### Does this PR introduce _any_ user-facing change? Secures ratis communication when TLS is enabled at master for rpc. ### How was this patch tested? Local tests and additional unit tests added Closes #2515 from mridulm/CELEBORN-1401-add-ratis-ssl-support. Authored-by: Mridul Muralidharan Signed-off-by: Shuang --- .../common/network/TransportContext.java | 33 +--- .../common/network/ssl/SSLFactory.java | 49 +++++- .../common/network/ssl/SslSampleConfigs.java | 128 +++++++++++--- docs/security.md | 2 + master/pom.xml | 7 + .../master/clustermeta/ha/HARaftServer.java | 55 +++++- .../clustermeta/ha/MasterClusterInfo.scala | 5 +- .../master/clustermeta/ha/MasterNode.scala | 11 +- .../ha/RatisMasterStatusSystemSuiteJ.java | 43 +++-- .../ha/SSLRatisMasterStatusSystemSuiteJ.java | 160 ++++++++++++++++++ project/CelebornBuild.scala | 1 + 11 files changed, 405 insertions(+), 89 deletions(-) create mode 100644 master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java diff --git a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java index 625350dc683..869b3fafd4f 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java +++ b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java @@ -87,7 +87,7 @@ public TransportContext( this.conf = conf; this.msgHandler = msgHandler; this.closeIdleConnections = closeIdleConnections; - this.sslFactory = createSslFactory(); + this.sslFactory = SSLFactory.createSslFactory(conf); this.channelsLimiter = channelsLimiter; this.enableHeartbeat = enableHeartbeat; this.source = source; @@ -216,37 +216,6 @@ public TransportChannelHandler initializePipeline( } } - private SSLFactory createSslFactory() { - if (conf.sslEnabled()) { - - if (conf.sslEnabledAndKeysAreValid()) { - return new SSLFactory.Builder() - .requestedProtocol(conf.sslProtocol()) - .requestedCiphers(conf.sslRequestedCiphers()) - .autoSslEnabled(conf.autoSslEnabled()) - .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword()) - .trustStore( - conf.sslTrustStore(), - conf.sslTrustStorePassword(), - conf.sslTrustStoreReloadingEnabled(), - conf.sslTrustStoreReloadIntervalMs()) - .build(); - } else { - logger.error( - "SSL encryption enabled but keyStore is not configured for " - + conf.getModuleName() - + "! Please ensure the configured keys are present."); - throw new IllegalArgumentException( - conf.getModuleName() - + " SSL encryption enabled for " - + conf.getModuleName() - + " but keyStore not configured !"); - } - } else { - return null; - } - } - private TransportChannelHandler createChannelHandler( Channel channel, BaseMessageHandler msgHandler) { TransportResponseHandler responseHandler = new TransportResponseHandler(conf, channel); diff --git a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java index 9612a77499d..a0eb106bf21 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java +++ b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java @@ -30,6 +30,7 @@ import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -47,6 +48,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.celeborn.common.network.util.TransportConf; import org.apache.celeborn.common.util.JavaUtils; /** @@ -96,6 +98,18 @@ private void initJdkSslContext(final Builder b) throws IOException, GeneralSecur this.jdkSslContext = createSSLContext(requestedProtocol, keyManagers, trustManagers); } + public List getKeyManagers() { + return null != keyManagers + ? Collections.unmodifiableList(Arrays.asList(keyManagers)) + : Collections.emptyList(); + } + + public List getTrustManagers() { + return null != trustManagers + ? Collections.unmodifiableList(Arrays.asList(trustManagers)) + : Collections.emptyList(); + } + /* * As b.trustStore is null, credulousTrustStoreManagers will be used - and so all * certs will be accepted - and hence self-signed cert from lifecycle manager will @@ -119,7 +133,7 @@ private void configureAutoSsl(Builder b) { } public boolean hasKeyManagers() { - return null != keyManagers; + return null != keyManagers && keyManagers.length > 0; } public void destroy() { @@ -327,7 +341,7 @@ private static TrustManager[] trustStoreManagers( } } - private static TrustManager[] defaultTrustManagers(File trustStore, String trustStorePassword) + public static TrustManager[] defaultTrustManagers(File trustStore, String trustStorePassword) throws IOException, KeyStoreException, CertificateException, NoSuchAlgorithmException { try (InputStream input = Files.asByteSource(trustStore).openStream()) { KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); @@ -436,4 +450,35 @@ private static List addIfSupported(String[] supported, String... names) } return enabled; } + + public static SSLFactory createSslFactory(TransportConf conf) { + if (conf.sslEnabled()) { + + if (conf.sslEnabledAndKeysAreValid()) { + return new SSLFactory.Builder() + .requestedProtocol(conf.sslProtocol()) + .requestedCiphers(conf.sslRequestedCiphers()) + .autoSslEnabled(conf.autoSslEnabled()) + .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword()) + .trustStore( + conf.sslTrustStore(), + conf.sslTrustStorePassword(), + conf.sslTrustStoreReloadingEnabled(), + conf.sslTrustStoreReloadIntervalMs()) + .build(); + } else { + logger.error( + "SSL encryption enabled but keyStore is not configured for " + + conf.getModuleName() + + "! Please ensure the configured keys are present."); + throw new IllegalArgumentException( + conf.getModuleName() + + " SSL encryption enabled for " + + conf.getModuleName() + + " but keyStore not configured !"); + } + } else { + return null; + } + } } diff --git a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java index 93007a936c9..3b634ae175f 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java +++ b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java @@ -26,19 +26,34 @@ import java.nio.file.StandardCopyOption; import java.security.*; import java.security.cert.Certificate; -import java.security.cert.CertificateEncodingException; import java.security.cert.X509Certificate; -import java.util.Date; -import java.util.HashMap; -import java.util.Map; - -import javax.security.auth.x500.X500Principal; +import java.util.*; +import java.util.stream.Stream; import org.apache.commons.io.FileUtils; -import org.bouncycastle.x509.X509V1CertificateGenerator; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class SslSampleConfigs { + private static final Logger LOG = LoggerFactory.getLogger(SslSampleConfigs.class); + + static { + Security.addProvider(new BouncyCastleProvider()); + } + public static final String DEFAULT_KEY_STORE_PATH = getResourceAsAbsolutePath("/ssl/server.jks"); public static final String SECOND_KEY_STORE_PATH = getResourceAsAbsolutePath("/ssl/server_another.jks"); @@ -113,34 +128,95 @@ public static void createTrustStore( * @param algorithm the signing algorithm, eg "SHA1withRSA" * @return the self-signed certificate */ + public static X509Certificate generateCertificate( + String dn, KeyPair pair, int days, String algorithm) throws Exception { + return generateCertificate(dn, pair, days, algorithm, false, null, null, null); + } + + /** + * Create a self-signed X.509 Certificate. + * + * @param dn the X.509 Distinguished Name, eg "CN=Test, L=London, C=GB" + * @param pair the KeyPair for the server + * @param days how many days from now the Certificate is valid for + * @param algorithm the signing algorithm, eg "SHA1withRSA" + * @param generateCaCert Is this request to generate a CA cert + * @param altNames Optional: Alternate names to be added to the cert - we add them as both + * hostnames and ip's. + * @param caKeyPair Optional: the KeyPair of the CA, to be used to sign this certificate. caCert + * should also be specified to use it + * @param caCert Optional: the CA cert, to be used to sign this certificate. caKeyPair should also + * be specified to use it + * @return the signed certificate (signed using ca if provided, else self-signed) + */ @SuppressWarnings("deprecation") public static X509Certificate generateCertificate( - String dn, KeyPair pair, int days, String algorithm) - throws CertificateEncodingException, InvalidKeyException, IllegalStateException, - NoSuchAlgorithmException, SignatureException { + String dn, + KeyPair pair, + int days, + String algorithm, + boolean generateCaCert, + String[] altNames, + KeyPair caKeyPair, + X509Certificate caCert) + throws Exception { Date from = new Date(); Date to = new Date(from.getTime() + days * 86400000L); BigInteger sn = new BigInteger(64, new SecureRandom()); - KeyPair keyPair = pair; - X509V1CertificateGenerator certGen = new X509V1CertificateGenerator(); - X500Principal dnName = new X500Principal(dn); - - certGen.setSerialNumber(sn); - certGen.setIssuerDN(dnName); - certGen.setNotBefore(from); - certGen.setNotAfter(to); - certGen.setSubjectDN(dnName); - certGen.setPublicKey(keyPair.getPublic()); - certGen.setSignatureAlgorithm(algorithm); - - X509Certificate cert = certGen.generate(pair.getPrivate()); - return cert; + X500Name subjectName = new X500Name(dn); + + X500Name issuerName; + KeyPair signingKeyPair; + + if (caKeyPair != null && caCert != null) { + issuerName = new JcaX509CertificateHolder(caCert).getSubject(); + signingKeyPair = caKeyPair; + } else { + issuerName = subjectName; + // self signed + signingKeyPair = pair; + } + + X509v3CertificateBuilder certBuilder = + new JcaX509v3CertificateBuilder( + issuerName, sn, from, to, new X500Name(dn), pair.getPublic()); + + if (null != altNames) { + Stream dnsStream = + Arrays.stream(altNames).map(h -> new GeneralName(GeneralName.dNSName, h)); + Stream ipStream = + Arrays.stream(altNames) + .map( + h -> { + try { + return new GeneralName(GeneralName.iPAddress, h); + } catch (Exception ex) { + return null; + } + }) + .filter(Objects::nonNull); + + GeneralName[] arr = Stream.concat(dnsStream, ipStream).toArray(GeneralName[]::new); + GeneralNames names = new GeneralNames(arr); + + certBuilder.addExtension(Extension.subjectAlternativeName, false, names); + LOG.info("Added subjectAlternativeName extension for hosts : " + Arrays.toString(altNames)); + } + + if (generateCaCert) { + certBuilder.addExtension(Extension.basicConstraints, true, new BasicConstraints(true)); + LOG.info("Added CA cert extension"); + } + + ContentSigner signer = + new JcaContentSignerBuilder(algorithm).build(signingKeyPair.getPrivate()); + return new JcaX509CertificateConverter().getCertificate(certBuilder.build(signer)); } public static KeyPair generateKeyPair(String algorithm) throws NoSuchAlgorithmException { KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm); - keyGen.initialize(1024); + keyGen.initialize(4096); return keyGen.genKeyPair(); } @@ -178,7 +254,7 @@ public static void createKeyStore( } private static KeyStore createEmptyKeyStore() throws GeneralSecurityException, IOException { - KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + KeyStore ks = KeyStore.getInstance("PKCS12"); ks.load(null, null); // initialize return ks; } diff --git a/docs/security.md b/docs/security.md index 563cfaf7b1a..8fa7b62fd10 100644 --- a/docs/security.md +++ b/docs/security.md @@ -34,6 +34,8 @@ start="" end="" !} +When SSL is enabled for `rpc_service`, Raft communication between masters are secured **only when** `celeborn.master.ha.ratis.raft.rpc.type` is set to `grpc`. + Note that `celeborn.ssl`, **without any module**, can be used to set SSL default values which applies to all modules. Also note that `data` module at application side, maps to `push` and `fetch` at worker - hence, for SSL configuration, worker configuration for `push` and `fetch` should be compatible with each other and with `data` at application side. diff --git a/master/pom.xml b/master/pom.xml index 4dd3ed11c77..0f1fe12bb55 100644 --- a/master/pom.xml +++ b/master/pom.xml @@ -104,6 +104,13 @@ jersey-test-framework-provider-jetty test + + org.apache.celeborn + celeborn-common_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java index 5a5694d867a..f0a7560d780 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.util.*; @@ -27,14 +28,19 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.ReentrantReadWriteLock; +import javax.net.ssl.KeyManager; +import javax.net.ssl.TrustManager; + import scala.Tuple2; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.InvalidProtocolBufferException; import org.apache.ratis.RaftConfigKeys; import org.apache.ratis.client.RaftClientConfigKeys; +import org.apache.ratis.conf.Parameters; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.grpc.GrpcConfigKeys; +import org.apache.ratis.grpc.GrpcTlsConfig; import org.apache.ratis.netty.NettyConfigKeys; import org.apache.ratis.proto.RaftProtos; import org.apache.ratis.protocol.*; @@ -54,6 +60,8 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.client.MasterClient; import org.apache.celeborn.common.exception.CelebornRuntimeException; +import org.apache.celeborn.common.network.ssl.SSLFactory; +import org.apache.celeborn.common.protocol.TransportModuleConstants; import org.apache.celeborn.common.util.ThreadUtils; import org.apache.celeborn.common.util.Utils; import org.apache.celeborn.service.deploy.master.clustermeta.ResourceProtos; @@ -123,13 +131,18 @@ private HARaftServer( this.raftGroup = RaftGroup.valueOf(RAFT_GROUP_ID, raftPeers); this.masterStateMachine = getStateMachine(); this.conf = conf; - RaftProperties serverProperties = newRaftProperties(conf); + + final RpcType rpc = SupportedRpcType.valueOfIgnoreCase(conf.haMasterRatisRpcType()); + RaftProperties serverProperties = newRaftProperties(conf, rpc); + Parameters sslParameters = + localNode.sslEnabled() ? configureSsl(conf, serverProperties, rpc) : null; setDeadlineTime(Integer.MAX_VALUE, Integer.MAX_VALUE); // for default this.server = RaftServer.newBuilder() .setServerId(this.raftPeerId) .setGroup(this.raftGroup) .setProperties(serverProperties) + .setParameters(sslParameters) .setStateMachine(masterStateMachine) .build(); @@ -270,11 +283,9 @@ public void stop() { } } - private RaftProperties newRaftProperties(CelebornConf conf) { + private RaftProperties newRaftProperties(CelebornConf conf, RpcType rpc) { final RaftProperties properties = new RaftProperties(); // Set RPC type - final String rpcType = conf.haMasterRatisRpcType(); - final RpcType rpc = SupportedRpcType.valueOfIgnoreCase(rpcType); RaftConfigKeys.Rpc.setType(properties, rpc); // Set the ratis port number @@ -375,6 +386,37 @@ private RaftProperties newRaftProperties(CelebornConf conf) { return properties; } + private Parameters configureSsl(CelebornConf conf, RaftProperties properties, RpcType rpc) { + + if (rpc != SupportedRpcType.GRPC) { + LOG.error( + "SSL has been disabled for Raft communication between masters. " + + "This is only supported when ratis is configured with GRPC"); + return null; + } + + // This is used only for querying state after initialization - not actual SSL + // also why nThreads does not matter + SSLFactory factory = + SSLFactory.createSslFactory( + Utils.fromCelebornConf(conf, TransportModuleConstants.RPC_SERVICE_MODULE, 1)); + + assert (null != factory); + assert (factory.hasKeyManagers()); + assert (!factory.getTrustManagers().isEmpty()); + + TrustManager trustManager = factory.getTrustManagers().get(0); + KeyManager keyManager = factory.getKeyManagers().get(0); + + Parameters params = new Parameters(); + GrpcConfigKeys.TLS.setEnabled(properties, true); + GrpcConfigKeys.TLS.setConf(params, new GrpcTlsConfig(keyManager, trustManager, true)); + + LOG.info("SSL enabled for ratis communication between masters"); + + return params; + } + private StateMachine getStateMachine() { StateMachine stateMachine = new StateMachine(this); stateMachine.setRaftGroupId(RAFT_GROUP_ID); @@ -536,6 +578,11 @@ public GroupInfoReply getGroupInfo() throws IOException { return server.getGroupInfo(groupInfoRequest); } + // Exposed for testing + public InetAddress getRaftAddress() { + return this.ratisAddr.getAddress(); + } + public int getRaftPort() { return this.ratisAddr.getPort(); } diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala index 1042f46273b..83b6f92830d 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala @@ -26,6 +26,7 @@ import scala.util.{Failure, Success, Try} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf._ import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.protocol.TransportModuleConstants case class MasterClusterInfo( localNode: MasterNode, @@ -37,6 +38,8 @@ object MasterClusterInfo extends Logging { def loadHAConfig(conf: CelebornConf): MasterClusterInfo = { val localNodeIdOpt = conf.haMasterNodeId val clusterNodeIds = conf.haMasterNodeIds + // If ssl is enabled, we enable it for ratis as well + val sslEnabled = conf.sslEnabled(TransportModuleConstants.RPC_SERVICE_MODULE) val masterNodes = clusterNodeIds.map { nodeId => val ratisHost = conf.haMasterRatisHost(nodeId) @@ -45,7 +48,7 @@ object MasterClusterInfo extends Logging { val rpcPort = conf.haMasterNodePort(nodeId) val internalPort = if (conf.internalPortEnabled) conf.haMasterNodeInternalPort(nodeId) else rpcPort - MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalPort) + MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalPort, sslEnabled) } val (localNodes, peerNodes) = localNodeIdOpt match { diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala index 0f2b09ca996..ca4ad8da293 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterNode.scala @@ -30,7 +30,8 @@ case class MasterNode( ratisPort: Int, rpcHost: String, rpcPort: Int, - internalRpcPort: Int) { + internalRpcPort: Int, + sslEnabled: Boolean) { def isRatisHostUnresolved: Boolean = ratisAddr.isUnresolved @@ -60,6 +61,7 @@ object MasterNode extends Logging { private var rpcHost: String = _ private var rpcPort = 0 private var internalRpcPort = 0 + private var sslEnabled = false def setNodeId(nodeId: String): this.type = { this.nodeId = nodeId @@ -97,8 +99,13 @@ object MasterNode extends Logging { this } + def setSslEnabled(sslEnabled: Boolean): this.type = { + this.sslEnabled = sslEnabled + this + } + def build: MasterNode = - MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalRpcPort) + MasterNode(nodeId, ratisHost, ratisPort, rpcHost, rpcPort, internalRpcPort, sslEnabled) } private def createSocketAddr(host: String, port: Int): InetSocketAddress = { diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java index 8f73071158e..340fb8e2759 100644 --- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java +++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java @@ -65,8 +65,12 @@ public class RatisMasterStatusSystemSuiteJ { protected static RpcEndpointRef mockRpcEndpoint = Mockito.mock(RpcEndpointRef.class); @BeforeClass - public static void init() throws IOException, InterruptedException { - resetRaftServer(); + public static void init() throws Exception { + resetRaftServer( + configureServerConf(new CelebornConf(), 1), + configureServerConf(new CelebornConf(), 2), + configureServerConf(new CelebornConf(), 3), + false); } private static void stopAllRaftServers() { @@ -81,7 +85,17 @@ private static void stopAllRaftServers() { } } - public static void resetRaftServer() throws IOException, InterruptedException { + static CelebornConf configureServerConf(CelebornConf conf, int id) throws IOException { + File tmpDir = File.createTempFile("celeborn-ratis" + id, "for-test-only"); + tmpDir.delete(); + tmpDir.mkdirs(); + conf.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), tmpDir.getAbsolutePath()); + return conf; + } + + public static void resetRaftServer( + CelebornConf conf1, CelebornConf conf2, CelebornConf conf3, boolean sslEnabled) + throws IOException, InterruptedException { Mockito.when(mockRpcEnv.setupEndpointRef(Mockito.any(), Mockito.any())) .thenReturn(mockRpcEndpoint); when(mockRpcEnv.setupEndpointRef(any(), any())).thenReturn(dummyRef); @@ -101,24 +115,6 @@ public static void resetRaftServer() throws IOException, InterruptedException { MetaHandler handler2 = new MetaHandler(STATUSSYSTEM2); MetaHandler handler3 = new MetaHandler(STATUSSYSTEM3); - CelebornConf conf1 = new CelebornConf(); - File tmpDir1 = File.createTempFile("celeborn-ratis1", "for-test-only"); - tmpDir1.delete(); - tmpDir1.mkdirs(); - conf1.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), tmpDir1.getAbsolutePath()); - - CelebornConf conf2 = new CelebornConf(); - File tmpDir2 = File.createTempFile("celeborn-ratis2", "for-test-only"); - tmpDir2.delete(); - tmpDir2.mkdirs(); - conf2.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), tmpDir2.getAbsolutePath()); - - CelebornConf conf3 = new CelebornConf(); - File tmpDir3 = File.createTempFile("celeborn-ratis3", "for-test-only"); - tmpDir3.delete(); - tmpDir3.mkdirs(); - conf3.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR().key(), tmpDir3.getAbsolutePath()); - String id1 = UUID.randomUUID().toString(); String id2 = UUID.randomUUID().toString(); String id3 = UUID.randomUUID().toString(); @@ -133,6 +129,7 @@ public static void resetRaftServer() throws IOException, InterruptedException { .setRatisPort(ratisPort1) .setRpcPort(ratisPort1) .setInternalRpcPort(ratisPort1) + .setSslEnabled(sslEnabled) .setNodeId(id1) .build(); MasterNode masterNode2 = @@ -141,6 +138,7 @@ public static void resetRaftServer() throws IOException, InterruptedException { .setRatisPort(ratisPort2) .setRpcPort(ratisPort2) .setInternalRpcPort(ratisPort2) + .setSslEnabled(sslEnabled) .setNodeId(id2) .build(); MasterNode masterNode3 = @@ -149,6 +147,7 @@ public static void resetRaftServer() throws IOException, InterruptedException { .setRatisPort(ratisPort3) .setRpcPort(ratisPort3) .setInternalRpcPort(ratisPort3) + .setSslEnabled(sslEnabled) .setNodeId(id3) .build(); @@ -304,7 +303,7 @@ public void testRaftSystemException() throws Exception { } catch (CelebornRuntimeException e) { Assert.assertTrue(true); } finally { - resetRaftServer(); + init(); } } diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java new file mode 100644 index 00000000000..f3d7bca0aa2 --- /dev/null +++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/SSLRatisMasterStatusSystemSuiteJ.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.master.clustermeta.ha; + +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.security.KeyPair; +import java.security.cert.X509Certificate; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; + +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.CelebornConf$; +import org.apache.celeborn.common.network.ssl.SSLFactory; +import org.apache.celeborn.common.network.ssl.SslSampleConfigs; +import org.apache.celeborn.common.protocol.TransportModuleConstants; +import org.apache.celeborn.common.util.Utils; + +public class SSLRatisMasterStatusSystemSuiteJ extends RatisMasterStatusSystemSuiteJ { + + private static final CelebornConf confWithHostPreferred = new CelebornConf(); + + static { + confWithHostPreferred.set(CelebornConf$.MODULE$.NETWORK_BIND_PREFER_IP(), false); + } + + private static class CertificateData { + final File file; + final KeyPair keyPair; + final X509Certificate cert; + + // If caData is null, we are generating for CA - else for a cert which is using the ca + // from caData + CertificateData(CertificateData caData) throws Exception { + this.file = File.createTempFile("file", ".jks"); + file.deleteOnExit(); + + this.keyPair = SslSampleConfigs.generateKeyPair("RSA"); + + // for both ca and cert, we are simply using the same machien as CN + String hostname = Utils.localHostName(confWithHostPreferred); + final String dn = "CN=" + hostname + ",O=MyCompany,C=US"; + + if (null != caData) { + this.cert = + SslSampleConfigs.generateCertificate( + dn, + keyPair, + 365, + "SHA256withRSA", + false, + new String[] {hostname}, + caData.keyPair, + caData.cert); + SslSampleConfigs.createKeyStore( + file, "password", "password", "cert", keyPair.getPrivate(), cert); + } else { + this.cert = + SslSampleConfigs.generateCertificate( + dn, keyPair, 365, "SHA256withRSA", true, null, null, null); + SslSampleConfigs.createTrustStore(file, "password", "ca", cert); + } + } + } + + private static final AtomicReference caData; + + static { + try { + caData = new AtomicReference<>(new CertificateData(null)); + } catch (Exception ex) { + throw new IllegalStateException("Unable to initialize", ex); + } + } + + @BeforeClass + public static void init() throws Exception { + + resetRaftServer( + configureSsl(caData.get(), configureServerConf(new CelebornConf(), 1)), + configureSsl(caData.get(), configureServerConf(new CelebornConf(), 2)), + configureSsl(caData.get(), configureServerConf(new CelebornConf(), 3)), + true); + } + + static CelebornConf configureSsl(CertificateData ca, CelebornConf conf) throws Exception { + conf.set("celeborn.master.ha.ratis.raft.rpc.type", "GRPC"); + + CertificateData server = new CertificateData(ca); + + final String module = TransportModuleConstants.RPC_SERVICE_MODULE; + + conf.set("celeborn.ssl." + module + ".enabled", "true"); + conf.set("celeborn.ssl." + module + ".keyStore", server.file.getAbsolutePath()); + + conf.set("celeborn.ssl." + module + ".keyStorePassword", "password"); + conf.set("celeborn.ssl." + module + ".keyPassword", "password"); + conf.set("celeborn.ssl." + module + ".privateKeyPassword", "password"); + conf.set("celeborn.ssl." + module + ".protocol", "TLSv1.2"); + conf.set("celeborn.ssl." + module + ".trustStore", ca.file.getAbsolutePath()); + conf.set("celeborn.ssl." + module + ".trustStorePassword", "password"); + + return conf; + } + + @Test + public void testSslEnabled() throws Exception { + assertTrue(isSslServer(RATISSERVER1.getRaftAddress(), RATISSERVER1.getRaftPort())); + assertTrue(isSslServer(RATISSERVER2.getRaftAddress(), RATISSERVER2.getRaftPort())); + assertTrue(isSslServer(RATISSERVER3.getRaftAddress(), RATISSERVER3.getRaftPort())); + } + + // Validate if the server listening at the port is using TLS or not. + static boolean isSslServer(InetAddress address, int port) throws Exception { + try (SSLSocket socket = createSslSocket(address, port)) { + socket.setSoTimeout(5000); + socket.startHandshake(); + // handshake succeeded, this will always return true in this case + return socket.getSession().isValid(); + } + } + + private static SSLSocket createSslSocket(InetAddress address, int port) throws Exception { + TrustManager trustStore = SSLFactory.defaultTrustManagers(caData.get().file, "password")[0]; + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, new TrustManager[] {trustStore}, null); + SSLSocketFactory factory = context.getSocketFactory(); + Socket socket = new Socket(); + socket.connect(new InetSocketAddress(address, port), 5000); + socket.setSoTimeout(5000); + + return (SSLSocket) factory.createSocket(socket, address.getHostAddress(), port, true); + } +} diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala index 4c889c63533..d75c6ba4604 100644 --- a/project/CelebornBuild.scala +++ b/project/CelebornBuild.scala @@ -528,6 +528,7 @@ object CelebornService { object CelebornMaster { lazy val master = Project("celeborn-master", file("master")) .dependsOn(CelebornCommon.common) + .dependsOn(CelebornCommon.common % "test->test;compile->compile") .dependsOn(CelebornService.service % "test->test;compile->compile") .settings ( commonSettings,