diff --git a/core/src/main/scala/ox/fork.scala b/core/src/main/scala/ox/fork.scala index 3cce920f..9f6a3042 100644 --- a/core/src/main/scala/ox/fork.scala +++ b/core/src/main/scala/ox/fork.scala @@ -25,8 +25,10 @@ def fork[T](f: => T)(using Ox): Fork[T] = supervisor.forkSuccess() catch case e: Throwable => + // we notify the supervisor first, so that if this is the first failing fork in the scope, the supervisor will + // get first notified of the exception by the "original" (this) fork + supervisor.forkError(e) result.completeExceptionally(e) - supervisor.forkError(e) } newForkUsingResult(result) diff --git a/core/src/main/scala/ox/supervised.scala b/core/src/main/scala/ox/supervised.scala index a8fd6a0c..e69049cd 100644 --- a/core/src/main/scala/ox/supervised.scala +++ b/core/src/main/scala/ox/supervised.scala @@ -27,8 +27,8 @@ def supervised[T](f: Ox ?=> T): T = catch case e: Throwable => // all forks are guaranteed to have finished: some might have ended up throwing exceptions (InterruptedException or - // others), but only the first one is propagated. That's wait, adding the others as suppressed. - s.addSuppressed(e) + // others), but only the first one is propagated below. That's why we add all the other exceptions as suppressed. + s.addOtherExceptionsAsSuppressedTo(e) throw e trait Supervisor: @@ -66,8 +66,8 @@ class DefaultSupervisor() extends Supervisor: override def join(): Unit = unwrapExecutionException(result.get()) - def addSuppressed(e: Throwable): Throwable = - otherExceptions.forEach(e.addSuppressed) + def addOtherExceptionsAsSuppressedTo(e: Throwable): Throwable = + otherExceptions.forEach(e2 => if e != e2 then e.addSuppressed(e2)) e /** Change the supervisor that is being used when running `f`. Doesn't affect existing usages of the current supervisor, or forks ran diff --git a/core/src/test/scala/ox/ExceptionTest.scala b/core/src/test/scala/ox/ExceptionTest.scala index 36d7a348..3589a5f5 100644 --- a/core/src/test/scala/ox/ExceptionTest.scala +++ b/core/src/test/scala/ox/ExceptionTest.scala @@ -9,6 +9,7 @@ import java.util.concurrent.Semaphore class ExceptionTest extends AnyFlatSpec with Matchers { class CustomException extends RuntimeException class CustomException2 extends RuntimeException + class CustomException3(e: Exception) extends RuntimeException(e) "scoped" should "throw the exception thrown by a joined fork" in { val trail = Trail() @@ -72,14 +73,41 @@ class ExceptionTest extends AnyFlatSpec with Matchers { throw CustomException() } } - catch - case e: Exception => - val suppressed = e.getSuppressed.map(_.getClass.getSimpleName) - trail.add(s"${e.getClass.getSimpleName}(suppressed=${suppressed.mkString(",")})") + catch case e: Exception => addExceptionWithSuppressedTo(trail, e) trail.get shouldBe Vector("CustomException(suppressed=CustomException2)") } + it should "not add the original exception as suppressed" in { + val trail = Trail() + try + supervised { + val f = fork { + throw new CustomException() + } + f.join() + } + catch case e: Exception => addExceptionWithSuppressedTo(trail, e) + + trail.get shouldBe Vector("CustomException(suppressed=)") + } + + it should "add an exception as suppressed, even if it wraps the original exception" in { + val trail = Trail() + try + supervised { + val f = fork { + throw new CustomException() + } + try f.join() catch { + case e: Exception => throw new CustomException3(e) + } + } + catch case e: Exception => addExceptionWithSuppressedTo(trail, e) + + trail.get shouldBe Vector("CustomException(suppressed=CustomException3)") + } + "joinEither" should "catch the exception with which a fork ends" in { val r = supervised { val f = forkUnsupervised { @@ -90,4 +118,9 @@ class ExceptionTest extends AnyFlatSpec with Matchers { r should matchPattern { case Left(e: CustomException) => } } + + def addExceptionWithSuppressedTo(t: Trail, e: Throwable): Unit = { + val suppressed = e.getSuppressed.map(_.getClass.getSimpleName) + t.add(s"${e.getClass.getSimpleName}(suppressed=${suppressed.mkString(",")})") + } }