diff --git a/astcreator/src/main/scala/com/github/plume/oss/passes/IncrementalKeyPool.scala b/astcreator/src/main/scala/com/github/plume/oss/passes/IncrementalKeyPool.scala deleted file mode 100644 index 5ea154b3..00000000 --- a/astcreator/src/main/scala/com/github/plume/oss/passes/IncrementalKeyPool.scala +++ /dev/null @@ -1,38 +0,0 @@ -package com.github.plume.oss.passes - -import io.shiftleft.passes.KeyPool - -import java.util.concurrent.atomic.AtomicLong - -class IncrementalKeyPool(val first: Long, val last: Long, private val usedIds: Set[Long]) extends KeyPool { - - override def next: Long = { - if (!valid) { - throw new IllegalStateException("Call to `next` on invalidated IncrementalKeyPool.") - } - var n = cur.incrementAndGet() - while (n <= last) { - if (!usedIds.contains(n)) return n - else n = cur.incrementAndGet() - } - throw new RuntimeException("Pool exhausted") - } - - def split(numberOfPartitions: Int): Iterator[IncrementalKeyPool] = { - valid = false - if (numberOfPartitions == 0) { - Iterator() - } else { - val curFirst = cur.get() - val k = (last - curFirst) / numberOfPartitions - (1 to numberOfPartitions).map { i => - val poolFirst = curFirst + (i - 1) * k - new IncrementalKeyPool(poolFirst, poolFirst + k - 1, usedIds) - }.iterator - } - } - - private var valid: Boolean = true - private val cur: AtomicLong = new AtomicLong(first - 1) - -} diff --git a/astcreator/src/main/scala/com/github/plume/oss/passes/PlumeConcurrentWriterPass.scala b/astcreator/src/main/scala/com/github/plume/oss/passes/PlumeConcurrentWriterPass.scala deleted file mode 100644 index c82eba25..00000000 --- a/astcreator/src/main/scala/com/github/plume/oss/passes/PlumeConcurrentWriterPass.scala +++ /dev/null @@ -1,58 +0,0 @@ -package com.github.plume.oss.passes - -import com.github.plume.oss.drivers.IDriver -import io.shiftleft.SerializedCpg -import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.utils.ExecutionContextProvider -import overflowdb.BatchedUpdate.DiffGraphBuilder - -import scala.collection.mutable -import scala.concurrent.duration.Duration -import scala.concurrent.{Await, ExecutionContext, Future} - -object PlumeConcurrentWriterPass { - private val writerQueueCapacity = 4 - private val producerQueueCapacity = 2 + 4 * Runtime.getRuntime.availableProcessors() -} - -abstract class PlumeConcurrentWriterPass[T <: AnyRef](driver: IDriver) { - - @volatile var nDiffT = -1 - - def generateParts(): Array[? <: AnyRef] - - // main function: add desired changes to builder - def runOnPart(builder: DiffGraphBuilder, part: T): Unit - - def createAndApply(): Unit = { - import PlumeConcurrentWriterPass.producerQueueCapacity - var nParts = 0 - var nDiff = 0 - nDiffT = -1 - val parts = generateParts() - nParts = parts.length - val partIter = parts.iterator - val completionQueue = mutable.ArrayDeque[Future[overflowdb.BatchedUpdate.DiffGraph]]() - implicit val ec: ExecutionContext = ExecutionContextProvider.getExecutionContext - var done = false - while (!done) { - if (completionQueue.size < producerQueueCapacity && partIter.hasNext) { - val next = partIter.next() - completionQueue.append(Future.apply { - val builder = Cpg.newDiffGraphBuilder - runOnPart(builder, next.asInstanceOf[T]) - val builtGraph = builder.build() - driver.bulkTx(builtGraph) - builtGraph - }) - } else if (completionQueue.nonEmpty) { - val future = completionQueue.removeHead() - val res = Await.result(future, Duration.Inf) - nDiff += res.size - } else { - done = true - } - } - } - -} diff --git a/astcreator/src/main/scala/com/github/plume/oss/passes/PlumeForkJoinParallelCpgPass.scala b/astcreator/src/main/scala/com/github/plume/oss/passes/PlumeForkJoinParallelCpgPass.scala new file mode 100644 index 00000000..8a068282 --- /dev/null +++ b/astcreator/src/main/scala/com/github/plume/oss/passes/PlumeForkJoinParallelCpgPass.scala @@ -0,0 +1,113 @@ +package com.github.plume.oss.passes + +import com.github.plume.oss.drivers.IDriver +import io.shiftleft.SerializedCpg +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.utils.ExecutionContextProvider +import io.shiftleft.codepropertygraph.generated.nodes.AbstractNode +import io.shiftleft.passes.CpgPassBase +import overflowdb.BatchedUpdate.DiffGraphBuilder + +import java.util.function.* +import scala.annotation.nowarn +import scala.collection.mutable +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} + +abstract class PlumeForkJoinParallelCpgPass[T <: AnyRef](driver: IDriver, @nowarn outName: String = "") + extends CpgPassBase { + + // generate Array of parts that can be processed in parallel + def generateParts(): Array[? <: AnyRef] + + // setup large data structures, acquire external resources + def init(): Unit = {} + + // release large data structures and external resources + def finish(): Unit = {} + + // main function: add desired changes to builder + def runOnPart(builder: DiffGraphBuilder, part: T): Unit + + // Override this to disable parallelism of passes. Useful for debugging. + def isParallel: Boolean = true + + override def createAndApply(): Unit = createApplySerializeAndStore(null) + + override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = { + try { + init() + val parts = generateParts() + val nParts = parts.size + nParts match { + case 0 => + case 1 => + runOnPart(externalBuilder, parts(0).asInstanceOf[T]) + case _ => + val stream = + if (!isParallel) + java.util.Arrays + .stream(parts) + .sequential() + else + java.util.Arrays + .stream(parts) + .parallel() + val diff = stream.collect( + new Supplier[DiffGraphBuilder] { + override def get(): DiffGraphBuilder = + Cpg.newDiffGraphBuilder + }, + new BiConsumer[DiffGraphBuilder, AnyRef] { + override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = + runOnPart(builder, part.asInstanceOf[T]) + }, + new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { + override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = + leftBuilder.absorb(rightBuilder) + } + ) + externalBuilder.absorb(diff) + } + nParts + } finally { + finish() + } + } + + override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = { + baseLogger.info(s"Start of pass: $name") + val nanosStart = System.nanoTime() + var nParts = 0 + var nanosBuilt = -1L + var nDiff = -1 + var nDiffT = -1 + try { + val diffGraph = Cpg.newDiffGraphBuilder + nParts = runWithBuilder(diffGraph) + nanosBuilt = System.nanoTime() + nDiff = diffGraph.size + driver.bulkTx(diffGraph) + } catch { + case exc: Exception => + baseLogger.error(s"Pass ${name} failed", exc) + throw exc + } finally { + try { + finish() + } finally { + // the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish() + // in the reported timings, and we must have our final log message if finish() throws + val nanosStop = System.nanoTime() + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) + val serializationString = if (serializedCpg != null && !serializedCpg.isEmpty) { + " Diff serialized and stored." + } else "" + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts.${serializationString}%s" + ) + } + } + } + +} diff --git a/astcreator/src/main/scala/com/github/plume/oss/passes/base/AstCreationPass.scala b/astcreator/src/main/scala/com/github/plume/oss/passes/base/AstCreationPass.scala index 77d23b23..708b51c9 100644 --- a/astcreator/src/main/scala/com/github/plume/oss/passes/base/AstCreationPass.scala +++ b/astcreator/src/main/scala/com/github/plume/oss/passes/base/AstCreationPass.scala @@ -3,7 +3,7 @@ package com.github.plume.oss.passes.base import better.files.File import com.github.plume.oss.JimpleAst2Database import com.github.plume.oss.drivers.IDriver -import com.github.plume.oss.passes.PlumeConcurrentWriterPass +import com.github.plume.oss.passes.PlumeForkJoinParallelCpgPass import io.joern.x2cpg.ValidationMode import io.joern.x2cpg.datastructures.Global import org.slf4j.LoggerFactory @@ -15,7 +15,7 @@ import java.nio.file.Paths /** Creates the AST layer from the given class file and stores all types in the given global parameter. */ class AstCreationPass(filenames: List[String], driver: IDriver, unpackingRoot: File) - extends PlumeConcurrentWriterPass[String](driver) { + extends PlumeForkJoinParallelCpgPass[String](driver) { val global: Global = new Global() private val logger = LoggerFactory.getLogger(classOf[AstCreationPass])