diff --git a/core/src/main/scala/ox/race.scala b/core/src/main/scala/ox/race.scala index 3e91852b..6a57d91e 100644 --- a/core/src/main/scala/ox/race.scala +++ b/core/src/main/scala/ox/race.scala @@ -6,18 +6,23 @@ import scala.concurrent.TimeoutException import scala.concurrent.duration.FiniteDuration import scala.util.{Failure, Success, Try} -/** A `Some` if the computation `t` took less than `duration`, and `None` otherwise. */ +/** A `Some` if the computation `t` took less than `duration`, and `None` otherwise. if the computation `t` throws an exception, it is + * propagated. + */ def timeoutOption[T](duration: FiniteDuration)(t: => T): Option[T] = - race(Some(t), { sleep(duration); None }) + raceResult(Some(t), { sleep(duration); None }) -/** The result of computation `t`, if it took less than `duration`, and a [[TimeoutException]] otherwise. +/** The result of computation `t`, if it took less than `duration`, and a [[TimeoutException]] otherwise. if the computation `t` throws an + * exception, it is propagated. * @throws TimeoutException * If `t` took more than `duration`. */ def timeout[T](duration: FiniteDuration)(t: => T): T = timeoutOption(duration)(t).getOrElse(throw new TimeoutException(s"Timed out after $duration")) -/** Result of the computation `t` if it took less than `duration`, and `Left(timeoutValue)` otherwise. */ +/** Result of the computation `t` if it took less than `duration`, and `Left(timeoutValue)` otherwise. if the computation `t` throws an + * exception, it is propagated. + */ def timeoutEither[E, T](duration: FiniteDuration, timeoutValue: E)(t: => Either[E, T]): Either[E, T] = timeoutOption(duration)(t).getOrElse(Left(timeoutValue)) diff --git a/core/src/test/scala/ox/ControlTest.scala b/core/src/test/scala/ox/ControlTest.scala index 12020548..d5c85d35 100644 --- a/core/src/test/scala/ox/ControlTest.scala +++ b/core/src/test/scala/ox/ControlTest.scala @@ -7,6 +7,8 @@ import ox.util.Trail import scala.concurrent.TimeoutException import scala.concurrent.duration.DurationInt +import scala.util.Failure +import scala.util.Try class ControlTest extends AnyFlatSpec with Matchers { "timeout" should "short-circuit a long computation" in { @@ -26,6 +28,14 @@ class ControlTest extends AnyFlatSpec with Matchers { trail.get shouldBe Vector("timeout", "done") } + it should "pass through the exception of failed computation" in { + val myException = new Throwable("failed computation") + + Try { + timeout(1.second)(throw myException) + } shouldBe Failure(myException) + } + it should "not interrupt a short computation" in { val trail = Trail() unsupervised { @@ -57,4 +67,21 @@ class ControlTest extends AnyFlatSpec with Matchers { trail.get shouldBe Vector("done") } + + "timeoutOption" should "pass through the exception of failed computation" in { + val myException = new Throwable("failed computation") + + Try { + timeoutOption(1.second)(throw myException) + } shouldBe Failure(myException) + } + + "timeoutEither" should "pass through the exception of failed computation" in { + val myException = new Throwable("failed computation") + + Try { + timeoutEither(1.second, new TimeoutException(s"Timed out after 1 seconds"))(throw myException) + } shouldBe Failure(myException) + } + }