Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature_counts function to Spark SQL #168

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/NA12878.slice.bam
Binary file not shown.
Binary file added data/NA12878.slice.bam.bai
Binary file not shown.
2 changes: 2 additions & 0 deletions data/NA12878.slice.fasta

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions data/test.bed
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
GeneId Chr Start End Strand
1 chr1 11 99 sss
2 chr1 222 257 a
3 chr1 800 2200 test
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import org.apache.spark.sql.SparkSession
import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim
import org.biodatageeks.sequila.utils.{Columns, InternalParams}
import org.rogach.scallop.ScallopConf
import org.seqdoop.hadoop_bam.{BAMInputFormat, SAMRecordWritable}
import org.seqdoop.hadoop_bam.util.SAMHeaderReader
import org.seqdoop.hadoop_bam.{BAMInputFormat, SAMRecordWritable}

object FeatureCounts {
case class Region(contig:String, pos_start:Int, pos_end:Int)
Expand All @@ -25,6 +25,7 @@ object FeatureCounts {
val spark = SparkSession
.builder()
.appName("SeQuiLa-FC")
.config("spark.master", "local[4]")
.getOrCreate()

spark.sqlContext.setConf(InternalParams.useJoinOrder,"true")
Expand All @@ -41,11 +42,11 @@ object FeatureCounts {
count(*) AS Counts
FROM reads JOIN targets
|ON (
| targets.Chr=reads.contigName
| targets.Chr=reads.contig
| AND
| reads.end >= CAST(targets.Start AS INTEGER)
| reads.pos_end >= CAST(targets.Start AS INTEGER)
| AND
| reads.start <= CAST(targets.End AS INTEGER)
| reads.pos_start <= CAST(targets.End AS INTEGER)
|)
|GROUP BY targets.GeneId,targets.Chr,targets.Start,targets.End,targets.Strand""".stripMargin
spark
Expand All @@ -70,7 +71,7 @@ object FeatureCounts {
.option("delimiter", "\t")
.csv(runConf.annotations())
targets
.withColumnRenamed("contigName", Columns.CONTIG)
.withColumnRenamed("contig", Columns.CONTIG)
.createOrReplaceTempView("targets")

spark.sql(query)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.biodatageeks.sequila.apps

import htsjdk.samtools.ValidationStringency
import org.apache.spark.sql.{SequilaSession, SparkSession}
import org.rogach.scallop.ScallopConf
import org.seqdoop.hadoop_bam.util.SAMHeaderReader

object FeatureCountsFunc {

class RunConf(args:Array[String]) extends ScallopConf(args){

val output = opt[String](required = true)
val annotations = opt[String](required = true)
val readsFile = trailArg[String](required = true)
verify()
}

def main(args: Array[String]): Unit = {
val runConf = new RunConf(args)
val spark = SparkSession
.builder()
.appName("SeQuiLa-DoC")
.config("spark.master", "local[4]")
.getOrCreate()

spark
.sparkContext
.setLogLevel("ERROR")

spark
.sparkContext
.hadoopConfiguration.set(SAMHeaderReader.VALIDATION_STRINGENCY_PROPERTY, ValidationStringency.SILENT.toString)

val ss = SequilaSession(spark)

val query = s"Select fc.*" +
s" FROM feature_counts('${runConf.readsFile()}', '${runConf.annotations()}') fc "

ss.sql(query)
.orderBy("sample_id")
.coalesce(1)
.show()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package org.biodatageeks.sequila.rangejoins.methods.IntervalTree

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SequilaSession.logger
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SizeEstimator
import org.biodatageeks.sequila.datasources.BAM.{BAMFileReader, BDGAlignFileReaderWriter}
import org.biodatageeks.sequila.rangejoins.IntervalTree.Interval
import org.biodatageeks.sequila.rangejoins.optimizer.{JoinOptimizerChromosome, RangeJoinMethod}
import org.openjdk.jol.info.GraphLayout
import org.seqdoop.hadoop_bam.BAMBDGInputFormat

import scala.collection.JavaConversions._

case class FeatureCountsPlan(spark: SparkSession,
readsPath: String,
genesPath: String,
output: Seq[Attribute],
minOverlap: Int,
maxGap: Int,
intervalHolderClassName: String)
extends SparkPlan with Serializable with BDGAlignFileReaderWriter[BAMBDGInputFormat] {

override protected def doExecute(): RDD[InternalRow] = {
val reads = new BAMFileReader[BAMBDGInputFormat](spark, readsPath, None).readFile
val genesRdd = spark
.read
.option("header", "true")
.option("delimiter", "\t")
.csv(genesPath)
.rdd

val optimizer = new JoinOptimizerChromosome(spark, genesRdd, genesRdd.count())
logger.info(optimizer.debugInfo)

if (optimizer.getRangeJoinMethod == RangeJoinMethod.JoinWithRowBroadcast) {
val localIntervals = {
if (maxGap != 0) {
genesRdd
.map(r => (r.getString(1), toInterval(r, maxGap), toInternalRow(r)))
} else {
genesRdd
.map(r => (r.getString(1), toInterval(r), toInternalRow(r)))
}
}
.collect()

val intervalTree = {
val tree = new IntervalHolderChromosome[InternalRow](localIntervals, intervalHolderClassName)
try {
val treeSize = GraphLayout.parseInstance(tree).totalSize()
logger.info(s"Real broadcast size of the interval structure is ${treeSize} bytes")
}
catch {
case e@(_: NoClassDefFoundError | _: ExceptionInInitializerError) =>
logger.error("Cannot get broadcast size, method ObjectSizeCalculator.getObjectSize not available falling back to Spark method")
val treeSize = SizeEstimator.estimate(tree)
logger.info(s"Real broadcast size of the interval structure is ${treeSize} bytes")

}
spark.sparkContext.broadcast(tree)
}

val v3 = reads.mapPartitions(readIterator => {
readIterator.map(read => {
intervalTree.value.getIntervalTreeByChromosome(read.getContig) match {
case Some(t) => {
val record = t.overlappers(read.getStart, read.getEnd)
if (minOverlap != 1) {
record
.filter(r => calcOverlap(read.getStart, read.getEnd, r.getStart, r.getEnd) >= minOverlap)
.flatMap(k => k.getValue)
} else {
record.flatMap(k => k.getValue)
}
}
case _ => Iterator.empty
}
})
})
.flatMap(r => r)
.groupBy(x => x)
.mapValues(_.size)
.map(entry => {
entry._1.setInt(6, entry._2)
toUnsafeRow(entry._1)
})

v3
} else {
val genesRddWithIndex = genesRdd.zipWithIndex()
val localIntervals = genesRddWithIndex
.map(r=>(r._1.getString(1), toInterval(r._1) ,r._2))
.collect()

val intervalTree = {
val tree = new IntervalHolderChromosome[Long](localIntervals, intervalHolderClassName)
spark.sparkContext.broadcast(tree)
}

val v3 = reads.mapPartitions(readIterator => {
readIterator.map(read => {
intervalTree.value.getIntervalTreeByChromosome(read.getContig) match {
case Some(t) => {
val record = t.overlappers(read.getStart, read.getEnd)
if (minOverlap != 1) {
record
.filter(r => calcOverlap(read.getStart, read.getEnd, r.getStart, r.getEnd) >= minOverlap)
.flatMap(k => k.getValue.map(s => (s, s)))
} else {
record.flatMap(k => k.getValue.map(s => (s, s)))
}
}
case _ => Iterator.empty
}
})
})
.flatMap(r => r)

val intGenesRdd = genesRddWithIndex.map(r => (r._2, toInternalRow(r._1)))
val result = v3.
join(intGenesRdd)
.map(l => l._2._2)
.groupBy(x => x)
.mapValues(_.size)
.map(entry => {
entry._1.setInt(6, entry._2)
toUnsafeRow(entry._1)
})
result
}
}

private def toUnsafeRow(r: InternalRow): InternalRow = {
val proj = UnsafeProjection.create(schema)
proj.apply(r)
}

private def toInternalRow(r: Row): InternalRow = {
InternalRow.fromSeq(Seq(
UTF8String.fromString(r.getString(0)), //Sample
UTF8String.fromString(r.getString(1)), //Contig
r.getString(2).toInt, //Start
r.getString(3).toInt, //End
UTF8String.fromString(r.getString(4)), //Strand
r.getString(3).toInt - r.getString(2).toInt, //Length
0 //count
))
}

private def calcOverlap(start1: Int, end1: Int, start2: Int, end2: Int) = (math.min(end1, end2) - math.max(start1, start2) + 1)

private def toInterval(r: Row) : Interval[Int] = {
Interval(r.getString(2).toInt, r.getString(3).toInt)
}

private def toInterval(r: Row, maxGap: Int) : Interval[Int] = {
Interval(r.getString(2).toInt - maxGap, r.getString(3).toInt + maxGap)
}

override def children: Seq[SparkPlan] = Nil
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package org.biodatageeks.sequila.rangejoins.IntervalTree

import org.biodatageeks.sequila.rangejoins.common.{ExtractRangeJoinKeys, ExtractRangeJoinKeysWithEquality}
import org.biodatageeks.sequila.rangejoins.methods.IntervalTree.IntervalTreeJoinOptimChromosome
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.{FeatureCountsTemplate, SparkSession, Strategy}
import org.biodatageeks.sequila.rangejoins.common.{ExtractRangeJoinKeys, ExtractRangeJoinKeysWithEquality}
import org.biodatageeks.sequila.rangejoins.methods.IntervalTree.{FeatureCountsPlan, IntervalTreeJoinOptimChromosome}
import org.biodatageeks.sequila.utils.InternalParams

import scala.annotation.tailrec

/**
* Created by marek on 27/01/2018.
*/
Expand Down Expand Up @@ -42,6 +40,17 @@ class IntervalTreeJoinStrategyOptim(spark: SparkSession) extends Strategy with S
useJoinOrder.toBoolean,
intervalHolderClassName) :: Nil
}
case FeatureCountsTemplate(reads, genes, output) =>
val minOverlap = spark.sqlContext.getConf(InternalParams.minOverlap,"1")
val maxGap = spark.sqlContext.getConf(InternalParams.maxGap,"0")
FeatureCountsPlan(
spark,
reads,
genes,
output,
minOverlap.toInt,
maxGap.toInt,
intervalHolderClassName) :: Nil
case _ =>
Nil
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
package org.biodatageeks.sequila.rangejoins.optimizer

import org.apache.log4j.Logger
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.util.SizeEstimator
import org.biodatageeks.sequila.rangejoins.IntervalTree.{Interval, IntervalWithRow}
import org.biodatageeks.sequila.rangejoins.optimizer.RangeJoinMethod.RangeJoinMethod
import org.biodatageeks.sequila.utils.InternalParams
import org.openjdk.jol.info.GraphLayout


class JoinOptimizerChromosome(spark: SparkSession, rdd: RDD[(String,Interval[Int],InternalRow)], rddCount : Long) {
class JoinOptimizerChromosome [T <: Object] (spark: SparkSession, rdd: RDD[T], rddCount : Long) {

val logger = Logger.getLogger(this.getClass.getCanonicalName)
val maxBroadcastSize = spark.sqlContext
val logger: Logger = Logger.getLogger(this.getClass.getCanonicalName)
val maxBroadcastSize: Double = spark.sqlContext
.getConf(InternalParams.maxBroadCastSize,"0") match {
case "0" => 0.1*scala.math.max((spark.sparkContext.getConf.getSizeAsBytes("spark.driver.memory","0")),1024*(1024*1024)) //defaults 128MB or 0.1 * Spark Driver's memory
case _ => spark.sqlContext.getConf(InternalParams.maxBroadCastSize).toLong }
val estBroadcastSize = estimateBroadcastSize(rdd,rddCount)
val estBroadcastSize: Long = estimateBroadcastSize(rdd,rddCount)


private def estimateBroadcastSize(rdd: RDD[(String,Interval[Int],InternalRow)], rddCount: Long): Long = {
private def estimateBroadcastSize(rdd: RDD[T], rddCount: Long): Long = {
try{
(GraphLayout.parseInstance(rdd.first()).totalSize() * rddCount)
}
Expand All @@ -36,32 +31,24 @@ class JoinOptimizerChromosome(spark: SparkSession, rdd: RDD[(String,Interval[Int
//FIXME: Do not know why the size ~10x the actual size is- Spark row representation or getObject size in bits???
}

def debugInfo = {
def debugInfo: String = {
s"""
|Estimated broadcast structure size is ~ ${math.rint(100*estBroadcastSize/1024.0)/100} kb
|${InternalParams.maxBroadCastSize} is set to ${(maxBroadcastSize/1024).toInt} kb"
|Using ${getRangeJoinMethod.toString} join method
""".stripMargin
}

private def estimateRDDSizeSpark(rdd: RDD[(String,Interval[Int],InternalRow)]): Long = {
math.round(SizeEstimator.estimate(rdd)/1024.0)
}

/**
* Choose range join method to use basic on estimated size of the underlying data struct for broadcast
* @param rdd
* @return
*/
def getRangeJoinMethod : RangeJoinMethod ={

if (estimateBroadcastSize(rdd, rddCount) <= maxBroadcastSize)
def getRangeJoinMethod : RangeJoinMethod = {
if (estBroadcastSize <= maxBroadcastSize)
RangeJoinMethod.JoinWithRowBroadcast
else
RangeJoinMethod.TwoPhaseJoin

}



}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ object Columns {
final val COUNT_REF="countRef"
final val COUNT_NONREF="countNonRef"
final val QUALS="quals"
final val LENGTH = "Length"

private val sequencedFragmentColumns = ScalaFuncs
.classAccessors[SequencedFragment]
Expand Down
Loading