Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Nemo-470]Synchronize work stealing process between task executors #312

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ public static String generateTaskId(final String stageId, final int index, final
return stageId + SPLITTER + index + SPLITTER + attempt;
}

/**
* Generates the ID of a task created by Work Stealing.
* @param taskId the ID of original task.
* @return the generated ID.
*/
public static String generateWorkStealingTaskId(final String taskId) {
return getStageIdFromTaskId(taskId) + SPLITTER + getIndexFromTaskId(taskId) + SPLITTER + "*";
}

/**
* Generates the ID for executor.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public class TaskMetric implements StateMetric<TaskState.State> {
private long shuffleReadTime = -1;
private long shuffleWriteBytes = -1;
private long shuffleWriteTime = -1;
private int currentIteratorIndex = -1;
private int totalIteratorNumber = -1;
private long taskPreparationTime = -1;

private static final Logger LOG = LoggerFactory.getLogger(TaskMetric.class.getName());

Expand Down Expand Up @@ -252,6 +255,30 @@ private void setShuffleWriteTime(final long shuffleWriteTime) {
this.shuffleWriteTime = shuffleWriteTime;
}

public final int getCurrentIteratorIndex() {
return this.currentIteratorIndex;
}

private void setCurrentIteratorIndex(final int currentIteratorIndex) {
this.currentIteratorIndex = currentIteratorIndex;
}

public final int getTotalIteratorNumber() {
return this.totalIteratorNumber;
}

private void setTotalIteratorNumber(final int totalIteratorNumber) {
this.totalIteratorNumber = totalIteratorNumber;
}

public final long getTaskPreparationTime() {
return this.taskPreparationTime;
}

private void setTaskPreparationTime(final long taskPreparationTime) {
this.taskPreparationTime = taskPreparationTime;
}

@Override
public final String getId() {
return id;
Expand Down Expand Up @@ -317,6 +344,14 @@ public final boolean processMetricMessage(final String metricField, final byte[]
case "shuffleWriteTime":
setShuffleWriteTime(SerializationUtils.deserialize(metricValue));
break;
case "currentIteratorIndex":
setCurrentIteratorIndex(SerializationUtils.deserialize(metricValue));
break;
case "totalIteratorNumber":
setTotalIteratorNumber(SerializationUtils.deserialize(metricValue));
break;
case "taskPreparationTime":
setTaskPreparationTime(SerializationUtils.deserialize(metricValue));
default:
LOG.warn("metricField {} is not supported.", metricField);
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@
import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;

/**
* Stage.
*/
public final class Stage extends Vertex {
private final List<Integer> taskIndices;
private final Set<String> workStealingTaskIds = new HashSet<>();
private final DAG<IRVertex, RuntimeEdge<IRVertex>> irDag;
private final byte[] serializedIRDag;
private final List<Map<String, Readable>> vertexIdToReadables;
Expand Down Expand Up @@ -93,6 +92,18 @@ public List<Integer> getTaskIndices() {
return taskIndices;
}

/**
* Set IDs for work stealing.
* @param workStealingTaskIds IDs of work stealer tasks.
*/
public void setWorkStealingTaskIds(final Set<String> workStealingTaskIds) {
this.workStealingTaskIds.addAll(workStealingTaskIds);
}

public Set<String> getWorkStealingTaskIds() {
return this.workStealingTaskIds;
}

/**
* @return the parallelism.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A Task (attempt) is a self-contained executable that can be executed on a machine.
Expand All @@ -40,8 +41,13 @@ public final class Task implements Serializable {
private final byte[] serializedIRDag;
private final Map<String, Readable> irVertexIdToReadable;

/* For work stealing */
private final AtomicInteger iteratorStartIndex;
private final AtomicInteger iteratorEndIndex;

/**
* Constructor.
* Default Constructor.
* It initializes iteratorStartIndex as 0 and iteratorEndIndex as Integer.MAX_VALUE.
*
* @param planId the id of the physical plan.
* @param taskId the ID of this task attempt.
Expand All @@ -58,13 +64,42 @@ public Task(final String planId,
final List<StageEdge> taskIncomingEdges,
final List<StageEdge> taskOutgoingEdges,
final Map<String, Readable> irVertexIdToReadable) {
this(planId, taskId, executionProperties, serializedIRDag, taskIncomingEdges, taskOutgoingEdges,
irVertexIdToReadable, new AtomicInteger(0), new AtomicInteger(Integer.MAX_VALUE));
}

/**
* Constructor with iterator information.
* This constructor is used when creating work stealer tasks.
*
* @param planId the id of the physical plan.
* @param taskId the ID of this task attempt.
* @param executionProperties {@link VertexExecutionProperty} map for the corresponding stage
* @param serializedIRDag the serialized DAG of the task.
* @param taskIncomingEdges the incoming edges of the task.
* @param taskOutgoingEdges the outgoing edges of the task.
* @param irVertexIdToReadable the map between IRVertex id to readable.
* @param iteratorStartIndex starting index of iterator.
* @param iteratorEndIndex ending index of iterator.
*/
public Task(final String planId,
final String taskId,
final ExecutionPropertyMap<VertexExecutionProperty> executionProperties,
final byte[] serializedIRDag,
final List<StageEdge> taskIncomingEdges,
final List<StageEdge> taskOutgoingEdges,
final Map<String, Readable> irVertexIdToReadable,
final AtomicInteger iteratorStartIndex,
final AtomicInteger iteratorEndIndex) {
this.planId = planId;
this.taskId = taskId;
this.executionProperties = executionProperties;
this.serializedIRDag = serializedIRDag;
this.taskIncomingEdges = taskIncomingEdges;
this.taskOutgoingEdges = taskOutgoingEdges;
this.irVertexIdToReadable = irVertexIdToReadable;
this.iteratorStartIndex = iteratorStartIndex;
this.iteratorEndIndex = iteratorEndIndex;
}

/**
Expand Down Expand Up @@ -142,6 +177,14 @@ public Map<String, Readable> getIrVertexIdToReadable() {
return irVertexIdToReadable;
}

public AtomicInteger getIteratorStartIndex() {
return this.iteratorStartIndex;
}

public AtomicInteger getIteratorEndIndex() {
return this.iteratorEndIndex;
}

@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
Expand Down
30 changes: 30 additions & 0 deletions runtime/common/src/main/proto/ControlMessage.proto
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ enum MessageType {
PipeInit = 13;
RequestPipeLoc = 14;
PipeLocInfo = 15;
ParentTaskDataCollected = 16;
CurrentlyProcessedBytesCollected = 17;
SendWorkStealingResult = 18;
HaltExecutors = 19;
ResumeTask = 20;
}

message Message {
Expand All @@ -107,6 +112,11 @@ message Message {
optional PipeInitMessage pipeInitMsg = 16;
optional RequestPipeLocationMessage requestPipeLocMsg = 17;
optional PipeLocationInfoMessage pipeLocInfoMsg = 18;
optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19;
optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20;
optional WorkStealingResultMsg sendWorkStealingResult = 21;
optional HaltExecutorsMsg haltExecutorsMsg = 22;
optional ResumeTaskMsg resumeTaskMsg = 23;
}

// Messages from Master to Executors
Expand Down Expand Up @@ -256,3 +266,23 @@ message PipeLocationInfoMessage {
required int64 requestId = 1; // To find the matching request msg
required string executorId = 2;
}

message ParentTaskDataCollectMsg {
required string taskId = 1;
required bytes partitionSizeMap = 2;
}

message CurrentlyProcessedBytesCollectMsg {
required string taskId = 1;
required int64 processedDataBytes = 2;
}

message WorkStealingResultMsg {
required bytes workStealingResult = 1;
}

message HaltExecutorsMsg {
}

message ResumeTaskMsg {
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.protobuf.ByteString;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.coder.BytesDecoderFactory;
import org.apache.nemo.common.coder.BytesEncoderFactory;
import org.apache.nemo.common.coder.DecoderFactory;
Expand Down Expand Up @@ -53,8 +54,13 @@
import org.slf4j.LoggerFactory;

import javax.inject.Inject;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

/**
* Executor.
Expand All @@ -68,6 +74,8 @@ public final class Executor {
* To be used for a thread pool to execute tasks.
*/
private final ExecutorService executorService;
/* For work stealing */
private final ScheduledExecutorService workStealingManager;

/**
* In charge of this executor's intermediate data transfer.
Expand All @@ -85,6 +93,12 @@ public final class Executor {

private final MetricMessageSender metricMessageSender;

/**
* For runtime optimizations.
*/
private final List<Pair<TaskExecutor, AtomicBoolean>> listOfWorkingTaskExecutors;
private final Map<String, Pair<AtomicInteger, AtomicInteger>> taskIdToIteratorInfo;

@Inject
private Executor(@Parameter(JobConf.ExecutorId.class) final String executorId,
final PersistentConnectionToMasterMap persistentConnectionToMasterMap,
Expand All @@ -97,12 +111,19 @@ private Executor(@Parameter(JobConf.ExecutorId.class) final String executorId,
this.executorService = Executors.newCachedThreadPool(new BasicThreadFactory.Builder()
.namingPattern("TaskExecutor thread-%d")
.build());
this.workStealingManager = Executors.newSingleThreadScheduledExecutor(new BasicThreadFactory.Builder()
.namingPattern("workstealing manager in executorSide")
.build());
this.persistentConnectionToMasterMap = persistentConnectionToMasterMap;
this.serializerManager = serializerManager;
this.intermediateDataIOFactory = intermediateDataIOFactory;
this.broadcastManagerWorker = broadcastManagerWorker;
this.metricMessageSender = metricMessageSender;
messageEnvironment.setupListener(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID, new ExecutorMessageReceiver());

this.listOfWorkingTaskExecutors = Collections.synchronizedList(new LinkedList<>());
this.taskIdToIteratorInfo = new ConcurrentHashMap();

}

public String getExecutorId() {
Expand Down Expand Up @@ -148,8 +169,18 @@ private void launchTask(final Task task) {
e.getPropertyValue(CompressionProperty.class).orElse(null),
e.getPropertyValue(DecompressionProperty.class).orElse(null))));

new TaskExecutor(task, irDag, taskStateManager, intermediateDataIOFactory, broadcastManagerWorker,
metricMessageSender, persistentConnectionToMasterMap).execute();
final AtomicBoolean onHold = new AtomicBoolean(false);
final TaskExecutor taskExecutor = new TaskExecutor(task, irDag, taskStateManager, intermediateDataIOFactory,
broadcastManagerWorker, metricMessageSender, persistentConnectionToMasterMap, onHold);
Pair<TaskExecutor, AtomicBoolean> taskExecutorPair = Pair.of(taskExecutor, onHold);

listOfWorkingTaskExecutors.add(taskExecutorPair);
taskIdToIteratorInfo.put(task.getTaskId(),
Pair.of(task.getIteratorStartIndex(), task.getIteratorEndIndex()));
taskExecutor.execute();
listOfWorkingTaskExecutors.remove(taskExecutorPair);
taskIdToIteratorInfo.remove(task.getTaskId());

} catch (final Exception e) {
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
Expand Down Expand Up @@ -206,6 +237,32 @@ public void terminate() {
}
}

