Skip to content

Commit

Permalink
Ensure structured concurrency exits when all tasks have completed.
Browse files Browse the repository at this point in the history
  • Loading branch information
fluentfuture committed Jan 6, 2025
1 parent 899292d commit 1fb28ec
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 5 deletions.
41 changes: 41 additions & 0 deletions mug/src/main/java/com/google/mu/util/concurrent/Completion.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*****************************************************************************
* ------------------------------------------------------------------------- *
* Licensed under the Apache License, Version 2.0 (the "License"); *
* you may not use this file except in compliance with the License. *
* You may obtain a copy of the License at *
* *
* http://www.apache.org/licenses/LICENSE-2.0 *
* *
* Unless required by applicable law or agreed to in writing, software *
* distributed under the License is distributed on an "AS IS" BASIS, *
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
* See the License for the specific language governing permissions and *
* limitations under the License. *
*****************************************************************************/
package com.google.mu.util.concurrent;

import java.util.concurrent.Phaser;

/** Helper to ensure that all started tasks must have run to completion. */
final class Completion implements AutoCloseable {
private final Phaser phaser = new Phaser(1);

void run(Runnable task) {
wrap(task).run();
}

Runnable wrap(Runnable task) {
phaser.register();
return () -> {
try {
task.run();
} finally {
phaser.arrive();
}
};
}

@Override public void close() {
phaser.arriveAndAwaitAdvance();
}
}
8 changes: 5 additions & 3 deletions mug/src/main/java/com/google/mu/util/concurrent/Fanout.java
Original file line number Diff line number Diff line change
Expand Up @@ -481,16 +481,18 @@ Scope add(Runnable... tasks) {
}

void run() throws StructuredConcurrencyInterruptedException {
try {
withUnlimitedConcurrency().parallelize(runnables.stream());
try (Completion completion = new Completion()){
withUnlimitedConcurrency().parallelize(runnables.stream().map(completion::wrap));
} catch (InterruptedException e) {
throw new StructuredConcurrencyInterruptedException(e);
}
}

@Deprecated
void runUninterruptibly() {
withUnlimitedConcurrency().parallelizeUninterruptibly(runnables.stream());
try (Completion completion = new Completion()){
withUnlimitedConcurrency().parallelizeUninterruptibly(runnables.stream().map(completion::wrap));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,10 @@ public void parallelizeUninterruptibly(Stream<? extends Runnable> tasks) {
inputs -> {
List<O> outputs = new ArrayList<>(inputs.size());
outputs.addAll(Collections.nCopies(inputs.size(), null));
try {
try (Completion completion = new Completion()){
parallelize(
IntStream.range(0, inputs.size()).boxed(),
i -> outputs.set(i, concurrentFunction.apply(inputs.get(i))));
i -> completion.run(() -> outputs.set(i, concurrentFunction.apply(inputs.get(i)))));
} catch (InterruptedException e) {
throw new StructuredConcurrencyInterruptedException(e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.google.mu.util.concurrent;

import static com.google.common.truth.Truth.assertThat;

import java.util.concurrent.atomic.AtomicBoolean;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class CompletionTest {
@Test public void noTaskStarted() throws Exception {
try (Completion completion = new Completion()) {}
}

@Test public void singleTask_succeeded() throws Exception {
try (Completion completion = new Completion()) {
completion.run(() -> {});
}
}

@Test public void singleTask_failed() throws Exception {
try {
try (Completion completion = new Completion()) {
completion.run(() -> {
throw new RuntimeException("test");
});
}
} catch (RuntimeException e) {
assertThat(e).hasMessageThat().contains("test");
}
}

@Test public void twoTasks_succeeded() throws Exception {
AtomicBoolean done = new AtomicBoolean();
try (Completion completion = new Completion()) {
completion.run(() -> {});
new Thread(completion.wrap(() -> {
done.set(true);
})).start();
}
assertThat(done.get()).isTrue();
}

@Test public void twoTasks_failed() throws Exception {
AtomicBoolean done = new AtomicBoolean();
try (Completion completion = new Completion()) {
new Thread(completion.wrap(() -> {
done.set(true);
})).start();
completion.run(() -> {
throw new RuntimeException("test");
});
} catch (RuntimeException e) {
assertThat(e).hasMessageThat().contains("test");
}
assertThat(done.get()).isTrue();
}
}

0 comments on commit 1fb28ec

Please sign in to comment.