Skip to content

Commit

Permalink
Handle error propagation / cancelation in Channel
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitriho committed Dec 1, 2024
1 parent bc334e9 commit 09f86e3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
50 changes: 29 additions & 21 deletions core/shared/src/main/scala/fs2/concurrent/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package concurrent

import cats.effect._
import cats.effect.implicits._
import cats.effect.Resource.ExitCase
import cats.syntax.all._

/** Stream aware, multiple producer, single consumer closeable channel.
Expand Down Expand Up @@ -138,76 +139,79 @@ object Channel {
size: Int,
waiting: Option[Deferred[F, Unit]],
producers: List[(A, Deferred[F, Unit])],
closed: Boolean
closed: Option[ExitCase]
)

val open = State(List.empty, 0, None, List.empty, closed = false)
val open = State(List.empty, 0, None, List.empty, closed = None)

def empty(isClosed: Boolean): State =
if (isClosed) State(List.empty, 0, None, List.empty, closed = true)
def empty(close: Option[ExitCase]): State =
if (close.nonEmpty) State(List.empty, 0, None, List.empty, closed = close)
else open

(F.ref(open), F.deferred[Unit]).mapN { (state, closedGate) =>
new Channel[F, A] {

def sendAll: Pipe[F, A, Nothing] = { in =>
in.onFinalize(close.void)
in.onFinalizeCase(closeWithExitCase(_).void)
.evalMap(send)
.takeWhile(_.isRight)
.drain
}

def sendImpl(a: A, close: Boolean) =
def sendImpl(a: A, close: Option[ExitCase]) =
F.deferred[Unit].flatMap { producer =>
state.flatModifyFull { case (poll, state) =>
state match {
case s @ State(_, _, _, _, closed @ true) =>
case s @ State(_, _, _, _, Some(_)) =>
(s, Channel.closed[Unit].pure[F])

case State(values, size, waiting, producers, closed @ false) =>
case State(values, size, waiting, producers, None) =>
if (size < capacity)
(
State(a :: values, size + 1, None, producers, close),
signalClosure.whenA(close) *> notifyStream(waiting).as(rightUnit)
signalClosure.whenA(close.nonEmpty) *> notifyStream(waiting).as(rightUnit)
)
else
(
State(values, size, None, (a, producer) :: producers, close),
signalClosure.whenA(close) *>
signalClosure.whenA(close.nonEmpty) *>
notifyStream(waiting).as(rightUnit) <*
waitOnBound(producer, poll).unlessA(close)
waitOnBound(producer, poll).unlessA(close.nonEmpty)
)
}
}
}

def send(a: A) = sendImpl(a, false)
def send(a: A) = sendImpl(a, None)

def closeWithElement(a: A) = sendImpl(a, true)
def closeWithElement(a: A) = sendImpl(a, Some(ExitCase.Succeeded))

def trySend(a: A) =
state.flatModify {
case s @ State(_, _, _, _, closed @ true) =>
case s @ State(_, _, _, _, Some(_)) =>
(s, Channel.closed[Boolean].pure[F])

case s @ State(values, size, waiting, producers, closed @ false) =>
case s @ State(values, size, waiting, producers, None) =>
if (size < capacity)
(
State(a :: values, size + 1, None, producers, false),
State(a :: values, size + 1, None, producers, None),
notifyStream(waiting).as(rightTrue)
)
else
(s, rightFalse.pure[F])
}

def close =
closeWithExitCase(ExitCase.Succeeded)

def closeWithExitCase(exitCase: ExitCase): F[Either[Closed, Unit]] =
state.flatModify {
case s @ State(_, _, _, _, closed @ true) =>
case s @ State(_, _, _, _, Some(_)) =>
(s, Channel.closed[Unit].pure[F])

case State(values, size, waiting, producers, closed @ false) =>
case State(values, size, waiting, producers, None) =>
(
State(values, size, None, producers, true),
State(values, size, None, producers, Some(exitCase)),
notifyStream(waiting).as(rightUnit) <* signalClosure
)
}
Expand Down Expand Up @@ -250,8 +254,12 @@ object Channel {
unblock.as(Pull.output(toEmit) >> consumeLoop)
} else {
F.pure(
if (closed) Pull.done
else Pull.eval(waiting.get) >> consumeLoop
closed match {
case Some(ExitCase.Succeeded) => Pull.done
case Some(ExitCase.Errored(e)) => Pull.raiseError(e)
case Some(ExitCase.Canceled) => Pull.eval(F.canceled)
case None => Pull.eval(waiting.get) >> consumeLoop
}
)
}
}
Expand Down
6 changes: 5 additions & 1 deletion core/shared/src/test/scala/fs2/concurrent/ChannelSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import scala.concurrent.duration._

import org.scalacheck.effect.PropF.forAllF

import scala.concurrent.CancellationException

class ChannelSuite extends Fs2Suite {

test("receives some simple elements above capacity and closes") {
Expand Down Expand Up @@ -336,6 +338,8 @@ class ChannelSuite extends Fs2Suite {
ch.stream.concurrently(producer).compile.drain
}

TestControl.executeEmbed(program) // will fail if program is deadlocked
TestControl
.executeEmbed(program)
.intercept[CancellationException]
}
}

0 comments on commit 09f86e3

Please sign in to comment.