diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index ba57a44b9b5..0b5bfb4013c 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -34,6 +34,7 @@ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException} @@ -104,8 +105,16 @@ class CelebornShuffleReader[K, C]( val localFetchEnabled = conf.enableReadLocalShuffleFile val localHostAddress = Utils.localHostName(conf) val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId) - // startPartition is irrelevant - val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition) + var fileGroups: ReduceFileGroups = null + try { + // startPartition is irrelevant + fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition) + } catch { + case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => + handleFetchExceptions(shuffleId, 0, ce) + case e: Throwable => throw e + } + // host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList) val workerRequestMap = new util.HashMap[ String, @@ -245,18 +254,7 @@ class CelebornShuffleReader[K, C]( if (exceptionRef.get() != null) { exceptionRef.get() match { case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - if (throwsFetchFailure && - shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) { - throw new FetchFailedException( - null, - handle.shuffleId, - -1, - -1, - partitionId, - SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId, - ce) - } else - throw ce + handleFetchExceptions(handle.shuffleId, partitionId, ce) case e => throw e } } @@ -289,18 +287,7 @@ class CelebornShuffleReader[K, C]( iter } catch { case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - if (throwsFetchFailure && - shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) { - throw new FetchFailedException( - null, - handle.shuffleId, - -1, - -1, - partitionId, - SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId, - e) - } else - throw e + handleFetchExceptions(handle.shuffleId, partitionId, e) } } @@ -380,6 +367,22 @@ class CelebornShuffleReader[K, C]( } } + private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: Throwable) = { + if (throwsFetchFailure && + shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) { + logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce) + throw new FetchFailedException( + null, + handle.shuffleId, + -1, + -1, + partitionId, + SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId, + ce) + } else + throw ce + } + def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { dep.serializer.newInstance() } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 7489d4f49b7..00018d95d78 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -138,7 +138,7 @@ protected Compressor initialValue() { private final ReviveManager reviveManager; - protected static class ReduceFileGroups { + public static class ReduceFileGroups { public Map> partitionGroups; public int[] mapAttempts; public Set partitionIds; diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 23d6a7b8df7..1b1d8be3996 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -68,6 +68,8 @@ class ReducePartitionCommitHandler( private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]() private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int, Array[Int]]() private val stageEndTimeout = conf.clientPushStageEndTimeout + private val mockShuffleLost = conf.testMockShuffleLost + private val mockShuffleLostShuffle = conf.testMockShuffleLostShuffle private val rpcCacheSize = conf.clientRpcCacheSize private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel @@ -94,7 +96,11 @@ class ReducePartitionCommitHandler( } override def isStageDataLost(shuffleId: Int): Boolean = { - dataLostShuffleSet.contains(shuffleId) + if (mockShuffleLost) { + mockShuffleLostShuffle == shuffleId + } else { + dataLostShuffleSet.contains(shuffleId) + } } override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 9fe6ebb0e80..dc0fd4e866d 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1245,6 +1245,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def testFetchFailure: Boolean = get(TEST_CLIENT_FETCH_FAILURE) def testMockDestroySlotsFailure: Boolean = get(TEST_CLIENT_MOCK_DESTROY_SLOTS_FAILURE) def testMockCommitFilesFailure: Boolean = get(TEST_CLIENT_MOCK_COMMIT_FILES_FAILURE) + def testMockShuffleLost: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_LOST) + def testMockShuffleLostShuffle: Int = get(TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE) def testPushPrimaryDataTimeout: Boolean = get(TEST_CLIENT_PUSH_PRIMARY_DATA_TIMEOUT) def testPushReplicaDataTimeout: Boolean = get(TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT) def testRetryRevive: Boolean = get(TEST_CLIENT_RETRY_REVIVE) @@ -3716,6 +3718,26 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val TEST_CLIENT_MOCK_SHUFFLE_LOST: ConfigEntry[Boolean] = + buildConf("celeborn.test.client.mockShuffleLost") + .internal + .categories("test", "client") + .doc("Mock shuffle lost.") + .version("0.5.2") + .internal + .booleanConf + .createWithDefault(false) + + val TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE: ConfigEntry[Int] = + buildConf("celeborn.test.client.mockShuffleLostShuffle") + .internal + .categories("test", "client") + .doc("Mock shuffle lost for shuffle") + .version("0.5.2") + .internal + .intConf + .createWithDefault(0) + val CLIENT_PUSH_REPLICATE_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.client.push.replicate.enabled") .withAlternative("celeborn.push.replicate.enabled") diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala new file mode 100644 index 00000000000..8c0e8b101b4 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.protocol.ShuffleMode + +class CelebornShuffleLostSuite extends AnyFunSuite + with SparkTestBase + with BeforeAndAfterEach { + + override def beforeEach(): Unit = { + ShuffleClient.reset() + } + + override def afterEach(): Unit = { + System.gc() + } + + test("celeborn shuffle data lost - hash") { + val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]") + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val combineResult = combine(sparkSession) + val groupbyResult = groupBy(sparkSession) + val repartitionResult = repartition(sparkSession) + val sqlResult = runsql(sparkSession) + + Thread.sleep(3000L) + sparkSession.stop() + + val conf = updateSparkConf(sparkConf, ShuffleMode.HASH) + conf.set("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + conf.set("spark.celeborn.test.client.mockShuffleLost", "true") + + val celebornSparkSession = SparkSession.builder() + .config(conf) + .getOrCreate() + val celebornCombineResult = combine(celebornSparkSession) + val celebornGroupbyResult = groupBy(celebornSparkSession) + val celebornRepartitionResult = repartition(celebornSparkSession) + val celebornSqlResult = runsql(celebornSparkSession) + + assert(combineResult.equals(celebornCombineResult)) + assert(groupbyResult.equals(celebornGroupbyResult)) + assert(repartitionResult.equals(celebornRepartitionResult)) + assert(combineResult.equals(celebornCombineResult)) + assert(sqlResult.equals(celebornSqlResult)) + + celebornSparkSession.stop() + } +}