From cbdd43b5ec49c446309341ec420bc241c7b65f61 Mon Sep 17 00:00:00 2001 From: mingji Date: Wed, 22 Jan 2025 21:03:55 +0800 Subject: [PATCH] [CELEBORN-1838] Interrupt spark task should not report fetch failure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit What changes were proposed in this pull request? Do not trigger fetch failure if a spark task attempt is interrupted(speculation enabled). Do not trigger fetch failure if the RPC of getReducerFileGroup is timeout. This PR is intended for celeborn-0.5 branch. Why are the changes needed? Avoid unnecessary fetch failures and stage re-runs. Does this PR introduce any user-facing change? NO. How was this patch tested? 1. GA. 2. Manually tested on cluster with spark speculation tasks. Here is the test case ```scala sc.parallelize(1 to 100, 100).flatMap(i => { (1 to 150000).iterator.map(num => num) }).groupBy(i => i, 100) .map(i => { if (i._1 < 5) { Thread.sleep(15000) } i }) .repartition(400).count ``` 截屏2025-01-18 16 16 16 截屏2025-01-18 16 16 22 截屏2025-01-18 16 19 15 截屏2025-01-18 16 17 27 Closes #3070 from FMX/branch-0.5-b1838. Authored-by: mingji Signed-off-by: Wang, Fei --- .../readclient/FlinkShuffleClientImpl.java | 14 +- client-spark/spark-3/pom.xml | 5 + .../celeborn/CelebornShuffleReader.scala | 33 ++++- .../celeborn/CelebornShuffleReaderSuite.scala | 95 ++++++++++++ .../celeborn/client/DummyShuffleClient.java | 4 + .../celeborn/client/ShuffleClientImpl.java | 28 ++-- .../celeborn/client/ShuffleClientSuiteJ.java | 136 +++++++++++++++++- project/CelebornBuild.scala | 2 +- ...te.scala => CelebornStageRerunSuite.scala} | 4 +- 9 files changed, 293 insertions(+), 28 deletions(-) create mode 100644 client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala rename client/src/{test => main}/java/org/apache/celeborn/client/DummyShuffleClient.java (97%) rename tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/{CelebornShuffleLostSuite.scala => CelebornStageRerunSuite.scala} (96%) diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java index c7b7971a784..feffc7fd1e3 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import scala.Tuple2; +import scala.Tuple3; import scala.reflect.ClassTag$; import com.google.common.annotations.VisibleForTesting; @@ -195,9 +195,9 @@ public CelebornBufferStream readBufferedPartition( public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId) throws CelebornIOException { ReduceFileGroups reduceFileGroups = - reduceFileGroupsMap.computeIfAbsent( - shuffleId, (id) -> Tuple2.apply(new ReduceFileGroups(), null)) - ._1; + reduceFileGroupsMap + .computeIfAbsent(shuffleId, (id) -> Tuple3.apply(new ReduceFileGroups(), null, null)) + ._1(); if (reduceFileGroups.partitionIds != null && reduceFileGroups.partitionIds.contains(partitionId)) { logger.debug( @@ -211,11 +211,11 @@ public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId) Utils.makeReducerKey(shuffleId, partitionId)); } else { // refresh file groups - Tuple2 fileGroups = loadFileGroupInternal(shuffleId); - ReduceFileGroups newGroups = fileGroups._1; + Tuple3 fileGroups = loadFileGroupInternal(shuffleId); + ReduceFileGroups newGroups = fileGroups._1(); if (newGroups == null) { throw new CelebornIOException( - loadFileGroupException(shuffleId, partitionId, fileGroups._2)); + loadFileGroupException(shuffleId, partitionId, fileGroups._2())); } else if (!newGroups.partitionIds.contains(partitionId)) { throw new CelebornIOException( String.format( diff --git a/client-spark/spark-3/pom.xml b/client-spark/spark-3/pom.xml index 4acacbed0b7..727a3e79a3f 100644 --- a/client-spark/spark-3/pom.xml +++ b/client-spark/spark-3/pom.xml @@ -91,5 +91,10 @@ mockito-core test + + org.mockito + mockito-inline + test + diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index f4edc0f2ff3..52eae2ff26e 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,12 +18,14 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException +import java.nio.file.Files import java.util -import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ +import com.google.common.annotations.VisibleForTesting import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging @@ -33,14 +35,14 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient} import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException} import org.apache.celeborn.common.network.client.TransportClient import org.apache.celeborn.common.network.protocol.TransportMessage -import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation, PbOpenStreamList, PbOpenStreamListResponse, PbStreamHandler} +import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.util.{ExceptionMaker, JavaUtils, ThreadUtils, Utils} @@ -57,7 +59,9 @@ class CelebornShuffleReader[K, C]( extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency - private val shuffleClient = ShuffleClient.get( + + @VisibleForTesting + val shuffleClient = ShuffleClient.get( handle.appUniqueId, handle.lifecycleManagerHost, handle.lifecycleManagerPort, @@ -111,7 +115,9 @@ class CelebornShuffleReader[K, C]( fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition) } catch { case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - handleFetchExceptions(handle.shuffleId, shuffleId, 0, ce) + // if a task is interrupted, should not report fetch failure + // if a task update file group timeout, should not report fetch failure + checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce) case e: Throwable => throw e } @@ -369,7 +375,22 @@ class CelebornShuffleReader[K, C]( } } - private def handleFetchExceptions( + @VisibleForTesting + def checkAndReportFetchFailureForUpdateFileGroupFailure( + celebornShuffleId: Int, + ce: Throwable): Unit = { + if (ce.getCause != null && + (ce.getCause.isInstanceOf[InterruptedException] || ce.getCause.isInstanceOf[ + TimeoutException])) { + logWarning(s"fetch shuffle ${celebornShuffleId} timeout or interrupt", ce) + throw ce + } else { + handleFetchExceptions(handle.shuffleId, celebornShuffleId, 0, ce) + } + } + + @VisibleForTesting + def handleFetchExceptions( appShuffleId: Int, shuffleId: Int, partitionId: Int, diff --git a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala new file mode 100644 index 00000000000..29878fd76c0 --- /dev/null +++ b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import java.nio.file.Files +import java.util.concurrent.TimeoutException + +import org.apache.spark.{Dependency, ShuffleDependency, TaskContext} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito +import org.mockito.Mockito._ +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient} +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.exception.CelebornIOException +import org.apache.celeborn.common.identity.UserIdentifier + +class CelebornShuffleReaderSuite extends AnyFunSuite { + + /** + * Due to spark limitations, spark local mode can not test speculation tasks , + * test the method `checkAndReportFetchFailureForUpdateFileGroupFailure` + */ + test("CELEBORN-1838 test check report fetch failure exceptions ") { + val dependency = Mockito.mock(classOf[ShuffleDependency[Int, Int, Int]]) + val handler = new CelebornShuffleHandle[Int, Int, Int]( + "APP", + "HOST1", + 1, + UserIdentifier.apply("a", "b"), + 0, + true, + 1, + dependency) + val context = Mockito.mock(classOf[TaskContext]) + val metricReporter = Mockito.mock(classOf[ShuffleReadMetricsReporter]) + val conf = new CelebornConf() + + val tmpFile = Files.createTempFile("test", ".tmp").toFile + mockStatic(classOf[ShuffleClient]).when(() => + ShuffleClient.get(any(), any(), any(), any(), any(), any())).thenReturn( + new DummyShuffleClient(conf, tmpFile)) + + val shuffleReader = + new CelebornShuffleReader[Int, Int](handler, 0, 0, 0, 0, context, conf, metricReporter, null) + + val exception1: Throwable = new CelebornIOException("test1", new InterruptedException("test1")) + val exception2: Throwable = new CelebornIOException("test2", new TimeoutException("test2")) + val exception3: Throwable = new CelebornIOException("test3") + val exception4: Throwable = new CelebornIOException("test4") + + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception1) + } catch { + case _: Throwable => + } + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception2) + } catch { + case _: Throwable => + } + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception3) + } catch { + case _: Throwable => + } + assert( + shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 1) + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception4) + } catch { + case _: Throwable => + } + assert( + shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 2) + + } +} diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java similarity index 97% rename from client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java rename to client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index a190c3e1bc7..e2f60aee928 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,6 +56,8 @@ public class DummyShuffleClient extends ShuffleClient { private final Map> reducePartitionMap = new HashMap<>(); + public AtomicInteger fetchFailureCount = new AtomicInteger(); + public DummyShuffleClient(CelebornConf conf, File file) throws Exception { this.os = new BufferedOutputStream(new FileOutputStream(file)); this.conf = conf; @@ -180,6 +183,7 @@ public int getShuffleId( @Override public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) { + fetchFailureCount.incrementAndGet(); return true; } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 00018d95d78..60704de8e0a 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -26,6 +26,7 @@ import java.util.concurrent.TimeUnit; import scala.Tuple2; +import scala.Tuple3; import scala.reflect.ClassTag$; import com.google.common.annotations.VisibleForTesting; @@ -166,7 +167,7 @@ public void update(ReduceFileGroups fileGroups) { } // key: shuffleId - protected final Map> reduceFileGroupsMap = + protected final Map> reduceFileGroupsMap = JavaUtils.newConcurrentHashMap(); public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier userIdentifier) { @@ -1633,10 +1634,11 @@ public boolean cleanupShuffle(int shuffleId) { return true; } - protected Tuple2 loadFileGroupInternal(int shuffleId) { + protected Tuple3 loadFileGroupInternal(int shuffleId) { { long getReducerFileGroupStartTime = System.nanoTime(); String exceptionMsg = null; + Exception exception = null; try { if (lifecycleManagerRef == null) { exceptionMsg = "Driver endpoint is null!"; @@ -1657,9 +1659,10 @@ protected Tuple2 loadFileGroupInternal(int shuffleId) shuffleId, TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime), response.fileGroup().size()); - return Tuple2.apply( + return Tuple3.apply( new ReduceFileGroups( response.fileGroup(), response.attempts(), response.partitionIds()), + null, null); case SHUFFLE_NOT_REGISTERED: logger.warn( @@ -1668,9 +1671,10 @@ protected Tuple2 loadFileGroupInternal(int shuffleId) response.status(), shuffleId); // return empty result - return Tuple2.apply( + return Tuple3.apply( new ReduceFileGroups( response.fileGroup(), response.attempts(), response.partitionIds()), + null, null); case STAGE_END_TIME_OUT: case SHUFFLE_DATA_LOST: @@ -1684,28 +1688,30 @@ protected Tuple2 loadFileGroupInternal(int shuffleId) } catch (Exception e) { logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e); exceptionMsg = e.getMessage(); + exception = e; } - return Tuple2.apply(null, exceptionMsg); + return Tuple3.apply(null, exceptionMsg, exception); } } public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId) throws CelebornIOException { - Tuple2 fileGroupTuple = + Tuple3 fileGroupTuple = reduceFileGroupsMap.compute( shuffleId, (id, existsTuple) -> { - if (existsTuple == null || existsTuple._1 == null) { + if (existsTuple == null || existsTuple._1() == null) { return loadFileGroupInternal(shuffleId); } else { return existsTuple; } }); - if (fileGroupTuple._1 == null) { + if (fileGroupTuple._1() == null) { throw new CelebornIOException( - loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2))); + loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2())), + fileGroupTuple._3()); } else { - return fileGroupTuple._1; + return fileGroupTuple._1(); } } @@ -1774,7 +1780,7 @@ public CelebornInputStream readPartition( } @VisibleForTesting - public Map> getReduceFileGroupsMap() { + public Map> getReduceFileGroupsMap() { return reduceFileGroupsMap; } diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index e3483fa574f..0457de405ed 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -25,13 +25,20 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Assert; import org.junit.Test; import org.apache.celeborn.client.compress.Compressor; @@ -42,9 +49,11 @@ import org.apache.celeborn.common.network.client.TransportClientFactory; import org.apache.celeborn.common.protocol.CompressionCodec; import org.apache.celeborn.common.protocol.PartitionLocation; -import org.apache.celeborn.common.protocol.message.ControlMessages.*; +import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$; +import org.apache.celeborn.common.protocol.message.ControlMessages.RegisterShuffleResponse$; import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.rpc.RpcEndpointRef; +import org.apache.celeborn.common.rpc.RpcTimeoutException; public class ShuffleClientSuiteJ { @@ -384,4 +393,129 @@ public Void get(long timeout, TimeUnit unit) { shuffleClient.dataClientFactory = clientFactory; return conf; } + + @Test + public void testUpdateReducerFileGroupInterrupted() throws InterruptedException { + CelebornConf conf = new CelebornConf(); + conf.set("celeborn.client.spark.stageRerun.enabled", "true"); + Map> locations = new HashMap<>(); + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + Thread.sleep(60 * 1000); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference exceptionRef = new AtomicReference<>(); + Thread thread = + new Thread( + new Runnable() { + @Override + public void run() { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + exceptionRef.set(e); + } + } + }); + + thread.start(); + Thread.sleep(1000); + thread.interrupt(); + Thread.sleep(1000); + + Exception exception = exceptionRef.get(); + Assert.assertTrue(exception.getCause() instanceof InterruptedException); + } + + @Test + public void testUpdateReducerFileGroupNonFetchFailureExceptions() { + CelebornConf conf = new CelebornConf(); + conf.set("celeborn.client.spark.stageRerun.enabled", "true"); + Map> locations = new HashMap<>(); + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SHUFFLE_NOT_REGISTERED, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + Assert.assertTrue(e.getCause() == null); + } + + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.STAGE_END_TIME_OUT, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + Assert.assertTrue(e.getCause() == null); + } + + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SHUFFLE_DATA_LOST, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + Assert.assertTrue(e.getCause() == null); + } + } + + @Test + public void testUpdateReducerFileGroupTimeout() throws InterruptedException { + CelebornConf conf = new CelebornConf(); + conf.set("celeborn.client.rpc.getReducerFileGroup.askTimeout", "1ms"); + + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + invocation -> { + throw new RpcTimeoutException( + "Rpc timeout", new TimeoutException("ask sync timeout")); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference exceptionRef = new AtomicReference<>(); + + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + exceptionRef.set(e); + } + + Exception exception = exceptionRef.get(); + Assert.assertTrue(exception.getCause() instanceof TimeoutException); + } } diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala index e5350706f3e..191d74b2264 100644 --- a/project/CelebornBuild.scala +++ b/project/CelebornBuild.scala @@ -772,7 +772,7 @@ trait SparkClientProjects { libraryDependencies ++= Seq( "org.apache.spark" %% "spark-core" % sparkVersion % "provided", "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", - ) ++ commonUnitTestDependencies + ) ++ commonUnitTestDependencies ++ Seq(Dependencies.mockitoInline % "test") ) } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornStageRerunSuite.scala similarity index 96% rename from tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala rename to tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornStageRerunSuite.scala index 8c0e8b101b4..91aee9db189 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornStageRerunSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.protocol.ShuffleMode -class CelebornShuffleLostSuite extends AnyFunSuite +class CelebornStageRerunSuite extends AnyFunSuite with SparkTestBase with BeforeAndAfterEach { @@ -37,7 +37,7 @@ class CelebornShuffleLostSuite extends AnyFunSuite System.gc() } - test("celeborn shuffle data lost - hash") { + test("stage rerun for data lost - hash") { val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]") val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() val combineResult = combine(sparkSession)