Skip to content

Commit

Permalink
[CELEBORN-1671] CelebornShuffleReader will try replica if create clie…
Browse files Browse the repository at this point in the history
…nt failed

### What changes were proposed in this pull request?
1. To bypass exceptions when creating clients failed in CelebornShuffleReader in spark 3.
2. Client will try the location's replicas in reading locations.

### Why are the changes needed?
Allow clients to retry locations when creating clients failed.

### Does this PR introduce _any_ user-facing change?
NO.

### How was this patch tested?
Pass GA.

Closes apache#2854 from FMX/b1671.

Authored-by: mingji <[email protected]>
Signed-off-by: Shuang <[email protected]>
  • Loading branch information
FMX authored and RexXiong committed Nov 6, 2024
1 parent f2e9043 commit 7dcd259
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,34 @@ class CelebornShuffleReader[K, C](
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
val client = shuffleClient.getDataClientFactory().createClient(
location.getHost,
location.getFetchPort)
val pbOpenStreamList = PbOpenStreamList.newBuilder()
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
try {
val client = shuffleClient.getDataClientFactory().createClient(
location.getHost,
location.getFetchPort)
val pbOpenStreamList = PbOpenStreamList.newBuilder()
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
} catch {
case ex: Exception =>
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
logWarning(
s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " +
s"Shuffle reader will try its replica if exists.")
}
}
workerRequestMap.get(hostPort) match {
case (_, locArr, pbOpenStreamListBuilder) =>
locArr.add(location)
pbOpenStreamListBuilder.addFileName(location.getFileName)
.addStartIndex(startMapIndex)
.addEndIndex(endMapIndex)
pbOpenStreamListBuilder.addReadLocalShuffle(
localFetchEnabled && location.getHost.equals(localHostAddress))
case _ =>
logDebug(s"Empty client for host ${hostPort}")
}
val (_, locArr, pbOpenStreamListBuilder) = workerRequestMap.get(hostPort)

locArr.add(location)
pbOpenStreamListBuilder.addFileName(location.getFileName)
.addStartIndex(startMapIndex)
.addEndIndex(endMapIndex)
pbOpenStreamListBuilder.addReadLocalShuffle(
localFetchEnabled && location.getHost.equals(localHostAddress))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,4 +285,6 @@ public abstract int getShuffleId(
public abstract boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier);

public abstract TransportClientFactory getDataClientFactory();

public abstract void excludeFailedFetchLocation(String hostAndFetchPort, Exception e);
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ public class ShuffleClientImpl extends ShuffleClient {
private final Set<String> pushExcludedWorkers = ConcurrentHashMap.newKeySet();
private final ConcurrentHashMap<String, Long> fetchExcludedWorkers =
JavaUtils.newConcurrentHashMap();
private boolean pushReplicateEnabled;
private boolean fetchExcludeWorkerOnFailureEnabled;

private final ExecutorService pushDataRetryPool;

Expand Down Expand Up @@ -180,6 +182,8 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u
pushBufferMaxSize = conf.clientPushBufferMaxSize();
pushExcludeWorkerOnFailureEnabled = conf.clientPushExcludeWorkerOnFailureEnabled();
shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
pushReplicateEnabled = conf.clientPushReplicateEnabled();
fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled();
if (conf.clientPushReplicateEnabled()) {
pushDataTimeout = conf.pushDataTimeoutMs() * 2;
} else {
Expand Down Expand Up @@ -1904,4 +1908,12 @@ private StatusCode getPushDataFailCause(String message) {
public TransportClientFactory getDataClientFactory() {
return dataClientFactory;
}

public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) {
if (pushReplicateEnabled
&& fetchExcludeWorkerOnFailureEnabled
&& Utils.isCriticalCauseForFetch(e)) {
fetchExcludedWorkers.put(hostAndFetchPort, System.currentTimeMillis());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private final boolean enabledReadLocalShuffle;
private final String localHostAddress;

private boolean pushReplicateEnabled;
private boolean fetchExcludeWorkerOnFailureEnabled;
private boolean shuffleCompressionEnabled;
private long fetchExcludedWorkerExpireTimeout;
private ConcurrentHashMap<String, Long> fetchExcludedWorkers;
Expand Down Expand Up @@ -205,8 +203,6 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
this.enabledReadLocalShuffle = conf.enableReadLocalShuffleFile();
this.localHostAddress = Utils.localHostName(conf);
this.pushReplicateEnabled = conf.clientPushReplicateEnabled();
this.fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled();
this.shuffleCompressionEnabled =
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout();
Expand Down Expand Up @@ -299,12 +295,6 @@ private void moveToNextReader(boolean fetchChunk) throws IOException {
}
}

private void excludeFailedLocation(PartitionLocation location, Exception e) {
if (pushReplicateEnabled && fetchExcludeWorkerOnFailureEnabled && isCriticalCause(e)) {
fetchExcludedWorkers.put(location.hostAndFetchPort(), System.currentTimeMillis());
}
}

private boolean isExcluded(PartitionLocation location) {
Long timestamp = fetchExcludedWorkers.get(location.hostAndFetchPort());
if (timestamp == null) {
Expand Down Expand Up @@ -354,7 +344,7 @@ private PartitionReader createReaderWithRetry(
return createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry);
} catch (Exception e) {
lastException = e;
excludeFailedLocation(location, e);
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e);
fetchChunkRetryCnt++;
if (location.hasPeer()) {
// fetchChunkRetryCnt % 2 == 0 means both replicas have been tried,
Expand Down Expand Up @@ -392,7 +382,8 @@ private ByteBuf getNextChunk() throws IOException {
}
return currentReader.next();
} catch (Exception e) {
excludeFailedLocation(currentReader.getLocation(), e);
shuffleClient.excludeFailedFetchLocation(
currentReader.getLocation().hostAndFetchPort(), e);
fetchChunkRetryCnt++;
currentReader.close();
if (fetchChunkRetryCnt == fetchChunkMaxRetry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ public TransportClientFactory getDataClientFactory() {
return null;
}

@Override
public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) {}

public void initReducePartitionMap(int shuffleId, int numPartitions, int workerNum) {
ConcurrentHashMap<Integer, PartitionLocation> map = JavaUtils.newConcurrentHashMap();
String host = "host";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.roaringbitmap.RoaringBitmap

import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf.PORT_MAX_RETRY
import org.apache.celeborn.common.exception.CelebornException
import org.apache.celeborn.common.exception.{CelebornException, CelebornIOException}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{DiskStatus, WorkerInfo}
import org.apache.celeborn.common.network.protocol.TransportMessage
Expand Down Expand Up @@ -1343,4 +1343,16 @@ object Utils extends Logging {
throw e
}
}

def isCriticalCauseForFetch(e: Exception) = {
val rpcTimeout =
e.isInstanceOf[IOException] && e.getCause != null && e.getCause.isInstanceOf[TimeoutException]
val connectException =
e.isInstanceOf[CelebornIOException] && e.getMessage != null && (e.getMessage.startsWith(
"Connecting to") || e.getMessage.startsWith("Failed to"))
val fetchChunkTimeout = e.isInstanceOf[
CelebornIOException] && e.getCause != null && e.getCause.isInstanceOf[IOException]
connectException || rpcTimeout || fetchChunkTimeout
}

}

0 comments on commit 7dcd259

Please sign in to comment.