Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to process statements of single session concurrentlly (especially… #39

Open
wants to merge 1 commit into
base: v0.6.0-incubating-kubernetes-support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
} else {
assert(msg.from != null)
assert(msg.size != null)
if (msg.size == 1) {
if (msg.from == -1) {
Array(new Statement(-1, session.getBufferState(), StatementState.Rejected, null))
} else if (msg.size == 1) {
session.statements.get(msg.from).toArray
} else {
val until = msg.from + msg.size
Expand All @@ -86,6 +88,9 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
// Update progress of statements when queried
statements.foreach { s =>
s.updateProgress(session.progressOfStatement(s.id))
if (s.state.get() == StatementState.Available) {
session.markHasRead(s.id)
}
}

new ReplJobResults(statements.sortBy(_.id))
Expand Down
176 changes: 145 additions & 31 deletions repl/src/main/scala/org/apache/livy/repl/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.livy.repl
import java.util.{LinkedHashMap => JLinkedHashMap}
import java.util.Map.Entry
import java.util.concurrent.Executors
import java.util.Date
import java.util.concurrent.{Executors, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
Expand Down Expand Up @@ -58,7 +59,7 @@ class Session(
import Session._

private val interpreterExecutor = ExecutionContext.fromExecutorService(
Executors.newSingleThreadExecutor())
Executors.newFixedThreadPool(livyConf.getInt(RSCConf.Entry.SESSION_INTERPRETER_THREADS)))

private val cancelExecutor = ExecutionContext.fromExecutorService(
Executors.newSingleThreadExecutor())
Expand All @@ -70,11 +71,17 @@ class Session(
// Number of statements kept in driver's memory
private val numRetainedStatements = livyConf.getInt(RSCConf.Entry.RETAINED_STATEMENTS)

private val _statements = new JLinkedHashMap[Int, Statement] {
protected override def removeEldestEntry(eldest: Entry[Int, Statement]): Boolean = {
size() > numRetainedStatements
}
}.asScala
private val resultRetainedTimeout =
livyConf.getTimeAsMs(RSCConf.Entry.STATEMENT_RESULT_RETAINED_TIMEOUT)

private val resultDiscardTimeout =
livyConf.getTimeAsMs(RSCConf.Entry.STATEMENT_RESULT_DISCARD_TIMEOUT)

private val _statements = mutable.HashMap[Int, Statement]()

private var _expiredTimestamps = mutable.HashMap[Int, Long]()
// record recently expired item to prevent search _expiredTimestamps all the time
private var _recentlyExpiredItem: (Int, Long) = (0, 0)

private val newStatementId = new AtomicInteger(0)

Expand Down Expand Up @@ -148,34 +155,48 @@ class Session(
}

def execute(code: String, codeType: String = null): Int = {
val tpe = if (codeType != null) {
Kind(codeType)
} else if (defaultInterpKind != Shared) {
defaultInterpKind
if (isOverload(newStatementId.get())) {
// return statementId -1 means reject current code
-1
} else {
throw new IllegalArgumentException(s"Code type should be specified if session kind is shared")
}

val statementId = newStatementId.getAndIncrement()
val statement = new Statement(statementId, code, StatementState.Waiting, null)
_statements.synchronized { _statements(statementId) = statement }

Future {
setJobGroup(tpe, statementId)
statement.compareAndTransit(StatementState.Waiting, StatementState.Running)

if (statement.state.get() == StatementState.Running) {
statement.started = System.currentTimeMillis()
statement.output = executeCode(interpreter(tpe), statementId, code)
val tpe = if (codeType != null) {
Kind(codeType)
} else if (defaultInterpKind != Shared) {
defaultInterpKind
} else {
throw new IllegalArgumentException(
s"Code type should be specified if session kind is shared")
}

val statementId = newStatementId.getAndIncrement()
val statement = new Statement(statementId, code, StatementState.Waiting, null)
_statements.synchronized { _statements(statementId) = statement }

Future {
this.synchronized { setJobGroup(tpe, statementId) }
statement.compareAndTransit(StatementState.Waiting, StatementState.Running)

if (statement.state.get() == StatementState.Running) {
statement.started = System.currentTimeMillis()
statement.output = executeCode(interpreter(tpe), statementId, code)
}

if (statement.state.get() == StatementState.Running) {
_expiredTimestamps.synchronized {
_expiredTimestamps(statement.id) = new Date().getTime + resultRetainedTimeout
}
} else {
markHasRead(statement.id)
}

statement.compareAndTransit(StatementState.Running, StatementState.Available)
statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled)
statement.updateProgress(1.0)
statement.completed = System.currentTimeMillis()
}(interpreterExecutor)
statement.compareAndTransit(StatementState.Running, StatementState.Available)
statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled)
statement.updateProgress(1.0)

}(interpreterExecutor)

statementId
statementId
}
}

def complete(code: String, codeType: String, cursor: Int): Array[String] = {
Expand Down Expand Up @@ -361,4 +382,97 @@ class Session(
private def statementIdToJobGroup(statementId: Int): String = {
statementId.toString
}

private def snapshot(id: Int, msg: String): Unit = {
debug(s"statement No.${id} ${msg}")
val (running, o) = _statements.values.partition(_.state.get() == StatementState.Running)
debug(s"Buffer Size: ${numRetainedStatements}\tUsed Size: ${_statements.size}\t" +
s"Finished: ${_expiredTimestamps.size} Running: ${running.size} " +
s"Waiting: ${o.size-_expiredTimestamps.size}")
}

/**
* return false when _statements size have not reach upper limit (numRetainedStatements)
* or some expired statement can be remove from _statements, otherwise return true
*/
private def isOverload(proposer: Int): Boolean = _statements.synchronized {
if (_statements.size < numRetainedStatements) {
snapshot(proposer, "will be accepted")
false
} else if (checkExpired) {
snapshot(proposer, s"will be accepted after cleanUpExpired")
cleanUpExpired()
false
} else {
snapshot(proposer, "is rejected")
true
}
}

private def checkExpired: Boolean = _expiredTimestamps.synchronized {
if (_expiredTimestamps.size == 0) {
false
} else {
if (_recentlyExpiredItem._2 == 0) {
_recentlyExpiredItem = _expiredTimestamps.toArray.sortBy(_._2).head
}
new Date().getTime > _recentlyExpiredItem._2
}
}

private def cleanUpExpired(): Unit = _statements.synchronized {
_expiredTimestamps.synchronized {
assert(_expiredTimestamps.size > 0)
val sorted = mutable.PriorityQueue[(Int, Long)](_expiredTimestamps.toSeq: _*)
(Ordering.by[(Int, Long), Long](_._2).reverse)
_recentlyExpiredItem = sorted.dequeue()
val now = new Date().getTime
while (_recentlyExpiredItem._2 != 0 && _recentlyExpiredItem._2 <= now ) {
_statements.remove(_recentlyExpiredItem._1)
if (sorted.size > 0) {
_recentlyExpiredItem = sorted.dequeue()
} else {
_recentlyExpiredItem = (0, 0)
}
}
if (_recentlyExpiredItem._2 > 0) {
sorted.enqueue(_recentlyExpiredItem)
}
_expiredTimestamps = mutable.HashMap[Int, Long](sorted.toSeq: _*)
}
}

def markHasRead(id: Integer): Unit = {
if (resultDiscardTimeout == 0) {
_statements.synchronized {
_expiredTimestamps.synchronized {
_statements.remove(id)
_expiredTimestamps.remove(id)
if (_recentlyExpiredItem._1 == id){
_recentlyExpiredItem = (0, 0)
}
}
}
} else {
_expiredTimestamps.synchronized {
val expiredTime = new Date().getTime + resultDiscardTimeout
_expiredTimestamps(id) = expiredTime
if (expiredTime < _recentlyExpiredItem._2) {
_recentlyExpiredItem = (id, expiredTime)
}
}
}
snapshot(id, s"read success, expired after ${resultDiscardTimeout} ms")
}

def getBufferState(): String = _statements.synchronized {
if (_statements.size < numRetainedStatements || checkExpired) {
"buffer is free, please try to resubmit the code"
} else {
"buffer is busy, maybe free after " +
s"${TimeUnit.SECONDS.convert(_recentlyExpiredItem._2
- new Date().getTime, TimeUnit.MILLISECONDS)}s"
}
}

}
28 changes: 21 additions & 7 deletions repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ package org.apache.livy.repl
import java.util.Properties
import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}

import scala.concurrent.Await
import scala.concurrent.duration.Duration

import org.apache.spark.SparkConf
import org.json4s.JsonAST
import org.scalatest.{BeforeAndAfter, FunSpec}
import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
Expand All @@ -29,6 +33,7 @@ import org.scalatest.time._
import org.apache.livy.LivyBaseUnitTestSuite
import org.apache.livy.repl.Interpreter.ExecuteResponse
import org.apache.livy.rsc.RSCConf
import org.apache.livy.rsc.driver.SparkEntries
import org.apache.livy.sessions._

class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite with BeforeAndAfter {
Expand Down Expand Up @@ -88,10 +93,19 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite wit
}
}

it("should remove old statements when reaching threshold") {
it("should remove expired statements when reaching threshold") {
rscConf.set(RSCConf.Entry.RETAINED_STATEMENTS, 2)
session = new Session(rscConf, new SparkConf())
session.start()
rscConf.set(RSCConf.Entry.STATEMENT_RESULT_RETAINED_TIMEOUT, "1s")
val interpreter = new SparkInterpreter(new SparkConf()) {
override def execute(code: String): ExecuteResponse = {
Interpreter.ExecuteSuccess(new org.json4s.JObject(List[JsonAST.JField]()))
}
override def postStart(): Unit = {
entries = new SparkEntries(conf)
}
}
session = new Session(rscConf, new SparkConf(), Some(interpreter))
Await.result(session.start(), Duration.Inf)

session.statements.size should be (0)
session.execute("")
Expand All @@ -100,10 +114,10 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite wit
session.execute("")
session.statements.size should be (2)
session.statements.map(_._1).toSet should be (Set(0, 1))
session.execute("")
eventually {
session.statements.size should be (2)
session.statements.map(_._1).toSet should be (Set(1, 2))
eventually(timeout(Span(1500, Millis))) {
session.execute("")
session.statements.size should be (1)
session.statements.map(_._1).toSet should be (Set(2))
}

// Continue submitting statements, total statements in memory should be 2.
Expand Down
4 changes: 4 additions & 0 deletions rsc/src/main/java/org/apache/livy/rsc/RSCConf.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public static enum Entry implements ConfEntry {
CLIENT_SHUTDOWN_TIMEOUT("client.shutdown-timeout", "10s"),
DRIVER_CLASS("driver-class", null),
SESSION_KIND("session.kind", null),
SESSION_INTERPRETER_THREADS("session.interpreter.threadpool.size", 1),

LIVY_JARS("jars", null),
SPARKR_PACKAGE("sparkr.package", null),
Expand All @@ -70,6 +71,9 @@ public static enum Entry implements ConfEntry {

SASL_MECHANISMS("rpc.sasl.mechanisms", "DIGEST-MD5"),
SASL_QOP("rpc.sasl.qop", null),

STATEMENT_RESULT_RETAINED_TIMEOUT("result-retained.timeout", "1h"),
STATEMENT_RESULT_DISCARD_TIMEOUT("result-discard.timeout", "10m"),

TEST_NO_CODE_COVERAGE_ANALYSIS("test.do-not-use.no-code-coverage-analysis", false),
TEST_STUCK_END_SESSION("test.do-not-use.stuck-end-session", false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.slf4j.LoggerFactory;

public enum StatementState {
Rejected("rejected"),
Waiting("waiting"),
Running("running"),
Available("available"),
Expand Down