Skip to content

Commit

Permalink
Abstract scope extension with extendScopeThrough
Browse files Browse the repository at this point in the history
Make extendScopeTo cancellation safe (see typelevel#3474)
  • Loading branch information
reardonj committed Dec 26, 2024
1 parent 810af8a commit e229f82
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 86 deletions.
5 changes: 1 addition & 4 deletions core/shared/src/main/scala/fs2/Pull.scala
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,7 @@ object Pull extends PullLowPriority {
def extendScopeTo[F[_], O](
s: Stream[F, O]
)(implicit F: MonadError[F, Throwable]): Pull[F, Nothing, Stream[F, O]] =
for {
scope <- Pull.getScope[F]
lease <- Pull.eval(scope.lease)
} yield s.onFinalize(lease.cancel.redeemWith(F.raiseError(_), _ => F.unit))
Pull.getScope[F].map(scope => Stream.bracket(scope.lease)(_.cancel.rethrow) *> s)

/** Repeatedly uses the output of the pull as input for the next step of the
* pull. Halts when a step terminates with `None` or `Pull.raiseError`.
Expand Down
173 changes: 91 additions & 82 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,39 +238,33 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
*/
def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = {
assert(pipes.nonEmpty, s"pipes should not be empty")
underlying.uncons.flatMap {
case Some((hd, tl)) =>
extendScopeThrough { source =>
Stream.force {
for {
// topic: contains the chunk that the pipes are processing at one point.
// until and unless all pipes are finished with it, won't move to next one
topic <- Pull.eval(Topic[F2, Chunk[O]])
topic <- Topic[F2, Chunk[O]]
// Coordination: neither the producer nor any consumer starts
// until and unless all consumers are subscribed to topic.
allReady <- Pull.eval(CountDownLatch[F2](pipes.length))

checkIn = allReady.release >> allReady.await
allReady <- CountDownLatch[F2](pipes.length)
} yield {
val checkIn = allReady.release >> allReady.await

dump = (pipe: Pipe[F2, O, O2]) =>
def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] =
Stream.resource(topic.subscribeAwait(1)).flatMap { sub =>
// Wait until all pipes are ready before consuming.
// Crucial: checkin is not passed to the pipe,
// so pipe cannot interrupt it and alter the latch count
Stream.exec(checkIn) ++ pipe(sub.unchunks)
}

dumpAll: Stream[F2, O2] <-
Pull.extendScopeTo(Stream(pipes: _*).map(dump).parJoinUnbounded)

chunksStream = Stream.chunk(hd).append(tl.stream).chunks

val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded
// Wait until all pipes are checked in before pulling
pump = Stream.exec(allReady.await) ++ topic.publish(chunksStream)

_ <- dumpAll.concurrently(pump).underlying
} yield ()

case None => Pull.done
}.stream
val pump = Stream.exec(allReady.await) ++ topic.publish(source.chunks)
dumpAll.concurrently(pump)
}
}
}
}

