diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java index d2d3293182..129b7ac380 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java @@ -786,7 +786,6 @@ public final Publisher onErrorResume(Predicate predicate, * return results; * } * @param mapper Convert each item emitted by this {@link Publisher} into another {@link Publisher}. - * each mapped {@link Publisher}. * @param The type of mapped {@link Publisher}. * @return A new {@link Publisher} which flattens the emissions from all mapped {@link Publisher}s. * @see ReactiveX flatMap operator. @@ -871,7 +870,6 @@ public final Publisher flatMapMerge(Function * @param mapper Convert each item emitted by this {@link Publisher} into another {@link Publisher}. - * each mapped {@link Publisher}. * @param The type of mapped {@link Publisher}. * @return A new {@link Publisher} which flattens the emissions from all mapped {@link Publisher}s. * @see ReactiveX flatMap operator. @@ -1618,6 +1616,86 @@ public final Publisher flatMapConcatIterable(Function(this, mapper); } + /** + * Return a {@link Publisher} that will switch to the latest {@link Publisher} emitted from {@code mapper} and the + * prior {@link Publisher} will be cancelled. Both upstream and the last switched {@link Publisher} must complete + * before the returned {@link Publisher} completes. If either upstream or the currently active {@link Publisher} + * terminate in error the returned {@link Publisher} is terminated with that error. + *
{@code
+     *     ExecutorService e = ...;
+     *     List>> futures = ...; // assume this is thread safe
+     *
+     *     for (T t : resultOfThisPublisher()) {
+     *         // Note that flatMap process results in parallel.
+     *         futures.add(e.submit(() -> {
+     *             // Approximation: control flow is simplified here but when a later mapper is applied any incomplete
+     *             // results from a previous mapper are cancelled and result in empty results.
+     *             return mapper.apply(t); // Asynchronous result is flatten into a value by this operator.
+     *         }));
+     *     }
+     *     List results = new ArrayList<>(futures.size());
+     *     // This is an approximation, this operator does not provide any ordering guarantees for the results.
+     *     for (Future> future : futures) {
+     *         List rList = future.get(); // Throws if the processing for this item failed.
+     *         results.addAll(rList);
+     *     }
+     *     return results;
+     * }
+ * @param mapper Convert each item emitted by this {@link Publisher} into another {@link Publisher}. + * @param The type of mapped {@link Publisher}. + * @return A {@link Publisher} that will switch to the latest {@link Publisher} emitted from {@code mapper} and the + * prior {@link Publisher} will be cancelled. + * @see ReactiveX switch operator. + * @see +Kotlin flatMapLatest + * @see #switchMapDelayError(Function) + */ + public final Publisher switchMap(Function> mapper) { + return new PublisherSwitchMap<>(this, 0, mapper); + } + + /** + * Return a {@link Publisher} that will switch to the latest {@link Publisher} emitted from {@code mapper} and the + * prior {@link Publisher} will be cancelled. Both upstream and the last switched {@link Publisher} must terminate + * before the returned {@link Publisher} terminates (including errors). + * @param mapper Convert each item emitted by this {@link Publisher} into another {@link Publisher}. + * @param The type of mapped {@link Publisher}. + * @return A {@link Publisher} that will switch to the latest {@link Publisher} emitted from {@code mapper} and the + * prior {@link Publisher} will be cancelled. + * @see ReactiveX switch operator. + * @see +Kotlin flatMapLatest + * @see #switchMap(Function) + * @see #switchMapDelayError(Function, int) + */ + public final Publisher switchMapDelayError(Function> mapper) { + return new PublisherSwitchMap<>(this, true, mapper); + } + + /** + * Return a {@link Publisher} that will switch to the latest {@link Publisher} emitted from {@code mapper} and the + * prior {@link Publisher} will be cancelled. Both upstream and the last switched {@link Publisher} must terminate + * before the returned {@link Publisher} terminates (including errors). + * @param mapper Convert each item emitted by this {@link Publisher} into another {@link Publisher}. + * @param maxDelayedErrorsHint The maximum amount of errors that will be queued. After this point exceptions maybe + * discarded to reduce memory consumption. + * @param The type of mapped {@link Publisher}. + * @return A {@link Publisher} that will switch to the latest {@link Publisher} emitted from {@code mapper} and the + * prior {@link Publisher} will be cancelled. + * @see ReactiveX switch operator. + * @see +Kotlin flatMapLatest + * @see #switchMap(Function) + * @see #switchMapDelayError(Function) + */ + public final Publisher switchMapDelayError(Function> mapper, + int maxDelayedErrorsHint) { + return new PublisherSwitchMap<>(this, maxDelayedErrorsHint, mapper); + } + /** * Merge two {@link Publisher}s together. There is no guaranteed ordering of events emitted from the returned * {@link Publisher}. diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherSwitchMap.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherSwitchMap.java new file mode 100644 index 0000000000..75a0d827a9 --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherSwitchMap.java @@ -0,0 +1,387 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.internal.TerminalNotification; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.CompositeExceptionUtils.addPendingError; +import static io.servicetalk.concurrent.api.CompositeExceptionUtils.maxDelayedErrors; +import static io.servicetalk.concurrent.api.SourceAdapters.toSource; +import static io.servicetalk.concurrent.internal.EmptySubscriptions.EMPTY_SUBSCRIPTION; +import static io.servicetalk.concurrent.internal.TerminalNotification.complete; +import static io.servicetalk.concurrent.internal.TerminalNotification.error; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.atomic.AtomicIntegerFieldUpdater.newUpdater; + +final class PublisherSwitchMap extends AbstractAsynchronousPublisherOperator { + private static final Logger LOGGER = LoggerFactory.getLogger(PublisherSwitchMap.class); + private final int maxDelayedErrors; + private final Function> mapper; + + PublisherSwitchMap(final Publisher original, + final boolean delayError, + final Function> mapper) { + this(original, maxDelayedErrors(delayError), mapper); + } + + PublisherSwitchMap(final Publisher original, + final int maxDelayedErrors, + final Function> mapper) { + super(original); + if (maxDelayedErrors < 0) { + throw new IllegalArgumentException("maxDelayedErrors: " + maxDelayedErrors + " (expected >=0)"); + } + this.maxDelayedErrors = maxDelayedErrors; + this.mapper = requireNonNull(mapper); + } + + @Override + public Subscriber apply(final Subscriber subscriber) { + return new SwitchSubscriber<>(subscriber, this); + } + + private static final class SwitchSubscriber implements Subscriber { + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater stateUpdater = + newUpdater(SwitchSubscriber.RSubscriber.class, "state"); + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater pendingErrorCountUpdater = + AtomicIntegerFieldUpdater.newUpdater(SwitchSubscriber.class, "pendingErrorCount"); + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater pendingErrorUpdater = + AtomicReferenceFieldUpdater.newUpdater(SwitchSubscriber.class, Throwable.class, "pendingError"); + private static final int INNER_STATE_IDLE = 0; + private static final int INNER_STATE_EMITTING = 1; + private static final int INNER_STATE_DISPOSED = 2; + private static final int INNER_STATE_COMPLETE = 3; + private static final int INNER_STATE_TERMINATED = 4; + private static final int OUTER_STATE_SHIFT = 3; + private static final int OUTER_STATE_MASK = -8; + private static final int INNER_STATE_MASK = ~OUTER_STATE_MASK; + private static final int OUTER_STATE_COMPLETE = 1; + private static final int OUTER_STATE_TERMINATED = 2; + private final SequentialSubscription rSubscription = new SequentialSubscription(); + private final PublisherSwitchMap parent; + private final Subscriber target; + @Nullable + private Subscription tSubscription; + @Nullable + private RSubscriber currPublisher; + @SuppressWarnings("unused") + private volatile int pendingErrorCount; + @Nullable + private volatile Throwable pendingError; + + private SwitchSubscriber(final Subscriber target, + final PublisherSwitchMap parent) { + this.target = target; + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription subscription) { + tSubscription = requireNonNull(subscription); + target.onSubscribe(rSubscription); + tSubscription.request(1); + } + + @Override + public void onNext(@Nullable T t) { + final Publisher nextPub = parent.mapper.apply(t); + if (nextPub == null) { + assert tSubscription != null; + tSubscription.request(1); + return; + } + currPublisher = new RSubscriber(currPublisher); + toSource(nextPub).subscribe(currPublisher); + } + + @Override + public void onError(Throwable t) { + if (currPublisher != null) { + try { + if (parent.maxDelayedErrors <= 0) { + currPublisher.dispose(EMPTY_SUBSCRIPTION); + } + } finally { + final Throwable cause = outerErrorUpdateState(t); + if (cause != null) { + target.onError(cause); + } + } + } else { + target.onError(t); + } + } + + @Override + public void onComplete() { + // If current publisher isn't null defer terminal signals to that publisher, otherwise terminate here. + TerminalNotification terminalNotification = complete(); + if (currPublisher == null || (terminalNotification = outerCompleteUpdateState()) != null) { + terminalNotification.terminate(target); + } + } + + @Nullable + private Throwable outerErrorUpdateState(Throwable t) { + assert currPublisher != null; + t = localAddPendingError(t); + for (;;) { + final int cState = currPublisher.state; + final int nState = setOuterState(cState, OUTER_STATE_TERMINATED); + if (stateUpdater.compareAndSet(currPublisher, cState, nState)) { + final int innerState = getInnerState(nState); + return parent.maxDelayedErrors <= 0 && innerState != INNER_STATE_TERMINATED || + (parent.maxDelayedErrors > 0 && + (innerState == INNER_STATE_TERMINATED || innerState == INNER_STATE_COMPLETE)) ? t : null; + } + } + } + + @Nullable + private TerminalNotification outerCompleteUpdateState() { + assert currPublisher != null; + for (;;) { + final int cState = currPublisher.state; + final int nState = setOuterState(cState, OUTER_STATE_COMPLETE); + if (stateUpdater.compareAndSet(currPublisher, cState, nState)) { + final int innerState = getInnerState(nState); + if (parent.maxDelayedErrors <= 0) { + return innerState == INNER_STATE_COMPLETE ? complete() : null; + } else if (innerState == INNER_STATE_TERMINATED || innerState == INNER_STATE_COMPLETE) { + final Throwable cPendingError = pendingError; + return cPendingError != null ? error(cPendingError) : complete(); + } + return null; + } + } + } + + private Throwable localAddPendingError(Throwable t) { + Throwable currPendingError = pendingError; + if (currPendingError == null) { + if (pendingErrorUpdater.compareAndSet(SwitchSubscriber.this, null, t)) { + currPendingError = t; + } else { + currPendingError = pendingError; + assert currPendingError != null; + addPendingError(pendingErrorCountUpdater, SwitchSubscriber.this, + parent.maxDelayedErrors, currPendingError, t); + } + } else { + addPendingError(pendingErrorCountUpdater, SwitchSubscriber.this, + parent.maxDelayedErrors, currPendingError, t); + } + return currPendingError; + } + + private static int setOuterState(int currState, int newState) { + return (newState << OUTER_STATE_SHIFT) | (currState & INNER_STATE_MASK); + } + + private static int setInnerState(int currState, int newState) { + return (currState & OUTER_STATE_MASK) | newState; + } + + private static int getOuterState(int state) { + return state >> OUTER_STATE_SHIFT; + } + + private static int getInnerState(int state) { + return state & INNER_STATE_MASK; + } + + private final class RSubscriber implements Subscriber { + volatile int state; + @Nullable + private final RSubscriber prevPublisher; + @Nullable + private Subscription localSubscription; + @Nullable + private Subscription nextSubscriptionIfDisposePending; + + private RSubscriber(@Nullable RSubscriber prevPublisher) { + this.prevPublisher = prevPublisher; + } + + @Override + public void onSubscribe(Subscription subscription) { + localSubscription = requireNonNull(subscription); + if (prevPublisher != null) { + prevPublisher.dispose(subscription); + } else { + switchTo(subscription); + } + } + + @Override + public void onNext(@Nullable R result) { + int innerState; + for (;;) { + final int cState = state; + innerState = getInnerState(cState); + final int outerState = getOuterState(cState); + if (outerState == OUTER_STATE_TERMINATED) { + return; + } else if (innerState == INNER_STATE_IDLE) { + if (stateUpdater.compareAndSet(this, cState, setInnerState(cState, INNER_STATE_EMITTING))) { + break; + } + } else if (innerState == INNER_STATE_EMITTING) { + // Allow reentry because we don't want to drop data. + break; + } else { + LOGGER.debug("Disposed Subscriber ignoring signal state={} subscriber='{}' onNext='{}'", + cState, SwitchSubscriber.this, result); + // Only states are TERMINATED and DISPOSED. + // DISPOSED -> Subscriber is no longer the newest subscriber, and it is OK to drop data + // because the "newest"/"active" Subscriber is assumed to get the "current" state as the + // first onNext signal and indicated a "switch" and downstream will do a delta between "old" + // and "current" state. + // TERMINATED -> This is a terminal state, and re-try/subscribe must happen to reset state. + + // It is OK to not call dataSubscription.itemReceived() because we won't be propagating it + // downstream when we switch to a newer subscriber we want to request more items to preserve + // the demand from downstream Subscription. + return; + } + } + try { + rSubscription.itemReceived(); + target.onNext(result); + } finally { + // Only attempt "unlock" if we acquired the lock, otherwise this is reentry and when the stack + // unwinds the lock will be released. + if (innerState == INNER_STATE_IDLE) { + for (;;) { + final int cState = state; + innerState = getInnerState(cState); + if (innerState == INNER_STATE_DISPOSED) { + assert nextSubscriptionIfDisposePending != null; + assert localSubscription != null; + switchWhenDisposed(localSubscription, nextSubscriptionIfDisposePending); + } else if (innerState == INNER_STATE_TERMINATED || innerState == INNER_STATE_COMPLETE || + stateUpdater.compareAndSet(this, cState, setInnerState(cState, INNER_STATE_IDLE))) { + break; + } + } + } + } + } + + @Override + public void onError(Throwable t) { + Throwable currPendingError = null; + for (;;) { + final int cState = state; + final int innerState = getInnerState(cState); + if (innerState == INNER_STATE_DISPOSED) { + break; + } else if (parent.maxDelayedErrors <= 0) { + if (stateUpdater.compareAndSet(this, cState, + setInnerState(cState, INNER_STATE_TERMINATED))) { + final int outerState = getOuterState(cState); + if (outerState != OUTER_STATE_TERMINATED) { + target.onError(t); + } + break; + } + } else { + if (currPendingError == null) { + currPendingError = localAddPendingError(t); + } + + if (stateUpdater.compareAndSet(this, cState, + setInnerState(cState, INNER_STATE_TERMINATED))) { + final int outerState = getOuterState(cState); + if (outerState == OUTER_STATE_TERMINATED || outerState == OUTER_STATE_COMPLETE) { + target.onError(currPendingError); + } + break; + } + } + } + } + + @Override + public void onComplete() { + for (;;) { + final int cState = state; + final int innerState = getInnerState(cState); + if (innerState == INNER_STATE_DISPOSED) { + break; + } else if (stateUpdater.compareAndSet(this, cState, setInnerState(cState, INNER_STATE_COMPLETE))) { + final int outerState = getOuterState(cState); + if (outerState == OUTER_STATE_COMPLETE) { + target.onComplete(); + } else if (parent.maxDelayedErrors > 0 && outerState == OUTER_STATE_TERMINATED) { + final Throwable cause = pendingError; + assert cause != null; + target.onError(cause); + } + break; + } + } + } + + void dispose(Subscription nextSubscription) { + nextSubscriptionIfDisposePending = nextSubscription; + for (;;) { + final int cState = state; + final int innerState = getInnerState(cState); + if (innerState == INNER_STATE_TERMINATED || innerState == INNER_STATE_COMPLETE) { + break; + } else if (stateUpdater.compareAndSet(this, cState, setInnerState(cState, INNER_STATE_DISPOSED))) { + // if emitting -> onNext will handle after it is done to avoid concurrency + // if disposed -> no need to switch again + if (innerState == INNER_STATE_IDLE) { + assert localSubscription != null; + switchWhenDisposed(localSubscription, nextSubscription); + } + break; + } + } + } + + private void switchWhenDisposed(Subscription mySubscription, Subscription nextSubscription) { + try { + mySubscription.cancel(); + } finally { + switchTo(nextSubscription); + } + } + + private void switchTo(Subscription nextSubscription) { + try { + rSubscription.switchTo(nextSubscription); + } finally { + assert tSubscription != null; + tSubscription.request(1); + } + } + } + } +} diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherSwitchMapTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherSwitchMapTest.java new file mode 100644 index 0000000000..74207d75b2 --- /dev/null +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherSwitchMapTest.java @@ -0,0 +1,429 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource; +import io.servicetalk.concurrent.internal.DeliberateException; +import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Supplier; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.Publisher.never; +import static io.servicetalk.concurrent.api.PublisherSwitchMapTest.SwitchMapSignal.toSwitchFunction; +import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; +import static io.servicetalk.concurrent.api.SourceAdapters.toSource; +import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static java.util.Arrays.asList; +import static java.util.Objects.requireNonNull; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +final class PublisherSwitchMapTest { + private final TestSubscription testSubscription = new TestSubscription(); + private final TestPublisher publisher = new TestPublisher.Builder() + .disableAutoOnSubscribe().build(sub -> { + sub.onSubscribe(testSubscription); + return sub; + }); + private final TestPublisherSubscriber> subscriber = + new TestPublisherSubscriber<>(); + + @ParameterizedTest(name = "onError={0}, delayError={1}") + @CsvSource({"true,true", "true,false", "false,true", "false,false"}) + void noSignalsTerminal(boolean onError, boolean delayError) { + final TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); + toSource((delayError ? publisher.switchMapDelayError(i -> never()) : publisher.switchMap(i -> never()))) + .subscribe(subscriber); + + validateTerminal(publisher, subscriber, onError); + } + + @ParameterizedTest(name = "onError={0}, delayError={1}") + @CsvSource({"true,true", "true,false", "false,true", "false,false"}) + void noSwitchMultipleSignals(boolean onError, boolean delayError) throws InterruptedException { + final TestSubscription testSubscription2 = new TestSubscription(); + final TestPublisher publisher2 = new TestPublisher.Builder() + .disableAutoOnSubscribe().build(sub -> { + sub.onSubscribe(testSubscription2); + return sub; + }); + + final int firstT = 0; + final String firstR = "foo"; + final String secondR = "bar"; + Function>> func = + toSwitchFunction(i -> i == firstT ? publisher2 : never()); + toSource(delayError ? publisher.switchMapDelayError(func) : publisher.switchMap(func)) + .subscribe(subscriber); + + subscriber.awaitSubscription().request(2); + testSubscription.awaitRequestN(1); + publisher.onNext(firstT); + publisher.onComplete(); + + assertThat(subscriber.pollOnNext(10, TimeUnit.MILLISECONDS), nullValue()); + + testSubscription2.awaitRequestN(1); + publisher2.onNext(firstR); + validateSignal(subscriber.takeOnNext(), firstR, false); + + testSubscription2.awaitRequestN(2); + publisher2.onNext(secondR); + validateSignal(subscriber.takeOnNext(), secondR, false); + + validateTerminal(publisher2, subscriber, onError); + } + + @ParameterizedTest(name = "onError={0}, delayError={1}") + @CsvSource({"true,true", "true,false", "false,true", "false,false"}) + void multipleSwitches(boolean onError, boolean delayError) throws InterruptedException { + final TestSubscription testSubscription2 = new TestSubscription(); + final TestPublisher publisher2 = new TestPublisher.Builder() + .disableAutoOnSubscribe().build(sub -> { + sub.onSubscribe(testSubscription2); + return sub; + }); + final TestSubscription testSubscription3 = new TestSubscription(); + final TestPublisher publisher3 = new TestPublisher.Builder() + .disableAutoOnSubscribe().build(sub -> { + sub.onSubscribe(testSubscription3); + return sub; + }); + final TestSubscription testSubscription4 = new TestSubscription(); + final TestPublisher publisher4 = new TestPublisher.Builder() + .disableAutoOnSubscribe().build(sub -> { + sub.onSubscribe(testSubscription4); + return sub; + }); + + final int firstT = 0; + final int secondT = 1; + final int thirdT = 2; + final String firstR = "foo"; + final String secondR = "bar"; + final String thirdR = "baz"; + final String ignoredR = "IGNORED"; + Function>> func = toSwitchFunction( + i -> i == firstT ? publisher2 : + i == secondT ? publisher3 : + i == thirdT ? publisher4 : never()); + toSource(delayError ? publisher.switchMapDelayError(func) : publisher.switchMap(func)) + .subscribe(subscriber); + + subscriber.awaitSubscription().request(3); + testSubscription.awaitRequestN(1); + publisher.onNext(firstT); + + testSubscription2.awaitRequestN(1); + publisher2.onNext(firstR); + validateSignal(subscriber.takeOnNext(), firstR, false); + + testSubscription.awaitRequestN(2); + publisher.onNext(secondT); + + testSubscription2.awaitCancelled(); + testSubscription3.awaitRequestN(2); + + // Send signals on "old" publishers that we switched from an ignore all the signals. + publisher2.onNext(ignoredR); + publisher2.onError(new IllegalStateException("should be ignored")); + + // Don't emit any items, and assert that switch is still done + testSubscription.awaitRequestN(3); + publisher.onNext(thirdT); + publisher.onComplete(); + testSubscription3.awaitCancelled(); + testSubscription4.awaitRequestN(2); + + // Send signals on "old" publishers that we switched from an ignore all the signals. + publisher3.onNext(ignoredR); + publisher3.onError(new IllegalStateException("should be ignored")); + + publisher4.onNext(secondR, thirdR); + validateSignal(subscriber.takeOnNext(), secondR, true); + validateSignal(subscriber.takeOnNext(), thirdR, false); + + validateTerminal(publisher4, subscriber, onError); + } + + @SuppressWarnings("unchecked") + @ParameterizedTest(name = "offloadFirstDemand={0}, delayError={1}") + @CsvSource({"true,true", "true,false", "false,true", "false,false"}) + void reentry(boolean offloadFirstDemand, boolean delayError) throws InterruptedException, ExecutionException { + final String completeSignal = "complete"; + BlockingQueue signals = new LinkedBlockingQueue<>(); + Executor executor = Executors.newCachedThreadExecutor(); + try { + Function>> func = toSwitchFunction(i -> i == 1 ? + fromSource(new ReentryPublisher(100, 103)) : never()); + Publisher pub = Publisher.from(1); + toSource(delayError ? pub.switchMapDelayError(func) : pub.switchMap(func) + ).subscribe(new PublisherSource.Subscriber>() { + @Nullable + private PublisherSource.Subscription subscription; + private boolean seenOnNext; + + @Override + public void onSubscribe(PublisherSource.Subscription s) { + subscription = s; + subscription.request(1); + } + + @Override + public void onNext(@Nullable SwitchMapSignal next) { + assert subscription != null; + signals.add(requireNonNull(next)); + final boolean localSeenOnNext = seenOnNext; + seenOnNext = true; + if (localSeenOnNext || !offloadFirstDemand) { + subscription.request(1); + } else { + // SequentialSubscription will prevent reentry from onNext when invoked from switchTo, so we + // offload here just to be sure reentry cases are covered. + executor.execute(() -> { + try { + // If this task executes quickly the Publisher may see demand in its loop from + // the request(n) in onSubscribe without triggering reentry from onNext. + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + subscription.request(1); + }); + } + } + + @Override + public void onError(Throwable t) { + signals.add(t); + } + + @Override + public void onComplete() { + signals.add(completeSignal); + } + }); + + Object signal = signals.take(); + assertThat(signal, instanceOf(SwitchMapSignal.class)); + validateSignal((SwitchMapSignal) signal, 100, false); + signal = signals.take(); + assertThat(signal, instanceOf(SwitchMapSignal.class)); + validateSignal((SwitchMapSignal) signal, 101, false); + signal = signals.take(); + assertThat(signal, instanceOf(SwitchMapSignal.class)); + validateSignal((SwitchMapSignal) signal, 102, false); + signal = signals.take(); + assertThat(signal, is(completeSignal)); + } finally { + executor.closeAsync().toFuture().get(); + } + } + + @ParameterizedTest(name = "delayError={0}") + @ValueSource(booleans = {true, false}) + void nullPublisher(boolean delayError) throws InterruptedException { + Function>> func = + toSwitchFunction(i -> i == 0 ? null : Publisher.from("foo")); + toSource(delayError ? publisher.switchMapDelayError(func) : publisher.switchMap(func)) + .subscribe(subscriber); + + subscriber.awaitSubscription().request(1); + testSubscription.awaitRequestN(1); + + publisher.onNext(0); + + testSubscription.awaitRequestN(2); + publisher.onNext(1); + + validateSignal(subscriber.takeOnNext(), "foo", false); + validateTerminal(publisher, subscriber, false); + } + + @ParameterizedTest(name = "delayError={0}, onComplete={1} rootErrorFirst={2}") + @CsvSource({"true,true,true", "true,false,true", "false,true,true", + "true,true,false", "true,false,false", "false,true,false"}) + void onErrorCancelMappedPublisher(boolean delayError, boolean onComplete, + boolean rootErrorFirst) throws InterruptedException { + final TestSubscription testSubscription2 = new TestSubscription(); + final TestPublisher publisher2 = new TestPublisher.Builder() + .disableAutoOnSubscribe().build(sub -> { + sub.onSubscribe(testSubscription2); + return sub; + }); + + Function>> func = toSwitchFunction(i -> publisher2); + toSource(delayError ? publisher.switchMapDelayError(func, 2) : publisher.switchMap(func)) + .subscribe(subscriber); + + subscriber.awaitSubscription().request(1); + testSubscription.awaitRequestN(1); + publisher.onNext(0); + + testSubscription2.awaitRequestN(1); + publisher2.onNext("foo"); + validateSignal(subscriber.takeOnNext(), "foo", false); + + if (delayError) { + DeliberateException deliberateException = new DeliberateException(); + if (rootErrorFirst) { + publisher.onError(deliberateException); + } + + Throwable secondCause = new IllegalStateException("second exception"); + if (onComplete) { + publisher2.onComplete(); + } else { + publisher2.onError(secondCause); + } + + if (rootErrorFirst) { + Throwable cause = subscriber.awaitOnError(); + assertThat(cause, is(deliberateException)); + if (!onComplete) { + assertThat(asList(cause.getSuppressed()), contains(secondCause)); + } + } else { + publisher.onError(deliberateException); + Throwable cause = subscriber.awaitOnError(); + if (onComplete) { + assertThat(cause, is(deliberateException)); + } else { + assertThat(cause, is(secondCause)); + assertThat(asList(cause.getSuppressed()), contains(deliberateException)); + } + } + assertThat(testSubscription2.isCancelled(), equalTo(false)); + } else { + publisher.onError(DELIBERATE_EXCEPTION); + testSubscription2.awaitCancelled(); + assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); + } + } + + private static void validateSignal(@Nullable SwitchMapSignal signal, R data, boolean isSwitched) { + assertThat(signal, notNullValue()); + assertThat(signal.isSwitched(), equalTo(isSwitched)); + assertThat(signal.data(), equalTo(data)); + } + + private static void validateTerminal(TestPublisher

publisher, TestPublisherSubscriber subscriber, + boolean onError) { + if (onError) { + publisher.onError(DELIBERATE_EXCEPTION); + assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); + } else { + publisher.onComplete(); + subscriber.awaitOnComplete(); + } + } + + /** + * A signal containing the data from a series of {@link Publisher}s connected in a serial fashion. + * @param Type of the data. + */ + interface SwitchMapSignal { + /** + * Returns {@code true} on the first signal from a new {@link Publisher}. + * @return {@code true} on the first signal from a new {@link Publisher}. + */ + boolean isSwitched(); + + /** + * Get the data. + * @return the data. + */ + @Nullable + T data(); + + /** + * Convert from a regular {@link Function} to a {@link Function} that emits {@link SwitchMapSignal}. + *

+ * This function has state, if used in an operator chain use {@link Publisher#defer(Supplier)} so the state is + * unique per each subscribe. + * @param function The original function to convert. + * @param The input data type. + * @param The resulting data type. + * @return a {@link Function} that emits {@link SwitchMapSignal}. + */ + static Function>> toSwitchFunction( + Function> function) { + return new Function>>() { + private boolean seenFirstPublisher; + + @Nullable + @Override + public Publisher> apply(T t) { + final Publisher rawPublisher = function.apply(t); + if (rawPublisher == null) { + return null; + } + final boolean localSeenFirstPublisher = seenFirstPublisher; + seenFirstPublisher = true; + if (localSeenFirstPublisher) { + return Publisher.defer(() -> { + final boolean[] seenOnNext = new boolean[1]; // modifiable boolean + return rawPublisher.map(res -> { + final boolean localSeenOnNext = seenOnNext[0]; + seenOnNext[0] = true; + return new SwitchMapSignal() { + @Override + public boolean isSwitched() { + return !localSeenOnNext; + } + + @Nullable + @Override + public R data() { + return res; + } + }; + }).shareContextOnSubscribe(); + }); + } else { + return rawPublisher.map(res -> new SwitchMapSignal() { + @Override + public boolean isSwitched() { + return false; + } + + @Nullable + @Override + public R data() { + return res; + } + }); + } + } + }; + } + } +} diff --git a/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherSwitchMapTckTest.java b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherSwitchMapTckTest.java new file mode 100644 index 0000000000..5263f07979 --- /dev/null +++ b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherSwitchMapTckTest.java @@ -0,0 +1,166 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.reactivestreams.tck; + +import io.servicetalk.concurrent.PublisherSource; +import io.servicetalk.concurrent.PublisherSource.Subscription; +import io.servicetalk.concurrent.api.Publisher; +import io.servicetalk.concurrent.api.PublisherOperator; +import io.servicetalk.concurrent.internal.DuplicateSubscribeException; +import io.servicetalk.concurrent.internal.FlowControlUtils; + +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.Publisher.defer; +import static io.servicetalk.concurrent.api.Publisher.from; +import static io.servicetalk.concurrent.api.Publisher.never; +import static io.servicetalk.concurrent.internal.SubscriberUtils.deliverErrorFromSource; + +@Test +public class PublisherSwitchMapTckTest extends AbstractPublisherOperatorTckTest { + private boolean applyFlatMap = true; + + @Override + protected Publisher composePublisher(Publisher publisher, int elements) { + return defer(() -> { + final AtomicLong itemCount = new AtomicLong(); + // The test harness delivers items synchronously and in a reentrant fashion, so we delay upstream demand + // until items are delivered downstream as otherwise delivery from upstream will happen before we deliver + // downstream. + final SingleUpstreamDemandOperator demandOperator = new SingleUpstreamDemandOperator<>(); + return publisher.liftAsync(demandOperator) + .switchMap(i -> { + Publisher p = from(i); + // flatMapMerge optimistically requests upstream, and for tests that try to validate illegal + // demand we skip this operator to prevent the from(.) operator above from completing before + // the illegal demand is requested and therefore wouldn't be able to propagate the error. + if (applyFlatMap) { + p = p.flatMapMerge(x -> itemCount.incrementAndGet() >= elements ? + from(x) : from(x).concat(never())); + } + return p.afterOnNext(x -> demandOperator.subscriberRef.get().decrementDemand()); + }); + }); + } + + @Override + public void required_spec309_requestNegativeNumberMustSignalIllegalArgumentException() throws Throwable { + applyFlatMap = false; + try { + super.required_spec309_requestNegativeNumberMustSignalIllegalArgumentException(); + } finally { + applyFlatMap = true; + } + } + + @Override + public void required_spec309_requestZeroMustSignalIllegalArgumentException() throws Throwable { + applyFlatMap = false; + try { + super.required_spec309_requestZeroMustSignalIllegalArgumentException(); + } finally { + applyFlatMap = true; + } + } + + private static final class SingleUpstreamDemandOperator implements PublisherOperator { + final AtomicReference> subscriberRef = new AtomicReference<>(); + @Override + public PublisherSource.Subscriber apply(final PublisherSource.Subscriber subscriber) { + SingleUpstreamDemandSubscriber sub = new SingleUpstreamDemandSubscriber<>(subscriber); + if (subscriberRef.compareAndSet(null, sub)) { + return sub; + } else { + return new PublisherSource.Subscriber() { + @Override + public void onSubscribe(final Subscription subscription) { + deliverErrorFromSource(subscriber, + new DuplicateSubscribeException(subscriberRef.get(), subscriber)); + } + + @Override + public void onNext(@Nullable final T t) { + } + + @Override + public void onError(final Throwable t) { + } + + @Override + public void onComplete() { + } + }; + } + } + + private static final class SingleUpstreamDemandSubscriber implements PublisherSource.Subscriber { + private final AtomicLong demand = new AtomicLong(); + private final PublisherSource.Subscriber subscriber; + @Nullable + private Subscription subscription; + + SingleUpstreamDemandSubscriber(final PublisherSource.Subscriber subscriber) { + this.subscriber = subscriber; + } + + @Override + public void onSubscribe(final Subscription s) { + this.subscription = s; + subscriber.onSubscribe(new Subscription() { + @Override + public void request(final long n) { + if (n <= 0) { + subscription.request(n); + } else if (demand.getAndAccumulate(n, FlowControlUtils::addWithOverflowProtection) == 0) { + subscription.request(1); + } + } + + @Override + public void cancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(@Nullable final T t) { + subscriber.onNext(t); + } + + @Override + public void onError(final Throwable t) { + subscriber.onError(t); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } + + void decrementDemand() { + if (demand.decrementAndGet() > 0) { + assert subscription != null; + subscription.request(1); + } + } + } + } +}