/* Methods for work stealing */
private synchronized void onDataRequestReceived() {
listOfWorkingTaskExecutors.forEach(pair -> pair.right().set(true));
//listOfWorkingTaskExecutors.forEach(TaskExecutor::onRequestForProcessedData);
}

private synchronized void resumePausedTasks() {
listOfWorkingTaskExecutors.forEach(pair -> pair.right().set(false));
}

private synchronized void resumePausedTasksWithWorkStealing(final Map<String, Pair<Integer, Integer>> result) {
// update iterator information
// skewed tasks: set iterator value
// non skewed tasks: do not change iterator
for (String taskId : taskIdToIteratorInfo.keySet()) {
Pair<Integer, Integer> startAndEndIndex = result.get(taskId);
if (startAndEndIndex.left() == 0 && startAndEndIndex.right() == Integer.MAX_VALUE) {
continue;
}
Pair<AtomicInteger, AtomicInteger> currentInfo = taskIdToIteratorInfo.get(taskId);
currentInfo.left().set(startAndEndIndex.left()); // null pointer exception here!
currentInfo.right().set(startAndEndIndex.right());
}
resumePausedTasks();
}

/**
* MessageListener for Executor.
*/
Expand All @@ -220,6 +277,19 @@ public void onMessage(final ControlMessage.Message message) {
SerializationUtils.deserialize(scheduleTaskMsg.getTask().toByteArray());
onTaskReceived(task);
break;
case HaltExecutors:
onDataRequestReceived();
break;
case ResumeTask:
resumePausedTasks();
break;
case SendWorkStealingResult:
final ControlMessage.WorkStealingResultMsg workStealingResultMsg = message.getSendWorkStealingResult();
final Map<String, Pair<Integer, Integer>> iteratorInformationMap =
SerializationUtils.deserialize(workStealingResultMsg.getWorkStealingResult().toByteArray());
LOG.error("received: {}", iteratorInformationMap);
resumePausedTasksWithWorkStealing(iteratorInformationMap);
break;
case RequestMetricFlush:
metricMessageSender.flush();
break;
Expand Down
Loading