diff --git a/core/src/main/scala/ox/channels/SourceCompanionIOOps.scala b/core/src/main/scala/ox/channels/SourceCompanionIOOps.scala index f856ac02..f931f873 100644 --- a/core/src/main/scala/ox/channels/SourceCompanionIOOps.scala +++ b/core/src/main/scala/ox/channels/SourceCompanionIOOps.scala @@ -9,7 +9,6 @@ import java.nio.channels.FileChannel import java.nio.file.Files import java.nio.file.Path import java.nio.file.StandardOpenOption -import scala.util.control.NonFatal trait SourceCompanionIOOps: @@ -25,7 +24,7 @@ trait SourceCompanionIOOps: */ def fromInputStream(is: InputStream, chunkSize: Int = 1024)(using Ox, StageCapacity, IO): Source[Chunk[Byte]] = val chunks = StageCapacity.newChannel[Chunk[Byte]] - fork { + forkPropagate(chunks) { try repeatWhile { val buf = new Array[Byte](chunkSize) @@ -37,14 +36,7 @@ trait SourceCompanionIOOps: if readBytes > 0 then chunks.send(if readBytes == chunkSize then Chunk.fromArray(buf) else Chunk.fromArray(buf.take(readBytes))) true } - catch - case t: Throwable => - chunks.errorOrClosed(t).discard - finally - try is.close() - catch - case t: Throwable => - chunks.errorOrClosed(t).discard + finally is.close() } chunks @@ -72,8 +64,8 @@ trait SourceCompanionIOOps: // Some file systems don't support file channels Files.newByteChannel(path, StandardOpenOption.READ) - fork { - try { + forkPropagate(chunks) { + try repeatWhile { val buf = ByteBuffer.allocate(chunkSize) val readBytes = jFileChannel.read(buf) @@ -84,11 +76,6 @@ trait SourceCompanionIOOps: if readBytes > 0 then chunks.send(Chunk.fromArray(if readBytes == chunkSize then buf.array else buf.array.take(readBytes))) true } - } catch case t: Throwable => chunks.errorOrClosed(t).discard - finally - try jFileChannel.close() - catch - case NonFatal(closeException) => - chunks.errorOrClosed(closeException).discard + finally jFileChannel.close() } chunks diff --git a/core/src/main/scala/ox/channels/SourceCompanionOps.scala b/core/src/main/scala/ox/channels/SourceCompanionOps.scala index 3d9f95f1..8090d512 100644 --- a/core/src/main/scala/ox/channels/SourceCompanionOps.scala +++ b/core/src/main/scala/ox/channels/SourceCompanionOps.scala @@ -19,70 +19,60 @@ trait SourceCompanionOps: def fromIterator[T](it: => Iterator[T])(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { + forkPropagate(c) { val theIt = it - try - while theIt.hasNext do c.sendOrClosed(theIt.next()).discard - c.doneOrClosed() - catch case t: Throwable => c.errorOrClosed(t) + while theIt.hasNext do c.sendOrClosed(theIt.next()).discard + c.doneOrClosed().discard } c def fromFork[T](f: Fork[T])(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { - try - c.sendOrClosed(f.join()) - c.doneOrClosed() - catch case t: Throwable => c.errorOrClosed(t) + forkPropagate(c) { + c.sendOrClosed(f.join()) + c.doneOrClosed().discard } c def iterate[T](zero: T)(f: T => T)(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { + forkPropagate(c) { var t = zero - try - forever { - c.sendOrClosed(t) - t = f(t) - } - catch case t: Throwable => c.errorOrClosed(t) + forever { + c.sendOrClosed(t) + t = f(t) + } } c /** A range of number, from `from`, to `to` (inclusive), stepped by `step`. */ def range(from: Int, to: Int, step: Int)(using Ox, StageCapacity): Source[Int] = val c = StageCapacity.newChannel[Int] - fork { + forkPropagate(c) { var t = from - try - repeatWhile { - c.sendOrClosed(t) - t = t + step - t <= to - } - c.doneOrClosed() - catch case t: Throwable => c.errorOrClosed(t) + repeatWhile { + c.sendOrClosed(t) + t = t + step + t <= to + } + c.doneOrClosed().discard } c def unfold[S, T](initial: S)(f: S => Option[(T, S)])(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { + forkPropagate(c) { var s = initial - try - repeatWhile { - f(s) match - case Some((value, next)) => - c.sendOrClosed(value) - s = next - true - case None => - c.doneOrClosed() - false - } - catch case t: Throwable => c.errorOrClosed(t) + repeatWhile { + f(s) match + case Some((value, next)) => + c.sendOrClosed(value) + s = next + true + case None => + c.doneOrClosed() + false + } } c @@ -147,12 +137,10 @@ trait SourceCompanionOps: */ def repeatEval[T](f: => T)(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { - try - forever { - c.sendOrClosed(f).discard - } - catch case t: Throwable => c.errorOrClosed(t) + forkPropagate(c) { + forever { + c.sendOrClosed(f).discard + } } c @@ -168,14 +156,12 @@ trait SourceCompanionOps: */ def repeatEvalWhileDefined[T](f: => Option[T])(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { - try - repeatWhile { - f match - case Some(value) => c.sendOrClosed(value); true - case None => c.doneOrClosed(); false - } - catch case t: Throwable => c.errorOrClosed(t) + forkPropagate(c) { + repeatWhile { + f match + case Some(value) => c.sendOrClosed(value); true + case None => c.doneOrClosed(); false + } } c @@ -190,27 +176,25 @@ trait SourceCompanionOps: def concat[T](sources: Seq[() => Source[T]])(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { + forkPropagate(c) { var currentSource: Option[Source[T]] = None val sourcesIterator = sources.iterator var continue = true - try - while continue do - currentSource match - case None if sourcesIterator.hasNext => currentSource = Some(sourcesIterator.next()()) - case None => - c.doneOrClosed() - continue = false - case Some(source) => - source.receiveOrClosed() match - case ChannelClosed.Done => - currentSource = None - case ChannelClosed.Error(r) => - c.errorOrClosed(r) - continue = false - case t: T @unchecked => - c.sendOrClosed(t).discard - catch case t: Throwable => c.errorOrClosed(t) + while continue do + currentSource match + case None if sourcesIterator.hasNext => currentSource = Some(sourcesIterator.next()()) + case None => + c.doneOrClosed() + continue = false + case Some(source) => + source.receiveOrClosed() match + case ChannelClosed.Done => + currentSource = None + case ChannelClosed.Error(r) => + c.errorOrClosed(r) + continue = false + case t: T @unchecked => + c.sendOrClosed(t).discard } c diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index ace4c855..2835a3e6 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -98,19 +98,12 @@ trait SourceOps[+T] { outer: Source[T] => */ def map[U](f: T => U)(using Ox, StageCapacity): Source[U] = val c2 = StageCapacity.newChannel[U] - fork { + forkPropagate(c2) { repeatWhile { receiveOrClosed() match case ChannelClosed.Done => c2.doneOrClosed(); false case ChannelClosed.Error(r) => c2.errorOrClosed(r); false - case t: T @unchecked => - try - val u = f(t) - c2.sendOrClosed(u).isValue - catch - case t: Throwable => - c2.errorOrClosed(t) - false + case t: T @unchecked => c2.send(f(t)); true } } c2 @@ -320,26 +313,19 @@ trait SourceOps[+T] { outer: Source[T] => */ def takeWhile(f: T => Boolean, includeFirstFailing: Boolean = false)(using Ox, StageCapacity): Source[T] = val c = StageCapacity.newChannel[T] - fork { + forkPropagate(c) { repeatWhile { receiveOrClosed() match - case ChannelClosed.Done => - c.doneOrClosed().discard - false - case ChannelClosed.Error(reason) => - c.errorOrClosed(reason).discard - false + case ChannelClosed.Done => c.done(); false + case ChannelClosed.Error(reason) => c.error(reason); false case t: T @unchecked => - try - if f(t) then c.sendOrClosed(t).isValue - else - if includeFirstFailing then c.sendOrClosed(t).discard - c.doneOrClosed().discard - false - catch - case t: Throwable => - c.errorOrClosed(t).discard - false + if f(t) then + c.send(t) + true + else + if includeFirstFailing then c.send(t) + c.done() + false } } c @@ -615,28 +601,22 @@ trait SourceOps[+T] { outer: Source[T] => initializeState: () => S )(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = val c = StageCapacity.newChannel[U] - fork { + forkPropagate(c) { var state = initializeState() repeatWhile { receiveOrClosed() match case ChannelClosed.Done => - try - onComplete(state).foreach(c.sendOrClosed) - c.doneOrClosed() - catch case t: Throwable => c.errorOrClosed(t) + onComplete(state).foreach(c.send) + c.done() false case ChannelClosed.Error(r) => - c.errorOrClosed(r) + c.error(r) false case t: T @unchecked => - try - val (nextState, result) = f(state, t) - state = nextState - result.iterator.map(c.sendOrClosed).forall(_.isValue) - catch - case t: Throwable => - c.errorOrClosed(t) - false + val (nextState, result) = f(state, t) + state = nextState + result.iterator.foreach(c.send) + true } } c @@ -665,24 +645,19 @@ trait SourceOps[+T] { outer: Source[T] => */ def mapConcat[U](f: T => IterableOnce[U])(using Ox, StageCapacity): Source[U] = val c = StageCapacity.newChannel[U] - fork { + forkPropagate(c) { repeatWhile { receiveOrClosed() match case ChannelClosed.Done => - c.doneOrClosed() + c.done() false case ChannelClosed.Error(r) => - c.errorOrClosed(r) + c.error(r) false case t: T @unchecked => - try - val results: IterableOnce[U] = f(t) - results.iterator.foreach(c.send) - true - catch - case t: Throwable => - c.errorOrClosed(t) - false + val results: IterableOnce[U] = f(t) + results.iterator.foreach(c.send) + true } } c @@ -826,36 +801,29 @@ trait SourceOps[+T] { outer: Source[T] => def groupedWeighted(minWeight: Long)(costFn: T => Long)(using Ox, StageCapacity): Source[Seq[T]] = require(minWeight > 0, "minWeight must be > 0") val c2 = StageCapacity.newChannel[Seq[T]] - fork { + forkPropagate(c2) { var buffer = Vector.empty[T] var accumulatedCost = 0L repeatWhile { receiveOrClosed() match case ChannelClosed.Done => - if buffer.nonEmpty then c2.sendOrClosed(buffer).discard - c2.doneOrClosed() + if buffer.nonEmpty then c2.send(buffer) + c2.done() false case ChannelClosed.Error(r) => - c2.errorOrClosed(r) + c2.error(r) false case t: T @unchecked => buffer = buffer :+ t - val wasCostEvaluationSuccessful = - try - accumulatedCost += costFn(t) - true - catch - case t: Throwable => - c2.errorOrClosed(t).discard - false + accumulatedCost += costFn(t) - if wasCostEvaluationSuccessful && accumulatedCost >= minWeight then - val isValue = c2.sendOrClosed(buffer).isValue + if accumulatedCost >= minWeight then + c2.send(buffer) buffer = Vector.empty accumulatedCost = 0 - isValue - else wasCostEvaluationSuccessful + + true } } c2 @@ -935,7 +903,7 @@ trait SourceOps[+T] { outer: Source[T] => require(duration > 0.seconds, "duration must be > 0") val c2 = StageCapacity.newChannel[Seq[T]] val timerChannel = StageCapacity.newChannel[GroupingTimeout.type] - fork { + forkPropagate(c2) { var buffer = Vector.empty[T] var accumulatedCost: Long = 0 @@ -945,45 +913,38 @@ trait SourceOps[+T] { outer: Source[T] => } var timeoutFork: Option[CancellableFork[Unit]] = Some(forkTimeout()) - def sendBufferAndForkNewTimeout(): Boolean = - val isValue = c2.sendOrClosed(buffer).isValue + def sendBufferAndForkNewTimeout(): Unit = + c2.send(buffer) buffer = Vector.empty accumulatedCost = 0 timeoutFork.foreach(_.cancelNow()) - if isValue then timeoutFork = Some(forkTimeout()) // start a new timeout only if channel was not closed - isValue + timeoutFork = Some(forkTimeout()) repeatWhile { selectOrClosed(receiveClause, timerChannel.receiveClause) match case ChannelClosed.Done => timeoutFork.foreach(_.cancelNow()) - if buffer.nonEmpty then c2.sendOrClosed(buffer).discard - c2.doneOrClosed() + if buffer.nonEmpty then c2.send(buffer) + c2.done() false case ChannelClosed.Error(r) => timeoutFork.foreach(_.cancelNow()) - c2.errorOrClosed(r) + c2.error(r) false case timerChannel.Received(GroupingTimeout) => timeoutFork = None // enter 'timed out state', may stay in this state if buffer is empty if buffer.nonEmpty then sendBufferAndForkNewTimeout() - else true + true case Received(t) => buffer = buffer :+ t - val wasCostEvaluationSuccessful = - try - accumulatedCost += costFn(t) - true - catch - case t: Throwable => - c2.errorOrClosed(t).discard - timeoutFork.foreach(_.cancelNow()) - false - if wasCostEvaluationSuccessful && (timeoutFork.isEmpty || accumulatedCost >= minWeight) then + accumulatedCost += costFn(t).tapException(_ => timeoutFork.foreach(_.cancelNow())) + + if (timeoutFork.isEmpty || accumulatedCost >= minWeight) then // timeout passed when buffer was empty or buffer full sendBufferAndForkNewTimeout() - else wasCostEvaluationSuccessful + + true } } c2 diff --git a/core/src/main/scala/ox/channels/forkPropagate.scala b/core/src/main/scala/ox/channels/forkPropagate.scala new file mode 100644 index 00000000..fd067c20 --- /dev/null +++ b/core/src/main/scala/ox/channels/forkPropagate.scala @@ -0,0 +1,16 @@ +package ox.channels + +import ox.* + +/** Fork the given computation, propagating any exceptions to the given sink. The propagated exceptions are not rethrown. + * + * Designed to be used in stream operators. + * + * @see + * ADR#1, ADR#3, implementation note in [[SourceOps]]. + */ +def forkPropagate[T](propagateExceptionsTo: Sink[_])(f: => Unit)(using Ox): Fork[Unit] = + fork { + try f + catch case t: Throwable => propagateExceptionsTo.errorOrClosed(t).discard + }