diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 1b7b6f1dd6b..f6405a6926b 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -119,23 +119,34 @@ class CelebornShuffleReader[K, C]( partCnt += 1 val hostPort = location.hostAndFetchPort if (!workerRequestMap.containsKey(hostPort)) { - val client = shuffleClient.getDataClientFactory().createClient( - location.getHost, - location.getFetchPort) - val pbOpenStreamList = PbOpenStreamList.newBuilder() - pbOpenStreamList.setShuffleKey(shuffleKey) - workerRequestMap.put( - hostPort, - (client, new util.ArrayList[PartitionLocation], pbOpenStreamList)) + try { + val client = shuffleClient.getDataClientFactory().createClient( + location.getHost, + location.getFetchPort) + val pbOpenStreamList = PbOpenStreamList.newBuilder() + pbOpenStreamList.setShuffleKey(shuffleKey) + workerRequestMap.put( + hostPort, + (client, new util.ArrayList[PartitionLocation], pbOpenStreamList)) + } catch { + case ex: Exception => + shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex) + logWarning( + s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " + + s"Shuffle reader will try its replica if exists.") + } + } + workerRequestMap.get(hostPort) match { + case (_, locArr, pbOpenStreamListBuilder) => + locArr.add(location) + pbOpenStreamListBuilder.addFileName(location.getFileName) + .addStartIndex(startMapIndex) + .addEndIndex(endMapIndex) + pbOpenStreamListBuilder.addReadLocalShuffle( + localFetchEnabled && location.getHost.equals(localHostAddress)) + case _ => + logDebug(s"Empty client for host ${hostPort}") } - val (_, locArr, pbOpenStreamListBuilder) = workerRequestMap.get(hostPort) - - locArr.add(location) - pbOpenStreamListBuilder.addFileName(location.getFileName) - .addStartIndex(startMapIndex) - .addEndIndex(endMapIndex) - pbOpenStreamListBuilder.addReadLocalShuffle( - localFetchEnabled && location.getHost.equals(localHostAddress)) } } } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 07ce7b10e88..efa9641f671 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -285,4 +285,6 @@ public abstract int getShuffleId( public abstract boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier); public abstract TransportClientFactory getDataClientFactory(); + + public abstract void excludeFailedFetchLocation(String hostAndFetchPort, Exception e); } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index b43bd95989b..ed1dabc9192 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -117,6 +117,8 @@ public class ShuffleClientImpl extends ShuffleClient { private final Set pushExcludedWorkers = ConcurrentHashMap.newKeySet(); private final ConcurrentHashMap fetchExcludedWorkers = JavaUtils.newConcurrentHashMap(); + private boolean pushReplicateEnabled; + private boolean fetchExcludeWorkerOnFailureEnabled; private final ExecutorService pushDataRetryPool; @@ -180,6 +182,8 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u pushBufferMaxSize = conf.clientPushBufferMaxSize(); pushExcludeWorkerOnFailureEnabled = conf.clientPushExcludeWorkerOnFailureEnabled(); shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE); + pushReplicateEnabled = conf.clientPushReplicateEnabled(); + fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled(); if (conf.clientPushReplicateEnabled()) { pushDataTimeout = conf.pushDataTimeoutMs() * 2; } else { @@ -1904,4 +1908,12 @@ private StatusCode getPushDataFailCause(String message) { public TransportClientFactory getDataClientFactory() { return dataClientFactory; } + + public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) { + if (pushReplicateEnabled + && fetchExcludeWorkerOnFailureEnabled + && Utils.isCriticalCauseForFetch(e)) { + fetchExcludedWorkers.put(hostAndFetchPort, System.currentTimeMillis()); + } + } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index bd0164cd6bc..dfbb7c502b8 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -159,8 +159,6 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private final boolean enabledReadLocalShuffle; private final String localHostAddress; - private boolean pushReplicateEnabled; - private boolean fetchExcludeWorkerOnFailureEnabled; private boolean shuffleCompressionEnabled; private long fetchExcludedWorkerExpireTimeout; private ConcurrentHashMap fetchExcludedWorkers; @@ -205,8 +203,6 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled(); this.enabledReadLocalShuffle = conf.enableReadLocalShuffleFile(); this.localHostAddress = Utils.localHostName(conf); - this.pushReplicateEnabled = conf.clientPushReplicateEnabled(); - this.fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled(); this.shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE); this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout(); @@ -299,12 +295,6 @@ private void moveToNextReader(boolean fetchChunk) throws IOException { } } - private void excludeFailedLocation(PartitionLocation location, Exception e) { - if (pushReplicateEnabled && fetchExcludeWorkerOnFailureEnabled && isCriticalCause(e)) { - fetchExcludedWorkers.put(location.hostAndFetchPort(), System.currentTimeMillis()); - } - } - private boolean isExcluded(PartitionLocation location) { Long timestamp = fetchExcludedWorkers.get(location.hostAndFetchPort()); if (timestamp == null) { @@ -354,7 +344,7 @@ private PartitionReader createReaderWithRetry( return createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry); } catch (Exception e) { lastException = e; - excludeFailedLocation(location, e); + shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e); fetchChunkRetryCnt++; if (location.hasPeer()) { // fetchChunkRetryCnt % 2 == 0 means both replicas have been tried, @@ -392,7 +382,8 @@ private ByteBuf getNextChunk() throws IOException { } return currentReader.next(); } catch (Exception e) { - excludeFailedLocation(currentReader.getLocation(), e); + shuffleClient.excludeFailedFetchLocation( + currentReader.getLocation().hostAndFetchPort(), e); fetchChunkRetryCnt++; currentReader.close(); if (fetchChunkRetryCnt == fetchChunkMaxRetry) { diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index 49b6b5c546a..a190c3e1bc7 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -192,6 +192,9 @@ public TransportClientFactory getDataClientFactory() { return null; } + @Override + public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) {} + public void initReducePartitionMap(int shuffleId, int numPartitions, int workerNum) { ConcurrentHashMap map = JavaUtils.newConcurrentHashMap(); String host = "host"; diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index f6709f69689..dc5b6ea346f 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -44,7 +44,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf.PORT_MAX_RETRY -import org.apache.celeborn.common.exception.CelebornException +import org.apache.celeborn.common.exception.{CelebornException, CelebornIOException} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{DiskStatus, WorkerInfo} import org.apache.celeborn.common.network.protocol.TransportMessage @@ -1343,4 +1343,16 @@ object Utils extends Logging { throw e } } + + def isCriticalCauseForFetch(e: Exception) = { + val rpcTimeout = + e.isInstanceOf[IOException] && e.getCause != null && e.getCause.isInstanceOf[TimeoutException] + val connectException = + e.isInstanceOf[CelebornIOException] && e.getMessage != null && (e.getMessage.startsWith( + "Connecting to") || e.getMessage.startsWith("Failed to")) + val fetchChunkTimeout = e.isInstanceOf[ + CelebornIOException] && e.getCause != null && e.getCause.isInstanceOf[IOException] + connectException || rpcTimeout || fetchChunkTimeout + } + }