Skip to content

Commit

Permalink
增加TaskExecutionGraph辅助工具,用于执行相互依赖的一组任务
Browse files Browse the repository at this point in the history
  • Loading branch information
entropy-cloud committed Nov 18, 2024
1 parent de3bfb3 commit ea8b302
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 4 deletions.
183 changes: 183 additions & 0 deletions nop-core/src/main/java/io/nop/core/model/graph/TaskExecutionGraph.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package io.nop.core.model.graph;

import io.nop.api.core.time.CoreMetrics;
import io.nop.api.core.util.Guard;
import io.nop.api.core.util.ICancelToken;
import io.nop.core.model.graph.dag.Dag;
import io.nop.core.model.graph.dag.DagNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;

public class TaskExecutionGraph {
static final Logger LOG = LoggerFactory.getLogger(TaskExecutionGraph.class);
private final String taskGraphName;

private final Map<String, Runnable> tasks = new LinkedHashMap<>();
private final Dag dag = new Dag(Dag.DEFAULT_ROOT_NAME);

public TaskExecutionGraph(String taskGraphName) {
this.taskGraphName = taskGraphName;
}

public TaskExecutionGraph addTask(String taskName, Runnable task) {
Guard.checkArgument(!tasks.containsKey(taskName), "duplicate task name");
this.tasks.put(taskName, task);
this.dag.addNextNode(Dag.DEFAULT_ROOT_NAME, taskName);
return this;
}

public boolean isRootNode(String taskName) {
return Dag.DEFAULT_ROOT_NAME.equals(taskName);
}

public boolean containsTask(String taskName) {
return tasks.containsKey(taskName);
}

public TaskExecutionGraph addTaskWithDepends(String taskName, Runnable task, Collection<String> depends) {
return addTask(taskName, task).addDepends(taskName, depends);
}

public TaskExecutionGraph addDepends(String taskName, Collection<String> depends) {
if (depends == null || depends.isEmpty())
return this;

for (String depend : depends) {
dag.addNextNode(Dag.DEFAULT_ROOT_NAME, depend);
dag.addNextNode(depend, taskName);
}
return this;
}

public Set<String> getDepends(String taskName) {
return dag.getNode(taskName).getPrevNodeNames();
}

public TaskExecutionGraph addDepend(String taskName, String depend) {
Guard.notEmpty(taskName, "taskName");
Guard.notEmpty(depend, "depend");

dag.addNextNode(depend, taskName);
return this;
}

public TaskExecutionGraph analyze() {
dag.analyze();
return this;
}

public CompletableFuture<Void> runOnExecutor(Executor executor, ICancelToken cancelToken) {
long beginTime = CoreMetrics.currentTimeMillis();

Map<String, CompletableFuture<Void>> futures = new HashMap<>();
for (String taskName : tasks.keySet()) {
futures.put(taskName, new CompletableFuture<>());
}

Set<String> noDepends = new LinkedHashSet<>();
for (String taskName : tasks.keySet()) {
CompletableFuture<Void> future = waitPrevTasks(futures, taskName);
if (future != null) {
future.whenComplete((ret, err) -> {
if (err != null) {
futures.get(taskName).completeExceptionally(err);
} else {
runTask(executor, cancelToken, futures, taskName);
}
});
} else {
noDepends.add(taskName);
}
}

for (String taskName : noDepends) {
runTask(executor, cancelToken, futures, taskName);
}

CompletableFuture<?>[] endFutures = new CompletableFuture[dag.getEndNodeNames().size()];
int index = 0;
for (String endNode : dag.getEndNodeNames()) {
endFutures[index++] = futures.get(endNode);
}
return CompletableFuture.allOf(endFutures).whenComplete((ret, err) -> {
LOG.info("nop.task.graph-execute-finished:taskGraphName={}, usedTime={}",
taskGraphName, CoreMetrics.currentTimeMillis() - beginTime);
});
}

private CompletableFuture<Void> waitPrevTasks(Map<String, CompletableFuture<Void>> futures, String taskName) {
if (!tasks.containsKey(taskName))
return null;

DagNode node = dag.getNode(taskName);
Set<String> prevNames = node.getPrevNodeNames();
if (prevNames == null || prevNames.isEmpty())
return null;

if (prevNames.contains(Dag.DEFAULT_ROOT_NAME)) {
if (prevNames.size() == 1)
return null;

prevNames = new HashSet<>(prevNames);
prevNames.retainAll(futures.keySet());
if (prevNames.isEmpty())
return null;
}

LOG.debug("nop.task.wait-prev:taskName={},prevNames={}", taskName, prevNames);
if (prevNames.size() == 1)
return futures.get(prevNames.iterator().next());

CompletableFuture<?>[] prevFutures = new CompletableFuture[prevNames.size()];
int index = 0;
for (String prevName : prevNames) {
prevFutures[index++] = futures.get(prevName);
}
return CompletableFuture.allOf(prevFutures);
}

private void runTask(Executor executor, ICancelToken cancelToken, Map<String, CompletableFuture<Void>> futures, String taskName) {
CompletableFuture<Void> future = futures.get(taskName);
executor.execute(() -> {
if (cancelToken != null) {
if (cancelToken.isCancelled()) {
LOG.info("nop.task.skip-cancelled:taskName={}", taskName);
future.complete(null);
return;
}
}
LOG.info("nop.task.run.start:taskName={}", taskName);
long beginTime = CoreMetrics.currentTimeMillis();
try {
tasks.get(taskName).run();
LOG.info("nop.task.run.finish:taskName={},usedTime={}", taskName, CoreMetrics.currentTimeMillis() - beginTime);
if (LOG.isDebugEnabled())
LOG.debug("nop.task.unfinished-count:{}", getUnfinishedCount(futures));
future.complete(null);
} catch (Exception e) {
LOG.error("nop.task.run.error:taskName={}", taskName, e);
future.completeExceptionally(e);
}
});
}

private int getUnfinishedCount(Map<String, CompletableFuture<Void>> futures) {
int count = 0;
for (Map.Entry<String, CompletableFuture<Void>> entry : futures.entrySet()) {
CompletableFuture<Void> future = entry.getValue();
if (!future.isDone())
count++;
}
return count;
}
}
43 changes: 40 additions & 3 deletions nop-core/src/main/java/io/nop/core/model/graph/dag/Dag.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -35,14 +37,21 @@

