diff --git a/README.md b/README.md index ae9b9762..69b746f6 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,21 @@ val result: (Int, String) = par(computation1)(computation2) If one of the computations fails, the other is interrupted, and `par` waits until both branches complete. +## Parallelize collection transformation + +```scala +import ox.mapPar + +val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + +val result: List[Int] = mapPar(input)(4)(_ + 1) +// (2, 3, 4, 5, 6, 7, 8, 9, 10) +``` + +If any transformation fails, others are interrupted and `mapPar` rethrows exception that was +thrown by the transformation. Parallelism +limits how many concurrent forks are going to process the collection. + ## Race two computations ```scala diff --git a/core/src/main/scala/ox/mapPar.scala b/core/src/main/scala/ox/mapPar.scala new file mode 100644 index 00000000..14e55eaa --- /dev/null +++ b/core/src/main/scala/ox/mapPar.scala @@ -0,0 +1,25 @@ +package ox + +import java.util.concurrent.Semaphore +import scala.collection.IterableFactory + +/** Runs parallel transformations on `iterable`. Using not more than `parallelism` forks concurrently. + * + * @param parallelism maximum number of concurrent forks + * @param iterable collection to transform + * @param transform transformation to apply to each element of `iterable` + */ +def mapPar[I, O, C[E] <: Iterable[E]](parallelism: Int)(iterable: => C[I])(transform: I => O): C[O] = + val s = Semaphore(parallelism) + + supervised { + val forks = iterable.map { elem => + s.acquire() + fork { + val o = transform(elem) + s.release() + o + } + } + forks.toSeq.map(f => f.join()).to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]]) + } diff --git a/core/src/main/scala/ox/syntax.scala b/core/src/main/scala/ox/syntax.scala index 087d1dc5..074645ab 100644 --- a/core/src/main/scala/ox/syntax.scala +++ b/core/src/main/scala/ox/syntax.scala @@ -12,6 +12,8 @@ object syntax: def forkDaemon: Fork[T] = ox.forkDaemon(f) def forkUnsupervised: Fork[T] = ox.forkUnsupervised(f) def forkCancellable: CancellableFork[T] = ox.forkCancellable(f) + + extension [T](f: => T) def timeout(duration: FiniteDuration): T = ox.timeout(duration)(f) def timeoutOption(duration: FiniteDuration): Option[T] = ox.timeoutOption(duration)(f) def scopedWhere[U](fl: ForkLocal[U], u: U): T = fl.scopedWhere(u)(f) @@ -24,3 +26,6 @@ object syntax: def useInScope: T = ox.useCloseableInScope(f) def useScoped[U](p: T => U): U = ox.useScoped(f)(p) def useSupervised[U](p: T => U): U = ox.useSupervised(f)(p) + + extension [I, C[E] <: Iterable[E]](f: => C[I]) + def mapPar[O](parallelism: Int)(transform: I => O) = ox.mapPar(parallelism)(f)(transform) diff --git a/core/src/test/scala/ox/MapParTest.scala b/core/src/test/scala/ox/MapParTest.scala new file mode 100644 index 00000000..97d554d7 --- /dev/null +++ b/core/src/test/scala/ox/MapParTest.scala @@ -0,0 +1,100 @@ +package ox + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.syntax.mapPar +import ox.util.Trail + +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.IterableFactory +import scala.collection.immutable.Iterable +import scala.List + +class MapParTest extends AnyFlatSpec with Matchers { + "mapPar" should "output the same type as input" in { + val input = List(1, 2, 3) + val result = input.mapPar(1)(identity) + result shouldBe a[List[_]] + } + + it should "run computations in parallel" in { + val InputElements = 17 + val TransformationMillis: Long = 100 + + val input = (0 to InputElements) + def transformation(i: Int) = { + Thread.sleep(TransformationMillis) + i + 1 + } + + val start = System.currentTimeMillis() + val result = input.to(Iterable).mapPar(5)(transformation) + val end = System.currentTimeMillis() + + result.toList should contain theSameElementsInOrderAs (input.map(_ + 1)) + (end - start) should be < (InputElements * TransformationMillis) + } + + it should "run not more computations than limit" in { + val Parallelism = 5 + + val input = (1 to 158) + + class MaxCounter { + val counter = new AtomicInteger(0) + var max = 0 + def increment() = { + counter.updateAndGet { c => + val inc = c + 1 + max = if (inc > max) inc else max + inc + } + } + def decrement() = { + counter.decrementAndGet() + } + } + + val maxCounter = new MaxCounter + + def transformation(i: Int) = { + maxCounter.increment() + Thread.sleep(10) + maxCounter.decrement() + } + + input.to(Iterable).mapPar(Parallelism)(transformation) + + maxCounter.max should be <= Parallelism + } + + it should "interrupt other computations in one fails" in { + val InputElements = 18 + val TransformationMillis: Long = 100 + val trail = Trail() + + val input = (0 to InputElements) + + def transformation(i: Int) = { + if (i == 4) { + trail.add("exception") + throw new Exception("boom") + } else { + Thread.sleep(TransformationMillis) + trail.add("transformation") + i + 1 + } + } + + try { + input.to(Iterable).mapPar(5)(transformation) + } catch { + case e: Exception if e.getMessage == "boom" => trail.add("catch") + } + + Thread.sleep(300) + trail.add("all done") + + trail.get shouldBe Vector("exception", "catch", "all done") + } +}