/** Behaves like the identity function, but requests `n` elements at a time from the input.
Expand Down Expand Up @@ -548,6 +542,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
)(implicit F: Concurrent[F2]): Stream[F2, O] =
concurrentlyAux(that).flatMap { case (startBack, fore) => startBack >> fore }

def concurrentlyExtendingThatScope[F2[x] >: F[x], O2](
that: Stream[F2, O2]
)(implicit F: Concurrent[F2]): Stream[F2, O] =
that.extendScopeThrough(that =>
concurrentlyAux(that).flatMap { case (startBack, fore) => startBack >> fore }
)

private def concurrentlyAux[F2[x] >: F[x], O2](
that: Stream[F2, O2]
)(implicit
Expand Down Expand Up @@ -2331,75 +2332,65 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
channel: F2[Channel[F2, F2[Either[Throwable, O2]]]],
isOrdered: Boolean,
f: O => F2[O2]
)(implicit F: Concurrent[F2]): Stream[F2, O2] = {
val action =
(
Semaphore[F2](concurrency),
channel,
Deferred[F2, Unit],
Deferred[F2, Unit]
).mapN { (semaphore, channel, stop, end) =>
def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = {
def ordered: F2[Either[Throwable, O2] => F2[Unit]] = {
def send(v: Deferred[F2, Either[Throwable, O2]]) =
(el: Either[Throwable, O2]) => v.complete(el).void

Deferred[F2, Either[Throwable, O2]]
.flatTap(value => channel.send(release *> value.get))
.map(send)
}
)(implicit F: Concurrent[F2]): Stream[F2, O2] =
extendScopeThrough { source =>
Stream.force {
(
Semaphore[F2](concurrency),
channel,
Deferred[F2, Unit],
Deferred[F2, Unit]
).mapN { (semaphore, channel, stop, end) =>
def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = {
def ordered: F2[Either[Throwable, O2] => F2[Unit]] = {
def send(v: Deferred[F2, Either[Throwable, O2]]) =
(el: Either[Throwable, O2]) => v.complete(el).void

Deferred[F2, Either[Throwable, O2]]
.flatTap(value => channel.send(release *> value.get))
.map(send)
}

def unordered: Either[Throwable, O2] => F2[Unit] =
(el: Either[Throwable, O2]) => release <* channel.send(F.pure(el))
def unordered: Either[Throwable, O2] => F2[Unit] =
(el: Either[Throwable, O2]) => release <* channel.send(F.pure(el))

if (isOrdered) ordered else F.pure(unordered)
}

val releaseAndCheckCompletion =
semaphore.release *>
semaphore.available.flatMap {
case `concurrency` => channel.close *> end.complete(()).void
case _ => F.unit
}
if (isOrdered) ordered else F.pure(unordered)
}

def forkOnElem(el: O): F2[Unit] =
F.uncancelable { poll =>
poll(semaphore.acquire) <*
Deferred[F2, Unit].flatMap { pushed =>
val init = initFork(pushed.complete(()).void)
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
}
val releaseAndCheckCompletion =
semaphore.release *>
semaphore.available.flatMap {
case `concurrency` => channel.close *> end.complete(()).void
case _ => F.unit
}
}

underlying.uncons.flatMap {
case Some((hd, tl)) =>
for {
foreground <- Pull.extendScopeTo(
channel.stream.evalMap(_.rethrow).onFinalize(stop.complete(()) *> end.get)
)
background = Stream
.exec(semaphore.acquire) ++
Stream
.chunk(hd)
.append(tl.stream)
.interruptWhen(stop.get.map(_.asRight[Throwable]))
.foreach(forkOnElem)
.onFinalizeCase {
case ExitCase.Succeeded => releaseAndCheckCompletion
case _ => stop.complete(()) *> releaseAndCheckCompletion
def forkOnElem(el: O): F2[Unit] =
F.uncancelable { poll =>
poll(semaphore.acquire) <*
Deferred[F2, Unit].flatMap { pushed =>
val init = initFork(pushed.complete(()).void)
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
}
_ <- foreground.concurrently(background).underlying
} yield ()
}
}

case None => Pull.done
}.stream
}
val background =
Stream.exec(semaphore.acquire) ++
source
.interruptWhen(stop.get.map(_.asRight[Throwable]))
.foreach(forkOnElem)
.onFinalizeCase {
case ExitCase.Succeeded => releaseAndCheckCompletion
case _ => stop.complete(()) *> releaseAndCheckCompletion
}

Stream.force(action)
}
val foreground = channel.stream.evalMap(_.rethrow)
foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background)
}
}
}

/** Concurrent zip.
*
Expand Down Expand Up @@ -2474,12 +2465,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
*/
def prefetchN[F2[x] >: F[x]: Concurrent](
n: Int
): Stream[F2, O] =
): Stream[F2, O] = extendScopeThrough { source =>
Stream.eval(Channel.bounded[F2, Chunk[O]](n)).flatMap { chan =>
chan.stream.unchunks.concurrently {
chunks.through(chan.sendAll)
source.chunks.through(chan.sendAll)
}
}
}

/** Prints each element of this stream to standard out, converting each element to a `String` via `Show`. */
def printlns[F2[x] >: F[x], O2 >: O](implicit
Expand Down Expand Up @@ -2940,6 +2932,23 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
)(f: (Stream[F, O], Stream[F2, O2]) => Stream[F2, O3]): Stream[F2, O3] =
f(this, s2)

/** Transforms this stream, explicitly extending the current scope through the given pipe.
*
* Use this when implementing a pipe where the resulting stream is not directly constructed from
* the source stream, e.g. when sending the source stream through a Channel and returning the
* channel's stream.
*/
def extendScopeThrough[F2[x] >: F[x], O2](
f: Stream[F, O] => Stream[F2, O2]
)(implicit F: MonadError[F2, Throwable]): Stream[F2, O2] =
this.pull.peek
.flatMap {
case Some((_, tl)) => Pull.extendScopeTo(f(tl))
case None => Pull.extendScopeTo(f(Stream.empty))
}
.flatMap(_.underlying)
.stream

/** Fails this stream with a `TimeoutException` if it does not complete within given `timeout`. */
def timeout[F2[x] >: F[x]: Temporal](
timeout: FiniteDuration
Expand Down
14 changes: 14 additions & 0 deletions core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,20 @@ class StreamCombinatorsSuite extends Fs2Suite {
)
.assertEquals(4.seconds)
}

test("scope propagation") {
Deferred[IO, Unit]
.flatMap { d =>
Stream
.bracket(IO.unit)(_ => d.complete(()).void)
.prefetch
.evalMap(_ => IO.sleep(1.second) >> d.complete(()))
.timeout(5.seconds)
.compile
.last
}
.assertEquals(Some(true))
}
}

test("range") {
Expand Down

0 comments on commit e229f82

Please sign in to comment.