@DataBean
public class Dag extends AbstractFreezable implements IGraphViewBase<DagNode, DefaultEdge<DagNode>> {
public static final String DEFAULT_ROOT_NAME = "__root__";

private String rootNodeName;
private Map<String, DagNode> nodes = new HashMap<>();
private Map<String, DagNode> nodes = new LinkedHashMap<>();

/**
* 为了从一般性的图结构抽取得到DAG,需要删除哪些链接
*/
private List<List<String>> loopEdges;

/**
* 没有后续节点的末端节点
*/
private Set<String> endNodeNames;

public Dag() {
}

Expand All @@ -51,6 +60,10 @@ public Dag(String rootName) {
nodes.put(rootName, new DagNode(rootName));
}

public Set<String> getNodeNames() {
return nodes.keySet();
}

public boolean containsLoop() {
return !loopEdges.isEmpty();
}
Expand All @@ -67,10 +80,18 @@ public List<List<String>> getLoopEdges() {
return loopEdges;
}

public void setLoopEdges(List<List<String>> loopEdges) {
void setLoopEdges(List<List<String>> loopEdges) {
this.loopEdges = loopEdges;
}

public Set<String> getEndNodeNames() {
return endNodeNames;
}

void setEndNodeNames(Set<String> endNodeNames) {
this.endNodeNames = endNodeNames;
}

public String toDot() {
return GraphvizHelper.toDot(DagNode::getName, this, true, "dag");
}
Expand Down Expand Up @@ -104,6 +125,18 @@ public List<DagNode> getNextNodes(DagNode node) {
return nextNames.stream().map(this::requireNode).collect(Collectors.toList());
}

public Set<String> getNoDependNodeNames() {
Set<String> names = new HashSet<>();
forEachNode(node -> {
if (!node.hasPrevNode()) {
names.add(node.getName());
} else if (node.getPrevNodeNames().size() == 1 && node.getPrevNodeNames().contains(rootNodeName)) {
names.add(node.getName());
}
});
return names;
}

public void forEachNextNode(DagNode node, Consumer<DagNode> action) {
Set<String> nextNames = node.getNextNodeNames();
if (nextNames == null || nextNames.isEmpty())
Expand Down Expand Up @@ -158,10 +191,13 @@ public void analyze() {
new DagAnalyzer(this).analyze();
}

public DagNode addNextNodes(String nodeName, Set<String> next) {
public DagNode addNextNodes(String nodeName, Collection<String> next) {
checkAllowChange();
DagNode node = nodes.computeIfAbsent(nodeName, DagNode::new);
if (next != null) {
for (String name : next) {
nodes.computeIfAbsent(name, DagNode::new);
}
node.addNextNodes(next);
}
return node;
Expand All @@ -170,6 +206,7 @@ public DagNode addNextNodes(String nodeName, Set<String> next) {
public DagNode addNextNode(String nodeName, String next) {
checkAllowChange();
DagNode node = nodes.computeIfAbsent(nodeName, DagNode::new);
nodes.computeIfAbsent(next, DagNode::new);
node.addNextNode(next);
return node;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
Expand Down Expand Up @@ -61,6 +62,7 @@ public void analyze() {
initDagNodes();
initNodeDepth();
initControlNode();
initEndNodes();
dag.setLoopEdges(removedEdges);
}

Expand Down Expand Up @@ -178,6 +180,20 @@ private void _initNodeDepth(DagNode node, int depth) {
}
}

private void initEndNodes() {
Set<String> endNodes = new HashSet<>();

dag.forEachNode(node -> {
if(node.getName().equals(dag.getRootNodeName()))
return;

if (node.getNextNodeNames() == null || node.getNextNodeNames().isEmpty()) {
endNodes.add(node.getName());
}
});
dag.setEndNodeNames(endNodes);
}

private Set<String> collectNormal(DagNode node, Function<DagNode, Set<String>> nextFetcher) {
Set<String> ret = new LinkedHashSet<>();
collectNormal(ret, node, nextFetcher);
Expand Down
11 changes: 10 additions & 1 deletion nop-core/src/main/java/io/nop/core/model/graph/dag/DagNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import io.nop.api.core.annotations.data.DataBean;

import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Set;

Expand Down Expand Up @@ -50,12 +51,20 @@ public int compareTo(DagNode o) {
return Integer.compare(this.nodeIndex, o.nodeIndex);
}

public boolean hasPrevNode() {
return prevNodeNames != null && !prevNodeNames.isEmpty();
}

public boolean hasNextNode() {
return nextNodeNames != null && !nextNodeNames.isEmpty();
}

public void removeNextNode(String name) {
if (this.nextNodeNames != null)
this.nextNodeNames.remove(name);
}

public void addNextNodes(Set<String> next) {
public void addNextNodes(Collection<String> next) {
if (next == null)
return;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.nop.core.model.graph;

import io.nop.api.core.util.FutureHelper;
import io.nop.commons.concurrent.executor.GlobalExecutors;
import io.nop.commons.concurrent.thread.ThreadHelper;
import io.nop.commons.util.MathHelper;
import org.junit.jupiter.api.Test;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class TestTaskExecutionGraph {
@Test
public void testRun() {
AtomicInteger count = new AtomicInteger();
TaskExecutionGraph graph = new TaskExecutionGraph("test");
Runnable task = () -> {
ThreadHelper.sleep(MathHelper.random().nextInt(100));
count.incrementAndGet();
};
graph.addTask("a", task);
graph.addTask("b", task);
graph.addTask("c", task);
graph.addDepend("b", "a");
graph.addDepend("c", "b");
graph.addDepend("a", "c");
graph.analyze();

CompletableFuture<?> future = graph.runOnExecutor(GlobalExecutors.cachedThreadPool(), null);
FutureHelper.syncGet(future);
assertEquals(3, count.get());
}
}

0 comments on commit ea8b302

Please sign in to comment.