Skip to content

Commit

Permalink
Merge pull request #52 from softwaremill/mapPar
Browse files Browse the repository at this point in the history
mapPar for collections
  • Loading branch information
adamw authored Nov 23, 2023
2 parents f28aaba + 4aa8a6a commit 299e03f
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 0 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ val result: (Int, String) = par(computation1)(computation2)

If one of the computations fails, the other is interrupted, and `par` waits until both branches complete.

## Parallelize collection transformation

```scala
import ox.mapPar

val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val result: List[Int] = mapPar(input)(4)(_ + 1)
// (2, 3, 4, 5, 6, 7, 8, 9, 10)
```

If any transformation fails, others are interrupted and `mapPar` rethrows exception that was
thrown by the transformation. Parallelism
limits how many concurrent forks are going to process the collection.

## Race two computations

```scala
Expand Down
25 changes: 25 additions & 0 deletions core/src/main/scala/ox/mapPar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ox

import java.util.concurrent.Semaphore
import scala.collection.IterableFactory

/** Runs parallel transformations on `iterable`. Using not more than `parallelism` forks concurrently.
*
* @param parallelism maximum number of concurrent forks
* @param iterable collection to transform
* @param transform transformation to apply to each element of `iterable`
*/
def mapPar[I, O, C[E] <: Iterable[E]](parallelism: Int)(iterable: => C[I])(transform: I => O): C[O] =
val s = Semaphore(parallelism)

supervised {
val forks = iterable.map { elem =>
s.acquire()
fork {
val o = transform(elem)
s.release()
o
}
}
forks.toSeq.map(f => f.join()).to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]])
}
5 changes: 5 additions & 0 deletions core/src/main/scala/ox/syntax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ object syntax:
def forkDaemon: Fork[T] = ox.forkDaemon(f)
def forkUnsupervised: Fork[T] = ox.forkUnsupervised(f)
def forkCancellable: CancellableFork[T] = ox.forkCancellable(f)

extension [T](f: => T)
def timeout(duration: FiniteDuration): T = ox.timeout(duration)(f)
def timeoutOption(duration: FiniteDuration): Option[T] = ox.timeoutOption(duration)(f)
def scopedWhere[U](fl: ForkLocal[U], u: U): T = fl.scopedWhere(u)(f)
Expand All @@ -24,3 +26,6 @@ object syntax:
def useInScope: T = ox.useCloseableInScope(f)
def useScoped[U](p: T => U): U = ox.useScoped(f)(p)
def useSupervised[U](p: T => U): U = ox.useSupervised(f)(p)

extension [I, C[E] <: Iterable[E]](f: => C[I])
def mapPar[O](parallelism: Int)(transform: I => O) = ox.mapPar(parallelism)(f)(transform)
100 changes: 100 additions & 0 deletions core/src/test/scala/ox/MapParTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package ox

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.syntax.mapPar
import ox.util.Trail

import java.util.concurrent.atomic.AtomicInteger
import scala.collection.IterableFactory
import scala.collection.immutable.Iterable
import scala.List

class MapParTest extends AnyFlatSpec with Matchers {
"mapPar" should "output the same type as input" in {
val input = List(1, 2, 3)
val result = input.mapPar(1)(identity)
result shouldBe a[List[_]]
}

it should "run computations in parallel" in {
val InputElements = 17
val TransformationMillis: Long = 100

val input = (0 to InputElements)
def transformation(i: Int) = {
Thread.sleep(TransformationMillis)
i + 1
}

val start = System.currentTimeMillis()
val result = input.to(Iterable).mapPar(5)(transformation)
val end = System.currentTimeMillis()

result.toList should contain theSameElementsInOrderAs (input.map(_ + 1))
(end - start) should be < (InputElements * TransformationMillis)
}

it should "run not more computations than limit" in {
val Parallelism = 5

val input = (1 to 158)

class MaxCounter {
val counter = new AtomicInteger(0)
var max = 0
def increment() = {
counter.updateAndGet { c =>
val inc = c + 1
max = if (inc > max) inc else max
inc
}
}
def decrement() = {
counter.decrementAndGet()
}
}

val maxCounter = new MaxCounter

def transformation(i: Int) = {
maxCounter.increment()
Thread.sleep(10)
maxCounter.decrement()
}

input.to(Iterable).mapPar(Parallelism)(transformation)

maxCounter.max should be <= Parallelism
}

it should "interrupt other computations in one fails" in {
val InputElements = 18
val TransformationMillis: Long = 100
val trail = Trail()

val input = (0 to InputElements)

def transformation(i: Int) = {
if (i == 4) {
trail.add("exception")
throw new Exception("boom")
} else {
Thread.sleep(TransformationMillis)
trail.add("transformation")
i + 1
}
}

try {
input.to(Iterable).mapPar(5)(transformation)
} catch {
case e: Exception if e.getMessage == "boom" => trail.add("catch")
}

Thread.sleep(300)
trail.add("all done")

trail.get shouldBe Vector("exception", "catch", "all done")
}
}

0 comments on commit 299e03f

Please sign in to comment.