diff --git a/core/shared/src/main/scala/fs2/concurrent/Channel.scala b/core/shared/src/main/scala/fs2/concurrent/Channel.scala index 68e71bcaa9..9807a525c4 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Channel.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Channel.scala @@ -117,6 +117,18 @@ sealed trait Channel[F[_], A] { */ def closeWithElement(a: A): F[Either[Channel.Closed, Unit]] + /** Raises an error, closing the channel with an error state. + * + * No-op if the channel is closed, see [[close]] for further info. + */ + def raiseError(e: Throwable): F[Either[Channel.Closed, Unit]] + + /** Cancels the channel, closing it with a canceled state. + * + * No-op if the channel is closed, see [[close]] for further info. + */ + def cancel: F[Either[Channel.Closed, Unit]] + /** Returns true if this channel is closed */ def isClosed: F[Boolean] @@ -216,6 +228,12 @@ object Channel { ) } + def raiseError(e: Throwable): F[Either[Closed, Unit]] = + closeWithExitCase(ExitCase.Errored(e)) + + def cancel: F[Either[Closed, Unit]] = + closeWithExitCase(ExitCase.Canceled) + def isClosed = closedGate.tryGet.map(_.isDefined) def closed = closedGate.get diff --git a/core/shared/src/main/scala/fs2/concurrent/Topic.scala b/core/shared/src/main/scala/fs2/concurrent/Topic.scala index 2061070b16..f123e29659 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Topic.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Topic.scala @@ -23,6 +23,7 @@ package fs2 package concurrent import cats.effect._ +import cats.effect.Resource.ExitCase import cats.effect.implicits._ import cats.syntax.all._ import scala.collection.immutable.LongMap @@ -208,7 +209,8 @@ object Topic { } def publish: Pipe[F, A, Nothing] = { in => - in.onFinalize(close.void) + in + .onFinalizeCase(closeWithExitCase(_).void) .evalMap(publish1) .takeWhile(_.isRight) .drain @@ -223,13 +225,24 @@ object Topic { def subscribers: Stream[F, Int] = subscriberCount.discrete def close: F[Either[Topic.Closed, Unit]] = + closeWithExitCase(ExitCase.Succeeded) + + def closeWithExitCase(exitCase: ExitCase): F[Either[Closed, Unit]] = signalClosure .complete(()) .flatMap { completedNow => val result = if (completedNow) Topic.rightUnit else Topic.closed state.get - .flatMap { case (subs, _) => foreach(subs)(_.close.void) } + .flatMap { case (subs, _) => + foreach(subs)(channel => + exitCase match { + case ExitCase.Succeeded => channel.close.void + case ExitCase.Errored(e) => channel.raiseError(e).void + case ExitCase.Canceled => channel.cancel.void + } + ) + } .as(result) } .uncancelable diff --git a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala index c26fd73dd7..dbbfe64296 100644 --- a/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala +++ b/core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala @@ -22,11 +22,14 @@ package fs2 package concurrent -import cats.syntax.all._ +import cats.syntax.all.* import cats.effect.IO -import scala.concurrent.duration._ + +import scala.concurrent.duration.* import cats.effect.testkit.TestControl +import scala.concurrent.CancellationException + class TopicSuite extends Fs2Suite { test("subscribers see all elements published") { Topic[IO, Int].flatMap { topic => @@ -204,6 +207,8 @@ class TopicSuite extends Fs2Suite { .drain } - TestControl.executeEmbed(program) // will fail if program is deadlocked + TestControl + .executeEmbed(program) + .intercept[CancellationException] } }