diff --git a/core/shared/src/main/scala/fs2/Pull.scala b/core/shared/src/main/scala/fs2/Pull.scala index c583263b15..6a241b8fb5 100644 --- a/core/shared/src/main/scala/fs2/Pull.scala +++ b/core/shared/src/main/scala/fs2/Pull.scala @@ -464,10 +464,7 @@ object Pull extends PullLowPriority { def extendScopeTo[F[_], O]( s: Stream[F, O] )(implicit F: MonadError[F, Throwable]): Pull[F, Nothing, Stream[F, O]] = - for { - scope <- Pull.getScope[F] - lease <- Pull.eval(scope.lease) - } yield s.onFinalize(lease.cancel.redeemWith(F.raiseError(_), _ => F.unit)) + Pull.getScope[F].map(scope => Stream.bracket(scope.lease)(_.cancel.rethrow) *> s) /** Repeatedly uses the output of the pull as input for the next step of the * pull. Halts when a step terminates with `None` or `Pull.raiseError`. diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 0b052f2c5f..7f22372608 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -238,19 +238,19 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, */ def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = { assert(pipes.nonEmpty, s"pipes should not be empty") - underlying.uncons.flatMap { - case Some((hd, tl)) => + extendScopeThrough { source => + Stream.force { for { // topic: contains the chunk that the pipes are processing at one point. // until and unless all pipes are finished with it, won't move to next one - topic <- Pull.eval(Topic[F2, Chunk[O]]) + topic <- Topic[F2, Chunk[O]] // Coordination: neither the producer nor any consumer starts // until and unless all consumers are subscribed to topic. - allReady <- Pull.eval(CountDownLatch[F2](pipes.length)) - - checkIn = allReady.release >> allReady.await + allReady <- CountDownLatch[F2](pipes.length) + } yield { + val checkIn = allReady.release >> allReady.await - dump = (pipe: Pipe[F2, O, O2]) => + def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] = Stream.resource(topic.subscribeAwait(1)).flatMap { sub => // Wait until all pipes are ready before consuming. // Crucial: checkin is not passed to the pipe, @@ -258,19 +258,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, Stream.exec(checkIn) ++ pipe(sub.unchunks) } - dumpAll: Stream[F2, O2] <- - Pull.extendScopeTo(Stream(pipes: _*).map(dump).parJoinUnbounded) - - chunksStream = Stream.chunk(hd).append(tl.stream).chunks - + val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded // Wait until all pipes are checked in before pulling - pump = Stream.exec(allReady.await) ++ topic.publish(chunksStream) - - _ <- dumpAll.concurrently(pump).underlying - } yield () - - case None => Pull.done - }.stream + val pump = Stream.exec(allReady.await) ++ topic.publish(source.chunks) + dumpAll.concurrently(pump) + } + } + } } /** Behaves like the identity function, but requests `n` elements at a time from the input. @@ -548,6 +542,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, )(implicit F: Concurrent[F2]): Stream[F2, O] = concurrentlyAux(that).flatMap { case (startBack, fore) => startBack >> fore } + def concurrentlyExtendingThatScope[F2[x] >: F[x], O2]( + that: Stream[F2, O2] + )(implicit F: Concurrent[F2]): Stream[F2, O] = + that.extendScopeThrough(that => + concurrentlyAux(that).flatMap { case (startBack, fore) => startBack >> fore } + ) + private def concurrentlyAux[F2[x] >: F[x], O2]( that: Stream[F2, O2] )(implicit @@ -2331,75 +2332,65 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, channel: F2[Channel[F2, F2[Either[Throwable, O2]]]], isOrdered: Boolean, f: O => F2[O2] - )(implicit F: Concurrent[F2]): Stream[F2, O2] = { - val action = - ( - Semaphore[F2](concurrency), - channel, - Deferred[F2, Unit], - Deferred[F2, Unit] - ).mapN { (semaphore, channel, stop, end) => - def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = { - def ordered: F2[Either[Throwable, O2] => F2[Unit]] = { - def send(v: Deferred[F2, Either[Throwable, O2]]) = - (el: Either[Throwable, O2]) => v.complete(el).void - - Deferred[F2, Either[Throwable, O2]] - .flatTap(value => channel.send(release *> value.get)) - .map(send) - } + )(implicit F: Concurrent[F2]): Stream[F2, O2] = + extendScopeThrough { source => + Stream.force { + ( + Semaphore[F2](concurrency), + channel, + Deferred[F2, Unit], + Deferred[F2, Unit] + ).mapN { (semaphore, channel, stop, end) => + def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = { + def ordered: F2[Either[Throwable, O2] => F2[Unit]] = { + def send(v: Deferred[F2, Either[Throwable, O2]]) = + (el: Either[Throwable, O2]) => v.complete(el).void + + Deferred[F2, Either[Throwable, O2]] + .flatTap(value => channel.send(release *> value.get)) + .map(send) + } - def unordered: Either[Throwable, O2] => F2[Unit] = - (el: Either[Throwable, O2]) => release <* channel.send(F.pure(el)) + def unordered: Either[Throwable, O2] => F2[Unit] = + (el: Either[Throwable, O2]) => release <* channel.send(F.pure(el)) - if (isOrdered) ordered else F.pure(unordered) - } - - val releaseAndCheckCompletion = - semaphore.release *> - semaphore.available.flatMap { - case `concurrency` => channel.close *> end.complete(()).void - case _ => F.unit - } + if (isOrdered) ordered else F.pure(unordered) + } - def forkOnElem(el: O): F2[Unit] = - F.uncancelable { poll => - poll(semaphore.acquire) <* - Deferred[F2, Unit].flatMap { pushed => - val init = initFork(pushed.complete(()).void) - poll(init).onCancel(releaseAndCheckCompletion).flatMap { send => - val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get - F.start(stop.get.race(action) *> releaseAndCheckCompletion) - } + val releaseAndCheckCompletion = + semaphore.release *> + semaphore.available.flatMap { + case `concurrency` => channel.close *> end.complete(()).void + case _ => F.unit } - } - underlying.uncons.flatMap { - case Some((hd, tl)) => - for { - foreground <- Pull.extendScopeTo( - channel.stream.evalMap(_.rethrow).onFinalize(stop.complete(()) *> end.get) - ) - background = Stream - .exec(semaphore.acquire) ++ - Stream - .chunk(hd) - .append(tl.stream) - .interruptWhen(stop.get.map(_.asRight[Throwable])) - .foreach(forkOnElem) - .onFinalizeCase { - case ExitCase.Succeeded => releaseAndCheckCompletion - case _ => stop.complete(()) *> releaseAndCheckCompletion + def forkOnElem(el: O): F2[Unit] = + F.uncancelable { poll => + poll(semaphore.acquire) <* + Deferred[F2, Unit].flatMap { pushed => + val init = initFork(pushed.complete(()).void) + poll(init).onCancel(releaseAndCheckCompletion).flatMap { send => + val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get + F.start(stop.get.race(action) *> releaseAndCheckCompletion) } - _ <- foreground.concurrently(background).underlying - } yield () + } + } - case None => Pull.done - }.stream - } + val background = + Stream.exec(semaphore.acquire) ++ + source + .interruptWhen(stop.get.map(_.asRight[Throwable])) + .foreach(forkOnElem) + .onFinalizeCase { + case ExitCase.Succeeded => releaseAndCheckCompletion + case _ => stop.complete(()) *> releaseAndCheckCompletion + } - Stream.force(action) - } + val foreground = channel.stream.evalMap(_.rethrow) + foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background) + } + } + } /** Concurrent zip. * @@ -2474,12 +2465,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, */ def prefetchN[F2[x] >: F[x]: Concurrent]( n: Int - ): Stream[F2, O] = + ): Stream[F2, O] = extendScopeThrough { source => Stream.eval(Channel.bounded[F2, Chunk[O]](n)).flatMap { chan => chan.stream.unchunks.concurrently { - chunks.through(chan.sendAll) + source.chunks.through(chan.sendAll) } } + } /** Prints each element of this stream to standard out, converting each element to a `String` via `Show`. */ def printlns[F2[x] >: F[x], O2 >: O](implicit @@ -2940,6 +2932,23 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, )(f: (Stream[F, O], Stream[F2, O2]) => Stream[F2, O3]): Stream[F2, O3] = f(this, s2) + /** Transforms this stream, explicitly extending the current scope through the given pipe. + * + * Use this when implementing a pipe where the resulting stream is not directly constructed from + * the source stream, e.g. when sending the source stream through a Channel and returning the + * channel's stream. + */ + def extendScopeThrough[F2[x] >: F[x], O2]( + f: Stream[F, O] => Stream[F2, O2] + )(implicit F: MonadError[F2, Throwable]): Stream[F2, O2] = + this.pull.peek + .flatMap { + case Some((_, tl)) => Pull.extendScopeTo(f(tl)) + case None => Pull.extendScopeTo(f(Stream.empty)) + } + .flatMap(_.underlying) + .stream + /** Fails this stream with a `TimeoutException` if it does not complete within given `timeout`. */ def timeout[F2[x] >: F[x]: Temporal]( timeout: FiniteDuration diff --git a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala index d7cc21d93b..fe27fc5492 100644 --- a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala @@ -1328,6 +1328,20 @@ class StreamCombinatorsSuite extends Fs2Suite { ) .assertEquals(4.seconds) } + + test("scope propagation") { + Deferred[IO, Unit] + .flatMap { d => + Stream + .bracket(IO.unit)(_ => d.complete(()).void) + .prefetch + .evalMap(_ => IO.sleep(1.second) >> d.complete(())) + .timeout(5.seconds) + .compile + .last + } + .assertEquals(Some(true)) + } } test("range") {