From cf39d486dddf49e29baa5dea78ca8a539ac74e23 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 16:58:02 +0900 Subject: [PATCH 01/17] calculate processed block size in outputwriter --- .../executor/datatransfer/BlockOutputWriter.java | 14 +++++++++++++- .../executor/datatransfer/OutputWriter.java | 6 ++++++ .../executor/datatransfer/PipeOutputWriter.java | 6 ++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java index 97cde037c6..b9cf9ef129 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter { private long writtenBytes; + private Optional> partitionSizeMap; + /** * Constructor. * @@ -109,7 +111,7 @@ public void close() { final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge .getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new); - final Optional> partitionSizeMap = blockToWrite.commit(); + partitionSizeMap = blockToWrite.commit(); // Return the total size of the committed block. if (partitionSizeMap.isPresent()) { long blockSizeTotal = 0; @@ -123,6 +125,16 @@ public void close() { blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence); } + @Override + public Optional> getPartitionSizeMap() { + if (partitionSizeMap.isPresent()) { + return partitionSizeMap; + } else { + return Optional.empty(); + } + } + + @Override public Optional getWrittenBytes() { if (writtenBytes == -1) { return Optional.empty(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java index bf6ff84e69..a1862f5f2d 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.punctuation.Watermark; +import java.util.Map; import java.util.Optional; /** @@ -45,5 +46,10 @@ public interface OutputWriter { */ Optional getWrittenBytes(); + /** + * @return the map of hashed key to partition size. + */ + Optional> getPartitionSizeMap(); + void close(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java index 544d64d921..d0025428aa 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -113,6 +114,11 @@ public Optional getWrittenBytes() { return Optional.empty(); } + @Override + public Optional> getPartitionSizeMap() { + return Optional.empty(); + } + @Override public void close() { if (!initialized) { From 1b94e85c262f329ef54a045baa0adf40ba82f488 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:15:15 +0900 Subject: [PATCH 02/17] modify datafetcher to gather trace of the serialized read bytes --- .../runtime/executor/task/DataFetcher.java | 16 +++++++ .../MultiThreadParentTaskDataFetcher.java | 6 +++ .../executor/task/ParentTaskDataFetcher.java | 45 +++++++++++++++++++ .../task/SourceVertexDataFetcher.java | 7 +++ 4 files changed, 74 insertions(+) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java index 7af08852eb..b1a828c13c 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.io.IOException; @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable { */ abstract Object fetchDataElement() throws IOException; + /** + * Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore + * on every iterator advance. + * This method is for WorkStealing implementation in Nemo. + * + * @param taskId task id + * @param metricMessageSender metricMessageSender + * + * @return data element + * @throws IOException upon I/O error + * @throws java.util.NoSuchElementException if no more element is available + */ + abstract Object fetchDataElementWithTrace(String taskId, + MetricMessageSender metricMessageSender) throws IOException; + OutputCollector getOutputCollector() { return outputCollector; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index 797818ce44..c1361cb125 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -22,6 +22,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.*; import org.slf4j.Logger; @@ -100,6 +101,11 @@ Object fetchDataElement() throws IOException { } } + @Override + Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) throws IOException { + return fetchDataElement(); + } + private void fetchDataLazily() { final List> futures = readersForParentTask.read(); numOfIterators = futures.size(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index a8ae4a9306..4c376ff6b9 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -18,10 +18,12 @@ */ package org.apache.nemo.runtime.executor.task; +import org.apache.commons.lang3.SerializationUtils; import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.InputReader; import org.slf4j.Logger; @@ -100,6 +102,49 @@ Object fetchDataElement() throws IOException { return Finishmark.getInstance(); } + @Override + Object fetchDataElementWithTrace(String taskId, + MetricMessageSender metricMessageSender) throws IOException { + try { + if (firstFetch) { + fetchDataLazily(); + advanceIterator(); + firstFetch = false; + } + + while (true) { + // This iterator has the element + if (this.currentIterator.hasNext()) { + return this.currentIterator.next(); + } + + // This iterator does not have the element + if (currentIteratorIndex < expectedNumOfIterators) { + // Next iterator has the element + countBytes(currentIterator); + // Send the cumulative serBytes to MetricStore + metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes", + SerializationUtils.serialize(serBytes)); + advanceIterator(); + continue; + } else { + // We've consumed all the iterators + break; + } + + } + } catch (final Throwable e) { + // Any failure is caught and thrown as an IOException, so that the task is retried. + // In particular, we catch unchecked exceptions like RuntimeException thrown by DataUtil.IteratorWithNumBytes + // when remote data fetching fails for whatever reason. + // Note that we rely on unchecked exceptions because the Iterator interface does not provide the standard + // "throw Exception" that the TaskExecutor thread can catch and handle. + throw new IOException(e); + } + + return Finishmark.getInstance(); + } + private void advanceIterator() throws IOException { // Take from iteratorQueue final Object iteratorOrThrowable; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 2d82898d7a..8ac8c27eee 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -23,7 +23,9 @@ import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; +import java.io.IOException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -74,6 +76,11 @@ Object fetchDataElement() { } } + @Override + Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) { + return fetchDataElement(); + } + final long getBoundedSourceReadTime() { return boundedSourceReadTime; } From 5f53288749d3383eae2a6045c3513eeef4f10379 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:28:23 +0900 Subject: [PATCH 03/17] handle checkstyle --- .../executor/task/MultiThreadParentTaskDataFetcher.java | 3 ++- .../nemo/runtime/executor/task/ParentTaskDataFetcher.java | 4 ++-- .../nemo/runtime/executor/task/SourceVertexDataFetcher.java | 3 +-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index c1361cb125..d7947e8c78 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -102,7 +102,8 @@ Object fetchDataElement() throws IOException { } @Override - Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) throws IOException { + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { return fetchDataElement(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index 4c376ff6b9..3a92cbc8a9 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -103,8 +103,8 @@ Object fetchDataElement() throws IOException { } @Override - Object fetchDataElementWithTrace(String taskId, - MetricMessageSender metricMessageSender) throws IOException { + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { try { if (firstFetch) { fetchDataLazily(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 8ac8c27eee..68a3362d27 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -25,7 +25,6 @@ import org.apache.nemo.common.punctuation.Watermark; import org.apache.nemo.runtime.executor.MetricMessageSender; -import java.io.IOException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -77,7 +76,7 @@ Object fetchDataElement() { } @Override - Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) { + Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) { return fetchDataElement(); } From aed78c66f1a75d0756fe39b211d772351841df8f Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:33:37 +0900 Subject: [PATCH 04/17] replace fetchDataElement with fetchDataElementWithTrace in TaskExecutor --- .../org/apache/nemo/runtime/executor/task/TaskExecutor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 2bf574d396..fc3cc4e8b8 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -458,7 +458,7 @@ private boolean handleDataFetchers(final List fetchers) { while (availableIterator.hasNext()) { final DataFetcher dataFetcher = availableIterator.next(); try { - final Object element = dataFetcher.fetchDataElement(); + final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender); onEventFromDataFetcher(element, dataFetcher); if (element instanceof Finishmark) { availableIterator.remove(); From e4401f05291527d6ff4848ca3d0976b926667264 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:50:54 +0900 Subject: [PATCH 05/17] add work stealing thread in runtime master --- .../apache/nemo/runtime/master/RuntimeMaster.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..4c4b5bb35e 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -85,9 +85,11 @@ public final class RuntimeMaster { private static final int METRIC_ARRIVE_TIMEOUT = 10000; private static final int REST_SERVER_PORT = 10101; private static final int SPECULATION_CHECKING_PERIOD_MS = 100; + private static final int WORK_STEALING_CHECKING_PERIOD_MS = 100; private final ExecutorService runtimeMasterThread; private final ScheduledExecutorService speculativeTaskCloningThread; + private final ScheduledExecutorService workStealingThread; private final Scheduler scheduler; private final ContainerManager containerManager; @@ -160,6 +162,16 @@ private RuntimeMaster(final Scheduler scheduler, SPECULATION_CHECKING_PERIOD_MS, TimeUnit.MILLISECONDS); + // Check for work stealing every second + this.workStealingThread = Executors + .newSingleThreadScheduledExecutor(runnable -> new Thread(runnable, "WorkStealing master thread")); + this.workStealingThread.scheduleWithFixedDelay( + () -> this.runtimeMasterThread.submit(scheduler::onWorkStealingCheck), + WORK_STEALING_CHECKING_PERIOD_MS, + WORK_STEALING_CHECKING_PERIOD_MS, + TimeUnit.MILLISECONDS); + + this.scheduler = scheduler; this.containerManager = containerManager; this.executorRegistry = executorRegistry; From 398ff94c23e052381193dfa7cb9e0ef9b5e3b653 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 18:14:03 +0900 Subject: [PATCH 06/17] send accumulated KV statistics to nemo driver --- .../src/main/proto/ControlMessage.proto | 7 +++ .../runtime/executor/task/TaskExecutor.java | 50 +++++++++++++++++++ .../nemo/runtime/master/RuntimeMaster.java | 14 ++++-- .../master/scheduler/BatchScheduler.java | 21 ++++++++ 4 files changed, 88 insertions(+), 4 deletions(-) diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 97e30fb4e7..3d43f10c64 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -86,6 +86,7 @@ enum MessageType { PipeInit = 13; RequestPipeLoc = 14; PipeLocInfo = 15; + ParentTaskDataCollected = 16; } message Message { @@ -107,6 +108,7 @@ message Message { optional PipeInitMessage pipeInitMsg = 16; optional RequestPipeLocationMessage requestPipeLocMsg = 17; optional PipeLocationInfoMessage pipeLocInfoMsg = 18; + optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; } // Messages from Master to Executors @@ -256,3 +258,8 @@ message PipeLocationInfoMessage { required int64 requestId = 1; // To find the matching request msg required string executorId = 2; } + +message ParentTaskDataCollectMsg { + required string taskId = 1; + required bytes partitionSizeMap = 2; +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index fc3cc4e8b8..758e32212e 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -19,6 +19,7 @@ package org.apache.nemo.runtime.executor.task; import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; import org.apache.commons.lang3.SerializationUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.nemo.common.Pair; @@ -688,12 +689,21 @@ public void setIRVertexPutOnHold(final IRVertex irVertex) { */ private void finalizeOutputWriters(final VertexHarness vertexHarness) { final List writtenBytesList = new ArrayList<>(); + final HashMap partitionSizeMap = new HashMap<>(); // finalize OutputWriters for main children vertexHarness.getWritersToMainChildrenTasks().forEach(outputWriter -> { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }); // finalize OutputWriters for additional tagged children @@ -702,6 +712,14 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }) ); @@ -713,5 +731,37 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { // TODO #236: Decouple metric collection and sending logic metricMessageSender.send(TASK_METRIC_ID, taskId, "taskOutputBytes", SerializationUtils.serialize(totalWrittenBytes)); + + if (!partitionSizeMap.isEmpty()) { + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.ParentTaskDataCollected) + .setParentTaskDataCollected(ControlMessage.ParentTaskDataCollectMsg.newBuilder() + .setTaskId(taskId) + .setPartitionSizeMap(ByteString.copyFrom(SerializationUtils.serialize(partitionSizeMap))) + .build()) + .build()); + } + } + + /** + * Gather the KV statistics of processed data. + * This method is for work stealing implementation. + * + * @param totalPartitionSizeMap accumulated partitionSizeMap of task. + * @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter. + */ + private void computePartitionSizeMap(final Map totalPartitionSizeMap, + final Map singlePartitionSizeMap) { + for (Integer hashedKey : singlePartitionSizeMap.keySet()) { + final Long partitionSize = singlePartitionSizeMap.get(hashedKey); + if (totalPartitionSizeMap.containsKey(hashedKey)) { + totalPartitionSizeMap.compute(hashedKey, (existingKey, existingValue) -> existingValue + partitionSize); + } else { + totalPartitionSizeMap.put(hashedKey, partitionSize); + } + } } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..dd6b144e2a 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -58,10 +58,7 @@ import javax.inject.Inject; import java.io.Serializable; import java.nio.file.Paths; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; @@ -481,6 +478,15 @@ private void handleControlMessage(final ControlMessage.Message message) { .setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(serializedData).build()) .build()); break; + case ParentTaskDataCollected: + if (scheduler instanceof BatchScheduler) { + final ControlMessage.ParentTaskDataCollectMsg workStealingMsg = message.getParentTaskDataCollected(); + final String taskId = workStealingMsg.getTaskId(); + final Map partitionSizeMap = SerializationUtils + .deserialize(workStealingMsg.getPartitionSizeMap().toByteArray()); + ((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap); + } + break; case MetricFlushed: metricCountDownLatch.countDown(); break; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 086c9d08bd..23ed8326e7 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -77,6 +77,11 @@ public final class BatchScheduler implements Scheduler { */ private List> sortedScheduleGroups; // Stages, sorted in the order to be scheduled. + /** + * Data Structures for work stealing. + */ + private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); + @Inject private BatchScheduler(final PlanRewriter planRewriter, final TaskDispatcher taskDispatcher, @@ -383,4 +388,20 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, return false; } + + // Methods for work stealing + public void aggregateStageIdToPartitionSizeMap(final String taskId, + final Map partitionSizeMap) { + final Map partitionSizeMapForThisStage = stageIdToOutputPartitionSizeMap + .getOrDefault(RuntimeIdManager.getStageIdFromTaskId(taskId), new HashMap<>()); + for (Integer hashedKey : partitionSizeMap.keySet()) { + final Long partitionSize = partitionSizeMap.get(hashedKey); + if (partitionSizeMapForThisStage.containsKey(hashedKey)) { + partitionSizeMapForThisStage.put(hashedKey, partitionSize + partitionSizeMapForThisStage.get(hashedKey)); + } else { + partitionSizeMapForThisStage.put(hashedKey, partitionSize); + } + } + stageIdToOutputPartitionSizeMap.put(RuntimeIdManager.getStageIdFromTaskId(taskId), partitionSizeMapForThisStage); + } } From 9e2d04775f8c072be7ecbc334e64745740756766 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 18:35:42 +0900 Subject: [PATCH 07/17] track the processed bytes of the current stage: send it to driver --- .../src/main/proto/ControlMessage.proto | 7 ++++++ .../runtime/executor/task/TaskExecutor.java | 24 +++++++++++++++++-- .../nemo/runtime/master/RuntimeMaster.java | 7 ++++++ .../master/scheduler/BatchScheduler.java | 20 ++++++++++++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 3d43f10c64..53adea3d3a 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -87,6 +87,7 @@ enum MessageType { RequestPipeLoc = 14; PipeLocInfo = 15; ParentTaskDataCollected = 16; + CurrentlyProcessedBytesCollected = 17; } message Message { @@ -109,6 +110,7 @@ message Message { optional RequestPipeLocationMessage requestPipeLocMsg = 17; optional PipeLocationInfoMessage pipeLocInfoMsg = 18; optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; + optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20; } // Messages from Master to Executors @@ -263,3 +265,8 @@ message ParentTaskDataCollectMsg { required string taskId = 1; required bytes partitionSizeMap = 2; } + +message CurrentlyProcessedBytesCollectMsg { + required string taskId = 1; + required int64 processedDataBytes = 2; +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 758e32212e..91e8212640 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -746,9 +746,11 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { } } + // Methods for work stealing /** - * Gather the KV statistics of processed data. - * This method is for work stealing implementation. + * Gather the KV statistics of processed data when execution is completed. + * This method is for work stealing implementation: the accumulated statistics will be used to + * detect skewed tasks of the child stage. * * @param totalPartitionSizeMap accumulated partitionSizeMap of task. * @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter. @@ -764,4 +766,22 @@ private void computePartitionSizeMap(final Map totalPartitionSize } } } + + /** + * Send the temporally processed bytes of the current task on request from the scheduler. + * This method is for work stealing implementation. + */ + public void onRequestForProcessedData() { + LOG.error("{}, bytes {}, replying for the request", taskId, serializedReadBytes); + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.CurrentlyProcessedBytesCollected) + .setCurrentlyProcessedBytesCollected(ControlMessage.CurrentlyProcessedBytesCollectMsg.newBuilder() + .setTaskId(this.taskId) + .setProcessedDataBytes(serializedReadBytes) + .build()) + .build()); + } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index dd6b144e2a..40fb5e86fc 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -487,6 +487,13 @@ private void handleControlMessage(final ControlMessage.Message message) { ((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap); } break; + case CurrentlyProcessedBytesCollected: + if (scheduler instanceof BatchScheduler) { + ((BatchScheduler) scheduler).aggregateTaskIdToProcessedBytes( + message.getCurrentlyProcessedBytesCollected().getTaskId(), + message.getCurrentlyProcessedBytesCollected().getProcessedDataBytes() + ); + } case MetricFlushed: metricCountDownLatch.countDown(); break; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 23ed8326e7..8941aa093c 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -81,6 +81,7 @@ public final class BatchScheduler implements Scheduler { * Data Structures for work stealing. */ private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); + private final Map taskIdToProcessedBytes = new HashMap<>(); @Inject private BatchScheduler(final PlanRewriter planRewriter, @@ -390,6 +391,14 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, } // Methods for work stealing + + /** + * Accumulate the execution result of each stage in Map[STAGE ID, Map[KEY, SIZE]] format. + * KEY is assumed to be Integer because of the HashPartition. + * + * @param taskId id of task to accumulate. + * @param partitionSizeMap map of (K) - (partition size) of the task. + */ public void aggregateStageIdToPartitionSizeMap(final String taskId, final Map partitionSizeMap) { final Map partitionSizeMapForThisStage = stageIdToOutputPartitionSizeMap @@ -404,4 +413,15 @@ public void aggregateStageIdToPartitionSizeMap(final String taskId, } stageIdToOutputPartitionSizeMap.put(RuntimeIdManager.getStageIdFromTaskId(taskId), partitionSizeMapForThisStage); } + + /** + * Store the tracked processed bytes per task by the current time. + * + * @param taskId id of task to track. + * @param processedBytes size of the processed bytes till now. + */ + public void aggregateTaskIdToProcessedBytes(final String taskId, + final long processedBytes) { + taskIdToProcessedBytes.put(taskId, processedBytes); + } } From f88cd30adde973b6e310258a300809542818f6b3 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 14:28:49 +0900 Subject: [PATCH 08/17] check work stealing on scheduler --- .../master/scheduler/BatchScheduler.java | 200 +++++++++++++++++- .../runtime/master/scheduler/Scheduler.java | 5 + .../master/scheduler/SimulationScheduler.java | 6 + .../master/scheduler/StreamingScheduler.java | 5 + 4 files changed, 215 insertions(+), 1 deletion(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 8941aa093c..f41b359f7b 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -20,16 +20,19 @@ import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.nemo.common.Pair; +import org.apache.nemo.common.dag.Vertex; import org.apache.nemo.common.exception.UnknownExecutionStateException; import org.apache.nemo.common.exception.UnrecoverableFailureException; import org.apache.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty; import org.apache.nemo.runtime.common.RuntimeIdManager; +import org.apache.nemo.runtime.common.metric.TaskMetric; import org.apache.nemo.runtime.common.plan.*; import org.apache.nemo.runtime.common.state.StageState; import org.apache.nemo.runtime.common.state.TaskState; import org.apache.nemo.runtime.master.BlockManagerMaster; import org.apache.nemo.runtime.master.PlanAppender; import org.apache.nemo.runtime.master.PlanStateManager; +import org.apache.nemo.runtime.master.metric.MetricStore; import org.apache.nemo.runtime.master.resource.ExecutorRepresenter; import org.apache.reef.annotations.audience.DriverSide; import org.slf4j.Logger; @@ -66,6 +69,7 @@ public final class BatchScheduler implements Scheduler { private final PendingTaskCollectionPointer pendingTaskCollectionPointer; // A 'pointer' to the list of pending tasks. private final ExecutorRegistry executorRegistry; // A registry for executors available for the job. private final PlanStateManager planStateManager; // A component that manages the state of the plan. + private final MetricStore metricStore = MetricStore.getStore(); /** * Other necessary components of this {@link org.apache.nemo.runtime.master.RuntimeMaster}. @@ -80,8 +84,10 @@ public final class BatchScheduler implements Scheduler { /** * Data Structures for work stealing. */ + private final Set workStealingCandidates = new HashSet<>(); private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); private final Map taskIdToProcessedBytes = new HashMap<>(); + private final Map stageIdToWorkStealingExecuted = new HashMap<>(); @Inject private BatchScheduler(final PlanRewriter planRewriter, @@ -117,6 +123,11 @@ public void updatePlan(final PhysicalPlan newPhysicalPlan) { private void updatePlan(final PhysicalPlan newPhysicalPlan, final int maxScheduleAttempt) { planStateManager.updatePlan(newPhysicalPlan, maxScheduleAttempt); + + for (Stage stage : planStateManager.getPhysicalPlan().getStageDAG().getVertices()) { + stageIdToWorkStealingExecuted.putIfAbsent(stage.getId(), false); + } + this.sortedScheduleGroups = newPhysicalPlan.getStageDAG().getVertices().stream() .collect(Collectors.groupingBy(Stage::getScheduleGroup)) .entrySet().stream() @@ -264,6 +275,24 @@ public void onSpeculativeExecutionCheck() { } } + @Override + public void onWorkStealingCheck() { + MutableBoolean isWorkStealingConditionSatisfied = new MutableBoolean(false); + List scheduleGroup = BatchSchedulerUtils + .selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager).orElse(new ArrayList<>()); + List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); + isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); + + if (isWorkStealingConditionSatisfied.booleanValue()) { + taskIdToProcessedBytes.clear(); + final List skewedTasks = detectSkew(scheduleGroupInId); + } + + // TODO #469 Split tasks using iterator interface. + + return; + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName()); @@ -310,6 +339,9 @@ public void terminate() { * - We make {@link TaskDispatcher} dispatch only the tasks that are READY. */ private void doSchedule() { + taskIdToProcessedBytes.clear(); + workStealingCandidates.clear(); + final Optional> earliest = BatchSchedulerUtils.selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager); @@ -390,7 +422,7 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, return false; } - // Methods for work stealing + ///////////////////////////////////////////////////////////////// Methods for work stealing /** * Accumulate the execution result of each stage in Map[STAGE ID, Map[KEY, SIZE]] format. @@ -424,4 +456,170 @@ public void aggregateTaskIdToProcessedBytes(final String taskId, final long processedBytes) { taskIdToProcessedBytes.put(taskId, processedBytes); } + + /** + * Check if work stealing can be conducted. + * + * @param scheduleGroup schedule group. + */ + private boolean checkForWorkStealingBaseConditions(final List scheduleGroup) { + if (scheduleGroup.isEmpty()) { + return false; + } + + /* If the stage of the given schedule group contains sharded tasks, return false */ + if (scheduleGroup.stream().anyMatch(stageId -> stageIdToWorkStealingExecuted.get(stageId).equals(true))) { + return false; + } + + /* If there are idle executors and the number of remaining tasks are smaller than number of executors, + * return true. + */ + final boolean executorStatus = executorRegistry.isExecutorSlotAvailable(); + final int totalNumberOfSlots = executorRegistry.getTotalNumberOfExecutorSlots(); + int remainingTasks = 0; + for (String stage : scheduleGroup) { + remainingTasks += planStateManager.getNumberOfTasksRemainingInStage(stage); // ready + executing? + } + return executorStatus && (totalNumberOfSlots > remainingTasks); + } + + private Set getCurrentlyRunningTaskId(final List scheduleGroup) { + final Set onGoingTasksOfSchedulingGroup = new HashSet<>(); + for (String stageId : scheduleGroup) { + onGoingTasksOfSchedulingGroup.addAll(planStateManager.getOngoingTaskIdsInStage(stageId)); + } + return onGoingTasksOfSchedulingGroup; + } + + private Map> getParentStages(final List scheduleGroup) { + Map> parentStages = new HashMap<>(); + for (String stageId : scheduleGroup) { + parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId).stream() + .map(Vertex::getId) + .collect(Collectors.toSet())); + } + return parentStages; + } + + private Map getInputSizesOfRunningTaskIds(final Set parentStageIds, + final Set currentlyRunningTaskIds) { + Map currentlyRunningTaskIdsToTotalSize = new HashMap<>(); + for (String parent : parentStageIds) { + Map taskIdxToSize = stageIdToOutputPartitionSizeMap.get(parent); + for (String taskId : currentlyRunningTaskIds) { + if (currentlyRunningTaskIdsToTotalSize.containsKey(taskId)) { + final long existingValue = currentlyRunningTaskIdsToTotalSize.get(taskId); + currentlyRunningTaskIdsToTotalSize.put(taskId, + existingValue + taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } else { + currentlyRunningTaskIdsToTotalSize + .put(taskId, taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } + } + } + return currentlyRunningTaskIdsToTotalSize; + } + + private Map getCurrentExecutionTimeMsOfRunningTasks(final List scheduleGroup) { + final Map taskToExecutionTime = new HashMap<>(); + for (String stageId : scheduleGroup) { + taskToExecutionTime.putAll(planStateManager.getExecutingTaskToRunningTimeMs(stageId)); + } + return taskToExecutionTime; + } + + private List getScheduleGroupByStage(final String stageId) { + return sortedScheduleGroups.get( + planStateManager.getPhysicalPlan().getStageDAG().getVertexById(stageId).getScheduleGroup()) + .stream() + .map(Vertex::getId) + .collect(Collectors.toList()); + } + + /** + * Detect skewed tasks. + * + * @param scheduleGroup current schedule group. + * @return List of skewed tasks. + */ + private List detectSkew(final List scheduleGroup) { + final Map> taskIdToIteratorInformation = new HashMap<>(); + final Map taskIdToInitializationOverhead = new HashMap<>(); + final Map inputSizeOfCandidateTasks = new HashMap<>(); + final Map> parentStageId = getParentStages(scheduleGroup); + + + /* if this schedule group contains a source stage, return empty list */ + if (scheduleGroup.stream().anyMatch(stage -> + planStateManager.getPhysicalPlan().getStageDAG().getParents(stage).isEmpty())) { + return new ArrayList<>(); + } + + workStealingCandidates.addAll(getCurrentlyRunningTaskId(scheduleGroup)); + + /* Gather statistics of work stealing candidates */ + + /* get size of running tasks */ + for (String stage : scheduleGroup) { + inputSizeOfCandidateTasks.putAll( + getInputSizesOfRunningTaskIds(parentStageId.get(stage), workStealingCandidates)); + } + + /* get elapsed time */ + Map taskIdToElapsedTime = getCurrentExecutionTimeMsOfRunningTasks(scheduleGroup); + + /* gather task metric */ + for (String taskId : workStealingCandidates) { + TaskMetric taskMetric = metricStore.getMetricWithId(TaskMetric.class, taskId); + + taskIdToProcessedBytes.put(taskId, taskMetric.getSerializedReadBytes()); + taskIdToIteratorInformation.put(taskId, Pair.of( + taskMetric.getCurrentIteratorIndex(), taskMetric.getTotalIteratorNumber())); + taskIdToInitializationOverhead.put(taskId, taskMetric.getTaskPreparationTime()); + } + + /* If gathered statistic is not sufficient for skew detection, return empty list. */ + if (taskIdToProcessedBytes.size() <= workStealingCandidates.size() / 2) { + return new ArrayList<>(); + } + + /* estimate the remaining time */ + List> estimatedTimeToFinishPerTask = new ArrayList<>(taskIdToElapsedTime.size()); + + for (String taskId : taskIdToProcessedBytes.keySet()) { + // if processed bytes are not available, do not detect skew. + if (taskIdToProcessedBytes.get(taskId) <= 0) { + return new ArrayList<>(); + } + + // if this task is almost finished, ignore it. + Pair iteratorInformation = taskIdToIteratorInformation.get(taskId); + if (iteratorInformation.right() - iteratorInformation.left() <= 2) { + continue; + } + + long timeToFinishExecute = taskIdToElapsedTime.get(taskId) * inputSizeOfCandidateTasks.get(taskId) + / taskIdToProcessedBytes.get(taskId); + + // if the estimated left time is shorter than the initialization overhead, stop! + if (timeToFinishExecute < taskIdToInitializationOverhead.get(taskId) * 2) { + continue; + } + + estimatedTimeToFinishPerTask.add(Pair.of(taskId, timeToFinishExecute)); + } + + // detect skew + Collections.sort(estimatedTimeToFinishPerTask, new Comparator>() { + @Override + public int compare(final Pair o1, final Pair o2) { + return o2.right().compareTo(o1.right()); + } + }); + + return estimatedTimeToFinishPerTask + .subList(0, estimatedTimeToFinishPerTask.size() / 2) + .stream().map(Pair::left).collect(Collectors.toList()); + } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java index cc4661df64..afe30f6e73 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java @@ -86,6 +86,11 @@ void onTaskStateReportFromExecutor(String executorId, */ void onSpeculativeExecutionCheck(); + /** + * Called to check for work stealing condition. + */ + void onWorkStealingCheck(); + /** * To be called when a job should be terminated. * Any clean up code should be implemented in this method. diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java index 42870f609e..5885aa0ada 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java @@ -451,6 +451,12 @@ public void onSpeculativeExecutionCheck() { return; } + @Override + public void onWorkStealingCheck() { + // we don't simulate work stealing yet. + return; + } + @Override public void terminate() { this.taskDispatcher.terminate(); diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java index 24e30bec87..ffa2c586da 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java @@ -149,6 +149,11 @@ public void onSpeculativeExecutionCheck() { throw new UnsupportedOperationException(); } + @Override + public void onWorkStealingCheck() { + throw new UnsupportedOperationException(); + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName()); From 3ba3a5edfcd5abf6dbd2f56f192f3070e1f33981 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:32:37 +0900 Subject: [PATCH 09/17] cleanup skew detection code --- .../master/scheduler/BatchScheduler.java | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index f41b359f7b..c9305494dc 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -461,6 +461,7 @@ public void aggregateTaskIdToProcessedBytes(final String taskId, * Check if work stealing can be conducted. * * @param scheduleGroup schedule group. + * @return true if work stealing is possible. */ private boolean checkForWorkStealingBaseConditions(final List scheduleGroup) { if (scheduleGroup.isEmpty()) { @@ -473,8 +474,7 @@ private boolean checkForWorkStealingBaseConditions(final List scheduleGr } /* If there are idle executors and the number of remaining tasks are smaller than number of executors, - * return true. - */ + * return true. */ final boolean executorStatus = executorRegistry.isExecutorSlotAvailable(); final int totalNumberOfSlots = executorRegistry.getTotalNumberOfExecutorSlots(); int remainingTasks = 0; @@ -484,7 +484,13 @@ private boolean checkForWorkStealingBaseConditions(final List scheduleGr return executorStatus && (totalNumberOfSlots > remainingTasks); } - private Set getCurrentlyRunningTaskId(final List scheduleGroup) { + /** + * Get the ids of tasks in execution. + * + * @param scheduleGroup schedule group. + * @return ids of running tasks. + */ + private Set getRunningTaskId(final List scheduleGroup) { final Set onGoingTasksOfSchedulingGroup = new HashSet<>(); for (String stageId : scheduleGroup) { onGoingTasksOfSchedulingGroup.addAll(planStateManager.getOngoingTaskIdsInStage(stageId)); @@ -492,22 +498,36 @@ private Set getCurrentlyRunningTaskId(final List scheduleGroup) return onGoingTasksOfSchedulingGroup; } + /** + * Get parent stages of given schedule group. + * + * @param scheduleGroup schedule group. + * @return Map of stage and set of its parent. + */ private Map> getParentStages(final List scheduleGroup) { Map> parentStages = new HashMap<>(); for (String stageId : scheduleGroup) { - parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId).stream() + parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId) + .stream() .map(Vertex::getId) .collect(Collectors.toSet())); } return parentStages; } - private Map getInputSizesOfRunningTaskIds(final Set parentStageIds, - final Set currentlyRunningTaskIds) { + /** + * Get the input size of running tasks. + * + * @param parentStageIds id of parent stages. + * @param runningTaskIds id of running tasks. + * @return Map of task id to its input size. + */ + private Map getInputSizeOfRunningTasks(final Set parentStageIds, + final Set runningTaskIds) { Map currentlyRunningTaskIdsToTotalSize = new HashMap<>(); for (String parent : parentStageIds) { Map taskIdxToSize = stageIdToOutputPartitionSizeMap.get(parent); - for (String taskId : currentlyRunningTaskIds) { + for (String taskId : runningTaskIds) { if (currentlyRunningTaskIdsToTotalSize.containsKey(taskId)) { final long existingValue = currentlyRunningTaskIdsToTotalSize.get(taskId); currentlyRunningTaskIdsToTotalSize.put(taskId, @@ -521,6 +541,13 @@ private Map getInputSizesOfRunningTaskIds(final Set parent return currentlyRunningTaskIdsToTotalSize; } + /** + * get current execution time of running tasks in millisecond. + * Note that this is the execution time of incomplete tasks. + * + * @param scheduleGroup schedule group. + * @return Map of task id to its execution time. + */ private Map getCurrentExecutionTimeMsOfRunningTasks(final List scheduleGroup) { final Map taskToExecutionTime = new HashMap<>(); for (String stageId : scheduleGroup) { @@ -556,14 +583,14 @@ private List detectSkew(final List scheduleGroup) { return new ArrayList<>(); } - workStealingCandidates.addAll(getCurrentlyRunningTaskId(scheduleGroup)); + workStealingCandidates.addAll(getRunningTaskId(scheduleGroup)); /* Gather statistics of work stealing candidates */ /* get size of running tasks */ for (String stage : scheduleGroup) { inputSizeOfCandidateTasks.putAll( - getInputSizesOfRunningTaskIds(parentStageId.get(stage), workStealingCandidates)); + getInputSizeOfRunningTasks(parentStageId.get(stage), workStealingCandidates)); } /* get elapsed time */ @@ -610,7 +637,7 @@ private List detectSkew(final List scheduleGroup) { estimatedTimeToFinishPerTask.add(Pair.of(taskId, timeToFinishExecute)); } - // detect skew + /* detect skew */ Collections.sort(estimatedTimeToFinishPerTask, new Comparator>() { @Override public int compare(final Pair o1, final Pair o2) { From 94cdf152bca3cf642f4617868857ebc936e78763 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:36:35 +0900 Subject: [PATCH 10/17] get executor vacancy information from executor registry --- .../master/resource/DefaultExecutorRepresenter.java | 5 +++++ .../nemo/runtime/master/resource/ExecutorRepresenter.java | 5 +++++ .../nemo/runtime/master/scheduler/ExecutorRegistry.java | 8 ++++++++ 3 files changed, 18 insertions(+) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java index 16c9a70db9..ebec804132 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java @@ -170,6 +170,11 @@ public void onTaskExecutionFailed(final String taskId) { failedTasks.add(failedTask); } + @Override + public boolean isExecutorSlotAvailable() { + return getExecutorCapacity() - getNumOfRunningTasks() > 0; + } + /** * @return how many Tasks can this executor simultaneously run */ diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java index 26649a81db..dcfb53eb1c 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java @@ -108,4 +108,9 @@ public interface ExecutorRepresenter { * @param taskId id of the Task */ void onTaskExecutionFailed(String taskId); + + /** + * @return true if this executor has an available slot. + */ + boolean isExecutorSlotAvailable(); } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java index 11d40c73b8..5cead6e290 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java @@ -126,6 +126,14 @@ private Set getRunningExecutors() { .collect(Collectors.toSet()); } + public int getTotalNumberOfExecutorSlots() { + return getRunningExecutors().stream().mapToInt(ExecutorRepresenter::getExecutorCapacity).sum(); + } + + public boolean isExecutorSlotAvailable() { + return getRunningExecutors().stream().anyMatch(ExecutorRepresenter::isExecutorSlotAvailable); + } + @Override public String toString() { return executors.toString(); From d269ca144ec25bca26272f846ed0aceb21838385 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:52:18 +0900 Subject: [PATCH 11/17] add helper methods in plan state manager --- .../nemo/runtime/master/PlanStateManager.java | 71 ++++++++++++++++--- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java index 53cab57810..65b5306f5b 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java @@ -85,6 +85,10 @@ public final class PlanStateManager { private final Map> stageIdToCompletedTaskTimeMsList = new HashMap<>(); private final Map> stageIdToTaskIndexToNumOfClones = new HashMap<>(); + /** + * Used for work stealing. + */ + private final Map>> stageIdToTaskIdxToWSAttemptStates = new HashMap<>(); /** * Represents the plan to manage. */ @@ -127,7 +131,7 @@ public static PlanStateManager newInstance(final String dagDirectory) { } /** - * @param metricStore set the metric store of the paln state manager. + * @param metricStore set the metric store of the plan state manager. */ public void setMetricStore(final MetricStore metricStore) { this.metricStore = metricStore; @@ -326,16 +330,8 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // Log not-yet-completed tasks for us humans to track progress final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); - final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.values().stream() - .filter(attempts -> { - final List states = attempts - .stream() - .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) - .collect(Collectors.toList()); - return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) // one of them is ON_HOLD - || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); // one of them is COMPLETE - }) - .count(); + final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + if (newTaskState.equals(TaskState.State.COMPLETE)) { LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage", taskId, taskStatesOfThisStage.size() - numOfCompletedTaskIndicesInThisStage, taskStatesOfThisStage.size()); @@ -577,6 +573,59 @@ private List getPeerAttemptsForTheSameTaskIndex(final String ta .collect(Collectors.toList()); } + /** + * Get number of remaining tasks of the stage. + * + * @param stageId stage id. + * @return number of remaining tasks. + */ + public int getNumberOfTasksRemainingInStage(final String stageId) { + final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); + final Map> wsTaskStatesOfThisStage = stageIdToTaskIdxToWSAttemptStates + .getOrDefault(stageId, new HashMap<>()); + final long numOfCompletedTaskIndices = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + if (wsTaskStatesOfThisStage.isEmpty()) { + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices); + } else { + final long numOfCompletedWorkStealingTaskIndices = getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices + + wsTaskStatesOfThisStage.size() - numOfCompletedWorkStealingTaskIndices); + } + } + + /** + * Get tasks which are currently being executed. + * + * @param stageId stage id. + * @return Set of tasksIds in execution. + */ + public Set getOngoingTaskIdsInStage(final String stageId) { + final Map> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); + final Set onGoingTaskIds = new HashSet<>(); + for (final int taskIndex : taskIdToState.keySet()) { + final List attemptStates = taskIdToState.get(taskIndex); + for (int attempt = 0; attempt < attemptStates.size(); attempt++) { + if (attemptStates.get(attempt).getStateMachine().getCurrentState().equals(TaskState.State.EXECUTING)) { + onGoingTaskIds.add(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt)); + } + } + } + return onGoingTaskIds; + } + + private long getNumberOfCompletedTasksInStage(final Map> taskIdxToState) { + return taskIdxToState.values().stream() + .filter(attempts -> { + final List states = attempts + .stream() + .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) + .collect(Collectors.toList()); + return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) + || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); + }) + .count(); + } + /** * @return the physical plan. */ From e23b012d00c861df2839db1c4d35304a2c6cdb5b Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:58:01 +0900 Subject: [PATCH 12/17] add task metrics needed for determining work stealing condition --- .../runtime/common/metric/TaskMetric.java | 35 +++++++++++++++++++ .../master/scheduler/BatchScheduler.java | 7 ++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java index 531e715a7b..7d98140cc3 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java @@ -49,6 +49,9 @@ public class TaskMetric implements StateMetric { private long shuffleReadTime = -1; private long shuffleWriteBytes = -1; private long shuffleWriteTime = -1; + private int currentIteratorIndex = -1; + private int totalIteratorNumber = -1; + private long taskPreparationTime = -1; private static final Logger LOG = LoggerFactory.getLogger(TaskMetric.class.getName()); @@ -252,6 +255,30 @@ private void setShuffleWriteTime(final long shuffleWriteTime) { this.shuffleWriteTime = shuffleWriteTime; } + public final int getCurrentIteratorIndex() { + return this.currentIteratorIndex; + } + + private void setCurrentIteratorIndex(final int currentIteratorIndex) { + this.currentIteratorIndex = currentIteratorIndex; + } + + public final int getTotalIteratorNumber() { + return this.totalIteratorNumber; + } + + private void setTotalIteratorNumber(final int totalIteratorNumber) { + this.totalIteratorNumber = totalIteratorNumber; + } + + public final long getTaskPreparationTime() { + return this.taskPreparationTime; + } + + private void setTaskPreparationTime(final long taskPreparationTime) { + this.taskPreparationTime = taskPreparationTime; + } + @Override public final String getId() { return id; @@ -317,6 +344,14 @@ public final boolean processMetricMessage(final String metricField, final byte[] case "shuffleWriteTime": setShuffleWriteTime(SerializationUtils.deserialize(metricValue)); break; + case "currentIteratorIndex": + setCurrentIteratorIndex(SerializationUtils.deserialize(metricValue)); + break; + case "totalIteratorNumber": + setTotalIteratorNumber(SerializationUtils.deserialize(metricValue)); + break; + case "taskPreparationTime": + setTaskPreparationTime(SerializationUtils.deserialize(metricValue)); default: LOG.warn("metricField {} is not supported.", metricField); return false; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index c9305494dc..3caa4118ce 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -283,10 +283,11 @@ public void onWorkStealingCheck() { List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); - if (isWorkStealingConditionSatisfied.booleanValue()) { - taskIdToProcessedBytes.clear(); - final List skewedTasks = detectSkew(scheduleGroupInId); + if (!isWorkStealingConditionSatisfied.booleanValue()) { + return; } + taskIdToProcessedBytes.clear(); + final List skewedTasks = detectSkew(scheduleGroupInId); // TODO #469 Split tasks using iterator interface. From 73825621a3fbf26ce873ccd78e002d638f9512ac Mon Sep 17 00:00:00 2001 From: hwarim Date: Tue, 20 Jul 2021 15:20:59 +0900 Subject: [PATCH 13/17] edit plan state manager --- .../nemo/runtime/master/PlanStateManager.java | 51 ++++++++++++++++--- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java index 65b5306f5b..5e4c1db022 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java @@ -160,6 +160,21 @@ public synchronized void updatePlan(final PhysicalPlan physicalPlanToUpdate, initializeStates(); } + /** + * Add work stealing tasks to the plan. + * @param workStealingTasks work stealing tasks. + */ + public synchronized void addWorkStealingTasks(final Set workStealingTasks) { + for (String taskId : workStealingTasks) { + final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); + final int taskIdx = RuntimeIdManager.getIndexFromTaskId(taskId); + stageIdToTaskIdxToWSAttemptStates.putIfAbsent(stageId, new HashMap<>()); + List attemptStatesForThisTask = new ArrayList<>(); + attemptStatesForThisTask.add(new TaskState()); + stageIdToTaskIdxToWSAttemptStates.get(stageId).putIfAbsent(taskIdx, attemptStatesForThisTask); + } + } + /** * Initializes the states for the plan/stages/tasks for this plan. * TODO #182: Consider reshaping in run-time optimization. At now, we only consider plan appending. @@ -330,7 +345,10 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // Log not-yet-completed tasks for us humans to track progress final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); - final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + final Map> wsTaskStatesOfThisStage = + stageIdToTaskIdxToWSAttemptStates.getOrDefault(stageId, new HashMap<>()); + final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage) + + getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); if (newTaskState.equals(TaskState.State.COMPLETE)) { LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage", @@ -360,9 +378,18 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // COMPLETE stage case COMPLETE: case ON_HOLD: - if (numOfCompletedTaskIndicesInThisStage - == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size()) { - onStageStateChanged(stageId, StageState.State.COMPLETE); + // if work stealing enabled + if (!physicalPlan.getStageDAG().getVertexById(stageId).getWorkStealingTaskIds().isEmpty()) { + if (numOfCompletedTaskIndicesInThisStage + == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size() + + physicalPlan.getStageDAG().getVertexById(stageId).getWorkStealingTaskIds().size()) { + onStageStateChanged(stageId, StageState.State.COMPLETE); + } + } else { + if (numOfCompletedTaskIndicesInThisStage + == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size()) { + onStageStateChanged(stageId, StageState.State.COMPLETE); + } } break; @@ -546,10 +573,18 @@ private Map getTaskAttemptIdsToItsState(final String st } private TaskState getTaskStateHelper(final String taskId) { - return stageIdToTaskIdxToAttemptStates - .get(RuntimeIdManager.getStageIdFromTaskId(taskId)) - .get(RuntimeIdManager.getIndexFromTaskId(taskId)) - .get(RuntimeIdManager.getAttemptFromTaskId(taskId)); + final boolean isWorkStealingTask = taskId.split("-")[2].equals("*"); + if (isWorkStealingTask) { + return stageIdToTaskIdxToWSAttemptStates + .get(RuntimeIdManager.getStageIdFromTaskId(taskId)) + .get(RuntimeIdManager.getIndexFromTaskId(taskId)) + .get(RuntimeIdManager.getAttemptFromTaskId(taskId)); + } else { + return stageIdToTaskIdxToAttemptStates + .get(RuntimeIdManager.getStageIdFromTaskId(taskId)) + .get(RuntimeIdManager.getIndexFromTaskId(taskId)) + .get(RuntimeIdManager.getAttemptFromTaskId(taskId)); + } } private boolean isTaskNotDone(final TaskState taskState) { From 9f94638d209afa8a0cc87e4f364cdf38f376635a Mon Sep 17 00:00:00 2001 From: hwarim Date: Tue, 20 Jul 2021 15:22:55 +0900 Subject: [PATCH 14/17] add work stealing information in stage and task --- .../nemo/runtime/common/RuntimeIdManager.java | 9 +++++ .../nemo/runtime/common/plan/Stage.java | 17 +++++++-- .../apache/nemo/runtime/common/plan/Task.java | 37 ++++++++++++++++++- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java index 255cca70f5..0c4f74231c 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java @@ -74,6 +74,15 @@ public static String generateTaskId(final String stageId, final int index, final return stageId + SPLITTER + index + SPLITTER + attempt; } + /** + * Generates the ID of a task created by Work Stealing. + * @param taskId the ID of original task. + * @return the generated ID. + */ + public static String generateWorkStealingTaskId(final String taskId) { + return getStageIdFromTaskId(taskId) + SPLITTER + getIndexFromTaskId(taskId) + SPLITTER + "*"; + } + /** * Generates the ID for executor. * diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java index a7f472c0da..e60cf1648f 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java @@ -33,15 +33,14 @@ import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty; import java.io.Serializable; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; /** * Stage. */ public final class Stage extends Vertex { private final List taskIndices; + private final Set workStealingTaskIds = new HashSet<>(); private final DAG> irDag; private final byte[] serializedIRDag; private final List> vertexIdToReadables; @@ -93,6 +92,18 @@ public List getTaskIndices() { return taskIndices; } + /** + * Set IDs for work stealing. + * @param workStealingTaskIds IDs of work stealer tasks. + */ + public void setWorkStealingTaskIds(final Set workStealingTaskIds) { + this.workStealingTaskIds.addAll(workStealingTaskIds); + } + + public Set getWorkStealingTaskIds() { + return this.workStealingTaskIds; + } + /** * @return the parallelism. */ diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java index 719075b456..5d827e4d07 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; /** * A Task (attempt) is a self-contained executable that can be executed on a machine. @@ -40,8 +41,13 @@ public final class Task implements Serializable { private final byte[] serializedIRDag; private final Map irVertexIdToReadable; + /* For work stealing */ + private final AtomicInteger iteratorStartIndex; + private final AtomicInteger iteratorEndIndex; + /** - * Constructor. + * Default Constructor. + * It initializes iteratorStartIndex as 0 and iteratorEndIndex as Integer.MAX_VALUE. * * @param planId the id of the physical plan. * @param taskId the ID of this task attempt. @@ -58,6 +64,33 @@ public Task(final String planId, final List taskIncomingEdges, final List taskOutgoingEdges, final Map irVertexIdToReadable) { + this(planId, taskId, executionProperties, serializedIRDag, taskIncomingEdges, taskOutgoingEdges, + irVertexIdToReadable, new AtomicInteger(0), new AtomicInteger(Integer.MAX_VALUE)); + } + + /** + * Constructor with iterator information. + * This constructor is used when creating work stealer tasks. + * + * @param planId the id of the physical plan. + * @param taskId the ID of this task attempt. + * @param executionProperties {@link VertexExecutionProperty} map for the corresponding stage + * @param serializedIRDag the serialized DAG of the task. + * @param taskIncomingEdges the incoming edges of the task. + * @param taskOutgoingEdges the outgoing edges of the task. + * @param irVertexIdToReadable the map between IRVertex id to readable. + * @param iteratorStartIndex starting index of iterator. + * @param iteratorEndIndex ending index of iterator. + */ + public Task(final String planId, + final String taskId, + final ExecutionPropertyMap executionProperties, + final byte[] serializedIRDag, + final List taskIncomingEdges, + final List taskOutgoingEdges, + final Map irVertexIdToReadable, + final AtomicInteger iteratorStartIndex, + final AtomicInteger iteratorEndIndex) { this.planId = planId; this.taskId = taskId; this.executionProperties = executionProperties; @@ -65,6 +98,8 @@ public Task(final String planId, this.taskIncomingEdges = taskIncomingEdges; this.taskOutgoingEdges = taskOutgoingEdges; this.irVertexIdToReadable = irVertexIdToReadable; + this.iteratorStartIndex = iteratorStartIndex; + this.iteratorEndIndex = iteratorEndIndex; } /** From e06b7705871ecab1a18ae8ee7842c840a76748ea Mon Sep 17 00:00:00 2001 From: hwarim Date: Tue, 20 Jul 2021 18:00:51 +0900 Subject: [PATCH 15/17] split skewed tasks and make new tasks to allocate --- .../src/main/proto/ControlMessage.proto | 8 +- .../nemo/runtime/master/PlanStateManager.java | 4 +- .../master/scheduler/BatchScheduler.java | 141 +++++++++++++++++- 3 files changed, 146 insertions(+), 7 deletions(-) diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 53adea3d3a..bd5ed2d30f 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -87,7 +87,8 @@ enum MessageType { RequestPipeLoc = 14; PipeLocInfo = 15; ParentTaskDataCollected = 16; - CurrentlyProcessedBytesCollected = 17; + CurrentlyProcessedBytesCollected = 17; + SendWorkStealingResult = 18; } message Message { @@ -111,6 +112,7 @@ message Message { optional PipeLocationInfoMessage pipeLocInfoMsg = 18; optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20; + optional WorkStealingResultMsg sendWorkStealingResult = 22; } // Messages from Master to Executors @@ -270,3 +272,7 @@ message CurrentlyProcessedBytesCollectMsg { required string taskId = 1; required int64 processedDataBytes = 2; } + +message WorkStealingResultMsg { + required bytes workStealingResult = 1; +} diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java index 5e4c1db022..b8f8f7b335 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java @@ -347,8 +347,8 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); final Map> wsTaskStatesOfThisStage = stageIdToTaskIdxToWSAttemptStates.getOrDefault(stageId, new HashMap<>()); - final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage) + - getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); + final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage) + + getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); if (newTaskState.equals(TaskState.State.COMPLETE)) { LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage", diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 3caa4118ce..288d3acfc9 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -18,13 +18,18 @@ */ package org.apache.nemo.runtime.master.scheduler; +import com.google.protobuf.ByteString; import org.apache.commons.lang.mutable.MutableBoolean; +import org.apache.commons.lang3.SerializationUtils; import org.apache.nemo.common.Pair; import org.apache.nemo.common.dag.Vertex; import org.apache.nemo.common.exception.UnknownExecutionStateException; import org.apache.nemo.common.exception.UnrecoverableFailureException; +import org.apache.nemo.common.ir.Readable; import org.apache.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty; import org.apache.nemo.runtime.common.RuntimeIdManager; +import org.apache.nemo.runtime.common.comm.ControlMessage; +import org.apache.nemo.runtime.common.message.MessageEnvironment; import org.apache.nemo.runtime.common.metric.TaskMetric; import org.apache.nemo.runtime.common.plan.*; import org.apache.nemo.runtime.common.state.StageState; @@ -41,7 +46,9 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import javax.inject.Inject; +import java.io.Serializable; import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; /** @@ -277,10 +284,12 @@ public void onSpeculativeExecutionCheck() { @Override public void onWorkStealingCheck() { - MutableBoolean isWorkStealingConditionSatisfied = new MutableBoolean(false); - List scheduleGroup = BatchSchedulerUtils + final List scheduleGroup = BatchSchedulerUtils .selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager).orElse(new ArrayList<>()); - List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); + final List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); + final Map> wsResult = new HashMap<>(); + final MutableBoolean isWorkStealingConditionSatisfied = new MutableBoolean(false); + isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); if (!isWorkStealingConditionSatisfied.booleanValue()) { @@ -290,7 +299,29 @@ public void onWorkStealingCheck() { final List skewedTasks = detectSkew(scheduleGroupInId); // TODO #469 Split tasks using iterator interface. + if (skewedTasks.isEmpty()) { + return; + } + + final Map> taskToSplitIteratorInfo = splitIterator(skewedTasks); + final List wsTasks = generateWorkStealingTasks(scheduleGroup, skewedTasks, taskToSplitIteratorInfo); + + // accumulate the Victim tasks and non skewed tasks result + for (String taskId : workStealingCandidates) { + if (skewedTasks.contains(taskId)) { // this is for skewed task + Pair iteratorInfo = taskToSplitIteratorInfo.get(taskId); + wsResult.put(taskId, Pair.of(0, iteratorInfo.left())); + } else { // this is for non skewed tasks + wsResult.put(taskId, Pair.of(0, Integer.MAX_VALUE)); + } + } + + /* notify the updated information to executors */ + sendWorkStealingResultToExecutor(wsResult); + // schedule new tasks + pendingTaskCollectionPointer.setToOverwrite(wsTasks); + taskDispatcher.onNewPendingTaskCollectionAvailable(); return; } @@ -587,7 +618,6 @@ private List detectSkew(final List scheduleGroup) { workStealingCandidates.addAll(getRunningTaskId(scheduleGroup)); /* Gather statistics of work stealing candidates */ - /* get size of running tasks */ for (String stage : scheduleGroup) { inputSizeOfCandidateTasks.putAll( @@ -650,4 +680,107 @@ public int compare(final Pair o1, final Pair o2) { .subList(0, estimatedTimeToFinishPerTask.size() / 2) .stream().map(Pair::left).collect(Collectors.toList()); } + + private Map> splitIterator(final List skewedTasks) { + final Map> taskToIteratorInfo = new HashMap<>(); + + for (String taskId : skewedTasks) { + TaskMetric taskMetric = metricStore.getMetricWithId(TaskMetric.class, taskId); + int currIterIdx = taskMetric.getCurrentIteratorIndex(); + int totalIterIndex = taskMetric.getTotalIteratorNumber(); + int changePoint = (int) Math.floor((totalIterIndex + currIterIdx) / 2 + 1); + + taskToIteratorInfo.put(taskId, Pair.of(changePoint, totalIterIndex)); + } + + return taskToIteratorInfo; + } + + private List generateWorkStealingTasks(final List scheduleGroup, + final List skewedTasks, + final Map> taskToIteratorInfo) { + /* Split the skewed tasks */ + final List tasksToSchedule = new ArrayList<>(skewedTasks.size()); + + + // tasks are generated in "stage" based : loop on stages, not schedule group + for (Stage stageToSchedule : scheduleGroup) { + String stageId = stageToSchedule.getId(); + + // make new task ids and store that information in corresponding stage and plan state manager + // for now, id logic for robber tasks are as follows: + // - same stage id (obvious) + // - same index number (need to fetch the same data as the victim task) + // - attempt number is replaced with "*", similar withe the block wildcard id. + + //generate the robber tasks' id + final Set newTaskIds = skewedTasks.stream() + .filter(taskId -> taskId.contains(stageId)) + .map(taskId -> RuntimeIdManager.generateWorkStealingTaskId(taskId)) + .collect(Collectors.toSet()); + + if (newTaskIds.isEmpty()) { + continue; + } + + // update the work stealing tasks in Stage and PlanStateManager + planStateManager.getPhysicalPlan().getStageDAG() + .getVertexById(stageId).setWorkStealingTaskIds(newTaskIds); + planStateManager.addWorkStealingTasks(newTaskIds); + + // house keeping stuffs needed for initializing tasks + // create and return Robber tasks + final List stageIncomingEdges = + planStateManager.getPhysicalPlan().getStageDAG().getIncomingEdgesOf(stageToSchedule.getId()); + final List stageOutgoingEdges = + planStateManager.getPhysicalPlan().getStageDAG().getOutgoingEdgesOf(stageToSchedule.getId()); + final List> vertexIdToReadable = stageToSchedule.getVertexIdToReadables(); + + skewedTasks.forEach(taskId -> { + final Set blockIds = BatchSchedulerUtils.getOutputBlockIds(planStateManager, taskId); + blockManagerMaster.onProducerTaskScheduled(taskId, blockIds); + final int taskIdx = RuntimeIdManager.getIndexFromTaskId(taskId); + + int startIterIdx = taskToIteratorInfo.get(taskId).left(); + int endIterIndex = taskToIteratorInfo.get(taskId).right(); + + tasksToSchedule.add(new Task( + planStateManager.getPhysicalPlan().getPlanId(), + RuntimeIdManager.generateWorkStealingTaskId(taskId), + stageToSchedule.getExecutionProperties(), + stageToSchedule.getSerializedIRDAG(), + stageIncomingEdges, + stageOutgoingEdges, + vertexIdToReadable.get(taskIdx), + new AtomicInteger(startIterIdx), + new AtomicInteger(endIterIndex))); + }); + + + // do work stealing for only once : this is because of the index based task state tracking system + // Need to be handled in the near future! + stageIdToWorkStealingExecuted.put(stageId, true); + } + + return tasksToSchedule; + } + + /** + * Send the accumulated iterator information (work stealing result) to executor. + * @param result result to send. + */ + private void sendWorkStealingResultToExecutor(final Map> result) { + // driver sends message to executors + // ask executors to flush metric + final byte[] serialized = SerializationUtils.serialize((Serializable) result); + ControlMessage.Message message = ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.SendWorkStealingResult) + .setSendWorkStealingResult(ControlMessage.WorkStealingResultMsg.newBuilder() + .setWorkStealingResult(ByteString.copyFrom(serialized)) + .build()) + .build(); + executorRegistry.viewExecutors(executors -> executors.forEach(executor -> executor.sendControlMessage(message))); + } } From 05fcbec4dcafa197a9eb343e206236eeb3f64b1b Mon Sep 17 00:00:00 2001 From: hwarim Date: Tue, 20 Jul 2021 18:21:52 +0900 Subject: [PATCH 16/17] add comments --- .../master/scheduler/BatchScheduler.java | 59 ++++++++++++------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 288d3acfc9..d30ea42394 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -287,26 +287,29 @@ public void onWorkStealingCheck() { final List scheduleGroup = BatchSchedulerUtils .selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager).orElse(new ArrayList<>()); final List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); - final Map> wsResult = new HashMap<>(); final MutableBoolean isWorkStealingConditionSatisfied = new MutableBoolean(false); + final Map> wsResult = new HashMap<>(); + /* check if work stealing is possible. If not, return */ isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); - if (!isWorkStealingConditionSatisfied.booleanValue()) { return; } + + /* detect skewed tasks */ taskIdToProcessedBytes.clear(); final List skewedTasks = detectSkew(scheduleGroupInId); - // TODO #469 Split tasks using iterator interface. + /* if there are no skewed tasks, return */ if (skewedTasks.isEmpty()) { return; } + /* generate work stealing tasks */ final Map> taskToSplitIteratorInfo = splitIterator(skewedTasks); final List wsTasks = generateWorkStealingTasks(scheduleGroup, skewedTasks, taskToSplitIteratorInfo); - // accumulate the Victim tasks and non skewed tasks result + /* accumulate result */ for (String taskId : workStealingCandidates) { if (skewedTasks.contains(taskId)) { // this is for skewed task Pair iteratorInfo = taskToSplitIteratorInfo.get(taskId); @@ -319,10 +322,9 @@ public void onWorkStealingCheck() { /* notify the updated information to executors */ sendWorkStealingResultToExecutor(wsResult); - // schedule new tasks + /* schedule new tasks */ pendingTaskCollectionPointer.setToOverwrite(wsTasks); taskDispatcher.onNewPendingTaskCollectionAvailable(); - return; } @Override @@ -676,11 +678,20 @@ public int compare(final Pair o1, final Pair o2) { } }); + /* return only longer half */ return estimatedTimeToFinishPerTask .subList(0, estimatedTimeToFinishPerTask.size() / 2) .stream().map(Pair::left).collect(Collectors.toList()); } + /** + * Calculate the iterator range of work stealing tasks. + * Given a skewed task, it calculates the iterator range which work stealing task will take from the task. + * + * @param skewedTasks List of skewed (original) tasks. + * @return Map of skewed task ID to iterator information. + * pair.left() is the starting index (inclusive) and pair.right() ending index (exclusive). + */ private Map> splitIterator(final List skewedTasks) { final Map> taskToIteratorInfo = new HashMap<>(); @@ -696,40 +707,47 @@ private Map> splitIterator(final List ske return taskToIteratorInfo; } + /** + * Generate work stealing tasks. + * + * @param scheduleGroup schedule group. + * @param skewedTasks List of skewed (original) tasks. + * @param taskToIteratorInfo Map of work stealing task ID to its iterator range information. + * @return List of work stealer tasks. + */ private List generateWorkStealingTasks(final List scheduleGroup, final List skewedTasks, final Map> taskToIteratorInfo) { - /* Split the skewed tasks */ final List tasksToSchedule = new ArrayList<>(skewedTasks.size()); - - // tasks are generated in "stage" based : loop on stages, not schedule group + /* tasks are generated in stage based: loop by stage, not schedule group */ for (Stage stageToSchedule : scheduleGroup) { String stageId = stageToSchedule.getId(); - // make new task ids and store that information in corresponding stage and plan state manager - // for now, id logic for robber tasks are as follows: - // - same stage id (obvious) - // - same index number (need to fetch the same data as the victim task) - // - attempt number is replaced with "*", similar withe the block wildcard id. + /* make new task ids and store that information in stage and plan state manager. + * for now, id logic for work stealing tasks is as follows: + * - same stage id + * - same index number + * - attempt number is replaced with "*", similar with the block wildcard id. + */ - //generate the robber tasks' id + /* generate work stealing task id */ final Set newTaskIds = skewedTasks.stream() .filter(taskId -> taskId.contains(stageId)) - .map(taskId -> RuntimeIdManager.generateWorkStealingTaskId(taskId)) + .map(RuntimeIdManager::generateWorkStealingTaskId) .collect(Collectors.toSet()); + /* if there are no work stealing tasks in this stage, pass */ if (newTaskIds.isEmpty()) { continue; } - // update the work stealing tasks in Stage and PlanStateManager + /* update the work stealing tasks in Stage and PlanStateManager */ planStateManager.getPhysicalPlan().getStageDAG() .getVertexById(stageId).setWorkStealingTaskIds(newTaskIds); planStateManager.addWorkStealingTasks(newTaskIds); - // house keeping stuffs needed for initializing tasks - // create and return Robber tasks + /* create work stealing task */ final List stageIncomingEdges = planStateManager.getPhysicalPlan().getStageDAG().getIncomingEdgesOf(stageToSchedule.getId()); final List stageOutgoingEdges = @@ -767,11 +785,10 @@ private List generateWorkStealingTasks(final List scheduleGroup, /** * Send the accumulated iterator information (work stealing result) to executor. + * * @param result result to send. */ private void sendWorkStealingResultToExecutor(final Map> result) { - // driver sends message to executors - // ask executors to flush metric final byte[] serialized = SerializationUtils.serialize((Serializable) result); ControlMessage.Message message = ControlMessage.Message.newBuilder() .setId(RuntimeIdManager.generateMessageId()) From 436fed09b79d4e01f45225388daa3e4cab0f85d6 Mon Sep 17 00:00:00 2001 From: hwarim Date: Wed, 28 Jul 2021 13:36:41 +0900 Subject: [PATCH 17/17] halting and resuming executors --- .../apache/nemo/runtime/common/plan/Task.java | 8 ++ .../src/main/proto/ControlMessage.proto | 12 ++- .../nemo/runtime/executor/Executor.java | 78 ++++++++++++++++++- .../runtime/executor/task/DataFetcher.java | 4 +- .../MultiThreadParentTaskDataFetcher.java | 4 +- .../executor/task/ParentTaskDataFetcher.java | 32 +++++++- .../task/SourceVertexDataFetcher.java | 5 +- .../runtime/executor/task/TaskExecutor.java | 20 ++++- .../task/ParentTaskDataFetcherTest.java | 5 +- .../executor/task/TaskExecutorTest.java | 3 +- .../master/scheduler/BatchScheduler.java | 18 +++++ 11 files changed, 173 insertions(+), 16 deletions(-) diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java index 5d827e4d07..4172715335 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java @@ -177,6 +177,14 @@ public Map getIrVertexIdToReadable() { return irVertexIdToReadable; } + public AtomicInteger getIteratorStartIndex() { + return this.iteratorStartIndex; + } + + public AtomicInteger getIteratorEndIndex() { + return this.iteratorEndIndex; + } + @Override public String toString() { final StringBuilder sb = new StringBuilder(); diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index bd5ed2d30f..ef0b95b6b7 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -89,6 +89,8 @@ enum MessageType { ParentTaskDataCollected = 16; CurrentlyProcessedBytesCollected = 17; SendWorkStealingResult = 18; + HaltExecutors = 19; + ResumeTask = 20; } message Message { @@ -112,7 +114,9 @@ message Message { optional PipeLocationInfoMessage pipeLocInfoMsg = 18; optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20; - optional WorkStealingResultMsg sendWorkStealingResult = 22; + optional WorkStealingResultMsg sendWorkStealingResult = 21; + optional HaltExecutorsMsg haltExecutorsMsg = 22; + optional ResumeTaskMsg resumeTaskMsg = 23; } // Messages from Master to Executors @@ -276,3 +280,9 @@ message CurrentlyProcessedBytesCollectMsg { message WorkStealingResultMsg { required bytes workStealingResult = 1; } + +message HaltExecutorsMsg { +} + +message ResumeTaskMsg { +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/Executor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/Executor.java index 75c2e43afc..060492378b 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/Executor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/Executor.java @@ -21,6 +21,7 @@ import com.google.protobuf.ByteString; import org.apache.commons.lang3.SerializationUtils; import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.apache.nemo.common.Pair; import org.apache.nemo.common.coder.BytesDecoderFactory; import org.apache.nemo.common.coder.BytesEncoderFactory; import org.apache.nemo.common.coder.DecoderFactory; @@ -53,8 +54,13 @@ import org.slf4j.LoggerFactory; import javax.inject.Inject; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; /** * Executor. @@ -68,6 +74,8 @@ public final class Executor { * To be used for a thread pool to execute tasks. */ private final ExecutorService executorService; + /* For work stealing */ + private final ScheduledExecutorService workStealingManager; /** * In charge of this executor's intermediate data transfer. @@ -85,6 +93,12 @@ public final class Executor { private final MetricMessageSender metricMessageSender; + /** + * For runtime optimizations. + */ + private final List> listOfWorkingTaskExecutors; + private final Map> taskIdToIteratorInfo; + @Inject private Executor(@Parameter(JobConf.ExecutorId.class) final String executorId, final PersistentConnectionToMasterMap persistentConnectionToMasterMap, @@ -97,12 +111,19 @@ private Executor(@Parameter(JobConf.ExecutorId.class) final String executorId, this.executorService = Executors.newCachedThreadPool(new BasicThreadFactory.Builder() .namingPattern("TaskExecutor thread-%d") .build()); + this.workStealingManager = Executors.newSingleThreadScheduledExecutor(new BasicThreadFactory.Builder() + .namingPattern("workstealing manager in executorSide") + .build()); this.persistentConnectionToMasterMap = persistentConnectionToMasterMap; this.serializerManager = serializerManager; this.intermediateDataIOFactory = intermediateDataIOFactory; this.broadcastManagerWorker = broadcastManagerWorker; this.metricMessageSender = metricMessageSender; messageEnvironment.setupListener(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID, new ExecutorMessageReceiver()); + + this.listOfWorkingTaskExecutors = Collections.synchronizedList(new LinkedList<>()); + this.taskIdToIteratorInfo = new ConcurrentHashMap(); + } public String getExecutorId() { @@ -148,8 +169,18 @@ private void launchTask(final Task task) { e.getPropertyValue(CompressionProperty.class).orElse(null), e.getPropertyValue(DecompressionProperty.class).orElse(null)))); - new TaskExecutor(task, irDag, taskStateManager, intermediateDataIOFactory, broadcastManagerWorker, - metricMessageSender, persistentConnectionToMasterMap).execute(); + final AtomicBoolean onHold = new AtomicBoolean(false); + final TaskExecutor taskExecutor = new TaskExecutor(task, irDag, taskStateManager, intermediateDataIOFactory, + broadcastManagerWorker, metricMessageSender, persistentConnectionToMasterMap, onHold); + Pair taskExecutorPair = Pair.of(taskExecutor, onHold); + + listOfWorkingTaskExecutors.add(taskExecutorPair); + taskIdToIteratorInfo.put(task.getTaskId(), + Pair.of(task.getIteratorStartIndex(), task.getIteratorEndIndex())); + taskExecutor.execute(); + listOfWorkingTaskExecutors.remove(taskExecutorPair); + taskIdToIteratorInfo.remove(task.getTaskId()); + } catch (final Exception e) { persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( ControlMessage.Message.newBuilder() @@ -206,6 +237,32 @@ public void terminate() { } } + /* Methods for work stealing */ + private synchronized void onDataRequestReceived() { + listOfWorkingTaskExecutors.forEach(pair -> pair.right().set(true)); + //listOfWorkingTaskExecutors.forEach(TaskExecutor::onRequestForProcessedData); + } + + private synchronized void resumePausedTasks() { + listOfWorkingTaskExecutors.forEach(pair -> pair.right().set(false)); + } + + private synchronized void resumePausedTasksWithWorkStealing(final Map> result) { + // update iterator information + // skewed tasks: set iterator value + // non skewed tasks: do not change iterator + for (String taskId : taskIdToIteratorInfo.keySet()) { + Pair startAndEndIndex = result.get(taskId); + if (startAndEndIndex.left() == 0 && startAndEndIndex.right() == Integer.MAX_VALUE) { + continue; + } + Pair currentInfo = taskIdToIteratorInfo.get(taskId); + currentInfo.left().set(startAndEndIndex.left()); // null pointer exception here! + currentInfo.right().set(startAndEndIndex.right()); + } + resumePausedTasks(); + } + /** * MessageListener for Executor. */ @@ -220,6 +277,19 @@ public void onMessage(final ControlMessage.Message message) { SerializationUtils.deserialize(scheduleTaskMsg.getTask().toByteArray()); onTaskReceived(task); break; + case HaltExecutors: + onDataRequestReceived(); + break; + case ResumeTask: + resumePausedTasks(); + break; + case SendWorkStealingResult: + final ControlMessage.WorkStealingResultMsg workStealingResultMsg = message.getSendWorkStealingResult(); + final Map> iteratorInformationMap = + SerializationUtils.deserialize(workStealingResultMsg.getWorkStealingResult().toByteArray()); + LOG.error("received: {}", iteratorInformationMap); + resumePausedTasksWithWorkStealing(iteratorInformationMap); + break; case RequestMetricFlush: metricMessageSender.flush(); break; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java index b1a828c13c..cd2bb5a56f 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java @@ -23,6 +23,7 @@ import org.apache.nemo.runtime.executor.MetricMessageSender; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; /** * An abstraction for fetching data from task-external sources. @@ -63,7 +64,8 @@ abstract class DataFetcher implements AutoCloseable { * @throws java.util.NoSuchElementException if no more element is available */ abstract Object fetchDataElementWithTrace(String taskId, - MetricMessageSender metricMessageSender) throws IOException; + MetricMessageSender metricMessageSender, + AtomicBoolean onHold) throws IOException; OutputCollector getOutputCollector() { return outputCollector; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index d7947e8c78..80e607ac6b 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -36,6 +36,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; /** * Task thread -> fetchDataElement() -> (((QUEUE))) <- List of iterators <- queueInsertionThreads @@ -103,7 +104,8 @@ Object fetchDataElement() throws IOException { @Override Object fetchDataElementWithTrace(final String taskId, - final MetricMessageSender metricMessageSender) throws IOException { + final MetricMessageSender metricMessageSender, + final AtomicBoolean onHold) throws IOException { return fetchDataElement(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index 3a92cbc8a9..ae3fbf9c75 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -34,6 +34,8 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; /** * Fetches data from parent tasks. @@ -53,14 +55,21 @@ class ParentTaskDataFetcher extends DataFetcher { private long serBytes = 0; private long encodedBytes = 0; + private AtomicInteger iteratorStartIndex; + private AtomicInteger iteratorEndIndex; + ParentTaskDataFetcher(final IRVertex dataSource, final InputReader inputReader, - final OutputCollector outputCollector) { + final OutputCollector outputCollector, + final AtomicInteger iteratorStartIndex, + final AtomicInteger iteratorEndIndex) { super(dataSource, outputCollector); this.inputReader = inputReader; this.firstFetch = true; this.currentIteratorIndex = 0; this.iteratorQueue = new LinkedBlockingQueue<>(); + this.iteratorStartIndex = iteratorStartIndex; + this.iteratorEndIndex = iteratorEndIndex; } @Override @@ -104,27 +113,41 @@ Object fetchDataElement() throws IOException { @Override Object fetchDataElementWithTrace(final String taskId, - final MetricMessageSender metricMessageSender) throws IOException { + final MetricMessageSender metricMessageSender, + final AtomicBoolean onHold) throws IOException { try { if (firstFetch) { fetchDataLazily(); advanceIterator(); + + // if this a work stealing task, move the iterator index to its starting point + while (currentIteratorIndex < iteratorStartIndex.get()) { + advanceIterator(); + } firstFetch = false; } - while (true) { + while (!onHold.get()) { // This iterator has the element if (this.currentIterator.hasNext()) { return this.currentIterator.next(); } // This iterator does not have the element + if (currentIteratorIndex >= iteratorEndIndex.get()) { + break; + } if (currentIteratorIndex < expectedNumOfIterators) { // Next iterator has the element countBytes(currentIterator); // Send the cumulative serBytes to MetricStore metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes", SerializationUtils.serialize(serBytes)); + metricMessageSender.send("TaksMetric", taskId, "currentIteratorIndex", + SerializationUtils.serialize(currentIteratorIndex)); + metricMessageSender.send("TaskMetric", taskId, "totalIteratorNumber", + SerializationUtils.serialize(expectedNumOfIterators)); + metricMessageSender.flush(); advanceIterator(); continue; } else { @@ -202,6 +225,9 @@ private void handleIncomingBlock(final int index, private void fetchDataLazily() { final List> futures = inputReader.read(); this.expectedNumOfIterators = futures.size(); + if (iteratorEndIndex.get() > expectedNumOfIterators) { + iteratorEndIndex.set(expectedNumOfIterators); + } for (int i = 0; i < futures.size(); i++) { final int index = i; final CompletableFuture future = futures.get(i); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 68a3362d27..d1bdc55f26 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -28,6 +28,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; /** * Fetches data from a data source. @@ -76,7 +77,9 @@ Object fetchDataElement() { } @Override - Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) { + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender, + final AtomicBoolean onHold) { return fetchDataElement(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 91e8212640..009c77dba4 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -55,6 +55,7 @@ import javax.annotation.concurrent.NotThreadSafe; import java.io.IOException; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; /** @@ -83,6 +84,7 @@ public final class TaskExecutor { // Dynamic optimization private String idOfVertexPutOnHold; + private final AtomicBoolean onHold; private final PersistentConnectionToMasterMap persistentConnectionToMasterMap; @@ -103,8 +105,10 @@ public TaskExecutor(final Task task, final IntermediateDataIOFactory intermediateDataIOFactory, final BroadcastManagerWorker broadcastManagerWorker, final MetricMessageSender metricMessageSender, - final PersistentConnectionToMasterMap persistentConnectionToMasterMap) { + final PersistentConnectionToMasterMap persistentConnectionToMasterMap, + final AtomicBoolean onHold) { // Essential information + final long taskPrepareStart = System.currentTimeMillis(); this.isExecuted = false; this.taskId = task.getTaskId(); this.taskStateManager = taskStateManager; @@ -116,6 +120,7 @@ public TaskExecutor(final Task task, // Dynamic optimization // Assigning null is very bad, but we are keeping this for now this.idOfVertexPutOnHold = null; + this.onHold = onHold; this.persistentConnectionToMasterMap = persistentConnectionToMasterMap; @@ -125,6 +130,8 @@ public TaskExecutor(final Task task, this.sortedHarnesses = pair.right(); this.timeSinceLastExecution = System.currentTimeMillis(); + metricMessageSender.send("TaskMetric", taskId, "taskPreparationTime", + SerializationUtils.serialize(System.currentTimeMillis() - taskPrepareStart)); } // Get all of the intra-task edges + inter-task edges @@ -292,7 +299,9 @@ irVertex, outputCollector, new TransformContextImpl(broadcastManagerWorker), new ParentTaskDataFetcher( parentTaskReader.getSrcIrVertex(), parentTaskReader, - dataFetcherOutputCollector)); + dataFetcherOutputCollector, + task.getIteratorStartIndex(), + task.getIteratorEndIndex())); } } }); @@ -459,7 +468,7 @@ private boolean handleDataFetchers(final List fetchers) { while (availableIterator.hasNext()) { final DataFetcher dataFetcher = availableIterator.next(); try { - final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender); + final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender, onHold); onEventFromDataFetcher(element, dataFetcher); if (element instanceof Finishmark) { availableIterator.remove(); @@ -784,4 +793,9 @@ public void onRequestForProcessedData() { .build()) .build()); } + + public String getTaskId() { + return this.taskId; + } + } diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java index ab774e968a..2e1ff61317 100644 --- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java +++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java @@ -40,6 +40,7 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -141,7 +142,9 @@ private ParentTaskDataFetcher createFetcher(final InputReader readerForParentTas return new ParentTaskDataFetcher( mock(IRVertex.class), readerForParentTask, // This is the only argument that affects the behavior of ParentTaskDataFetcher - mock(OutputCollector.class)); + mock(OutputCollector.class), + new AtomicInteger(0), + new AtomicInteger(Integer.MAX_VALUE)); } private InputReader generateInputReader(final CompletableFuture completableFuture, diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java index 3e33fec1f4..ae5bfdd732 100644 --- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java +++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java @@ -66,6 +66,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -897,6 +898,6 @@ private List getRangedNumList(final int start, final int end) { private TaskExecutor getTaskExecutor(final Task task, final DAG> taskDag) { return new TaskExecutor(task, taskDag, taskStateManager, intermediateDataIOFactory, broadcastManagerWorker, - metricMessageSender, persistentConnectionToMasterMap); + metricMessageSender, persistentConnectionToMasterMap, new AtomicBoolean(false)); } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index d30ea42394..5c5f413079 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -305,6 +305,8 @@ public void onWorkStealingCheck() { return; } + haltExecutors(); + /* generate work stealing tasks */ final Map> taskToSplitIteratorInfo = splitIterator(skewedTasks); final List wsTasks = generateWorkStealingTasks(scheduleGroup, skewedTasks, taskToSplitIteratorInfo); @@ -800,4 +802,20 @@ private void sendWorkStealingResultToExecutor(final Map executors.forEach(executor -> executor.sendControlMessage(message))); } + + private void haltExecutors() { + // driver sends message to executors + // ask executors to flush metric + ControlMessage.Message haltExecutorsMessage = ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.HaltExecutors) + .build(); + executorRegistry.viewExecutors(executors -> executors.forEach(executor -> + executor.sendControlMessage(haltExecutorsMessage))); + try { + Thread.sleep(1000); // wait for 1 sec + } catch (InterruptedException ignored) { + } + } }