diff --git a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala index b805a4db1..108281131 100644 --- a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala +++ b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala @@ -75,7 +75,9 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) } else { assert(msg.from != null) assert(msg.size != null) - if (msg.size == 1) { + if (msg.from == -1) { + Array(new Statement(-1, session.getBufferState(), StatementState.Rejected, null)) + } else if (msg.size == 1) { session.statements.get(msg.from).toArray } else { val until = msg.from + msg.size @@ -86,6 +88,9 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) // Update progress of statements when queried statements.foreach { s => s.updateProgress(session.progressOfStatement(s.id)) + if (s.state.get() == StatementState.Available) { + session.markHasRead(s.id) + } } new ReplJobResults(statements.sortBy(_.id)) diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index 262c811c7..bf231857f 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -20,9 +20,10 @@ package org.apache.livy.repl import java.util.{LinkedHashMap => JLinkedHashMap} import java.util.Map.Entry import java.util.concurrent.Executors +import java.util.Date +import java.util.concurrent.{Executors, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ @@ -58,7 +59,7 @@ class Session( import Session._ private val interpreterExecutor = ExecutionContext.fromExecutorService( - Executors.newSingleThreadExecutor()) + Executors.newFixedThreadPool(livyConf.getInt(RSCConf.Entry.SESSION_INTERPRETER_THREADS))) private val cancelExecutor = ExecutionContext.fromExecutorService( Executors.newSingleThreadExecutor()) @@ -70,11 +71,17 @@ class Session( // Number of statements kept in driver's memory private val numRetainedStatements = livyConf.getInt(RSCConf.Entry.RETAINED_STATEMENTS) - private val _statements = new JLinkedHashMap[Int, Statement] { - protected override def removeEldestEntry(eldest: Entry[Int, Statement]): Boolean = { - size() > numRetainedStatements - } - }.asScala + private val resultRetainedTimeout = + livyConf.getTimeAsMs(RSCConf.Entry.STATEMENT_RESULT_RETAINED_TIMEOUT) + + private val resultDiscardTimeout = + livyConf.getTimeAsMs(RSCConf.Entry.STATEMENT_RESULT_DISCARD_TIMEOUT) + + private val _statements = mutable.HashMap[Int, Statement]() + + private var _expiredTimestamps = mutable.HashMap[Int, Long]() + // record recently expired item to prevent search _expiredTimestamps all the time + private var _recentlyExpiredItem: (Int, Long) = (0, 0) private val newStatementId = new AtomicInteger(0) @@ -148,34 +155,48 @@ class Session( } def execute(code: String, codeType: String = null): Int = { - val tpe = if (codeType != null) { - Kind(codeType) - } else if (defaultInterpKind != Shared) { - defaultInterpKind + if (isOverload(newStatementId.get())) { + // return statementId -1 means reject current code + -1 } else { - throw new IllegalArgumentException(s"Code type should be specified if session kind is shared") - } - - val statementId = newStatementId.getAndIncrement() - val statement = new Statement(statementId, code, StatementState.Waiting, null) - _statements.synchronized { _statements(statementId) = statement } - - Future { - setJobGroup(tpe, statementId) - statement.compareAndTransit(StatementState.Waiting, StatementState.Running) - - if (statement.state.get() == StatementState.Running) { - statement.started = System.currentTimeMillis() - statement.output = executeCode(interpreter(tpe), statementId, code) + val tpe = if (codeType != null) { + Kind(codeType) + } else if (defaultInterpKind != Shared) { + defaultInterpKind + } else { + throw new IllegalArgumentException( + s"Code type should be specified if session kind is shared") } + + val statementId = newStatementId.getAndIncrement() + val statement = new Statement(statementId, code, StatementState.Waiting, null) + _statements.synchronized { _statements(statementId) = statement } + + Future { + this.synchronized { setJobGroup(tpe, statementId) } + statement.compareAndTransit(StatementState.Waiting, StatementState.Running) + + if (statement.state.get() == StatementState.Running) { + statement.started = System.currentTimeMillis() + statement.output = executeCode(interpreter(tpe), statementId, code) + } + + if (statement.state.get() == StatementState.Running) { + _expiredTimestamps.synchronized { + _expiredTimestamps(statement.id) = new Date().getTime + resultRetainedTimeout + } + } else { + markHasRead(statement.id) + } - statement.compareAndTransit(StatementState.Running, StatementState.Available) - statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) - statement.updateProgress(1.0) - statement.completed = System.currentTimeMillis() - }(interpreterExecutor) + statement.compareAndTransit(StatementState.Running, StatementState.Available) + statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + statement.updateProgress(1.0) + + }(interpreterExecutor) - statementId + statementId + } } def complete(code: String, codeType: String, cursor: Int): Array[String] = { @@ -361,4 +382,97 @@ class Session( private def statementIdToJobGroup(statementId: Int): String = { statementId.toString } + + private def snapshot(id: Int, msg: String): Unit = { + debug(s"statement No.${id} ${msg}") + val (running, o) = _statements.values.partition(_.state.get() == StatementState.Running) + debug(s"Buffer Size: ${numRetainedStatements}\tUsed Size: ${_statements.size}\t" + + s"Finished: ${_expiredTimestamps.size} Running: ${running.size} " + + s"Waiting: ${o.size-_expiredTimestamps.size}") + } + + /** + * return false when _statements size have not reach upper limit (numRetainedStatements) + * or some expired statement can be remove from _statements, otherwise return true + */ + private def isOverload(proposer: Int): Boolean = _statements.synchronized { + if (_statements.size < numRetainedStatements) { + snapshot(proposer, "will be accepted") + false + } else if (checkExpired) { + snapshot(proposer, s"will be accepted after cleanUpExpired") + cleanUpExpired() + false + } else { + snapshot(proposer, "is rejected") + true + } + } + + private def checkExpired: Boolean = _expiredTimestamps.synchronized { + if (_expiredTimestamps.size == 0) { + false + } else { + if (_recentlyExpiredItem._2 == 0) { + _recentlyExpiredItem = _expiredTimestamps.toArray.sortBy(_._2).head + } + new Date().getTime > _recentlyExpiredItem._2 + } + } + + private def cleanUpExpired(): Unit = _statements.synchronized { + _expiredTimestamps.synchronized { + assert(_expiredTimestamps.size > 0) + val sorted = mutable.PriorityQueue[(Int, Long)](_expiredTimestamps.toSeq: _*) + (Ordering.by[(Int, Long), Long](_._2).reverse) + _recentlyExpiredItem = sorted.dequeue() + val now = new Date().getTime + while (_recentlyExpiredItem._2 != 0 && _recentlyExpiredItem._2 <= now ) { + _statements.remove(_recentlyExpiredItem._1) + if (sorted.size > 0) { + _recentlyExpiredItem = sorted.dequeue() + } else { + _recentlyExpiredItem = (0, 0) + } + } + if (_recentlyExpiredItem._2 > 0) { + sorted.enqueue(_recentlyExpiredItem) + } + _expiredTimestamps = mutable.HashMap[Int, Long](sorted.toSeq: _*) + } + } + + def markHasRead(id: Integer): Unit = { + if (resultDiscardTimeout == 0) { + _statements.synchronized { + _expiredTimestamps.synchronized { + _statements.remove(id) + _expiredTimestamps.remove(id) + if (_recentlyExpiredItem._1 == id){ + _recentlyExpiredItem = (0, 0) + } + } + } + } else { + _expiredTimestamps.synchronized { + val expiredTime = new Date().getTime + resultDiscardTimeout + _expiredTimestamps(id) = expiredTime + if (expiredTime < _recentlyExpiredItem._2) { + _recentlyExpiredItem = (id, expiredTime) + } + } + } + snapshot(id, s"read success, expired after ${resultDiscardTimeout} ms") + } + + def getBufferState(): String = _statements.synchronized { + if (_statements.size < numRetainedStatements || checkExpired) { + "buffer is free, please try to resubmit the code" + } else { + "buffer is busy, maybe free after " + + s"${TimeUnit.SECONDS.convert(_recentlyExpiredItem._2 + - new Date().getTime, TimeUnit.MILLISECONDS)}s" + } + } + } diff --git a/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala index 34019071d..967171ddd 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala @@ -20,7 +20,11 @@ package org.apache.livy.repl import java.util.Properties import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} +import scala.concurrent.Await +import scala.concurrent.duration.Duration + import org.apache.spark.SparkConf +import org.json4s.JsonAST import org.scalatest.{BeforeAndAfter, FunSpec} import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually @@ -29,6 +33,7 @@ import org.scalatest.time._ import org.apache.livy.LivyBaseUnitTestSuite import org.apache.livy.repl.Interpreter.ExecuteResponse import org.apache.livy.rsc.RSCConf +import org.apache.livy.rsc.driver.SparkEntries import org.apache.livy.sessions._ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite with BeforeAndAfter { @@ -88,10 +93,19 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite wit } } - it("should remove old statements when reaching threshold") { + it("should remove expired statements when reaching threshold") { rscConf.set(RSCConf.Entry.RETAINED_STATEMENTS, 2) - session = new Session(rscConf, new SparkConf()) - session.start() + rscConf.set(RSCConf.Entry.STATEMENT_RESULT_RETAINED_TIMEOUT, "1s") + val interpreter = new SparkInterpreter(new SparkConf()) { + override def execute(code: String): ExecuteResponse = { + Interpreter.ExecuteSuccess(new org.json4s.JObject(List[JsonAST.JField]())) + } + override def postStart(): Unit = { + entries = new SparkEntries(conf) + } + } + session = new Session(rscConf, new SparkConf(), Some(interpreter)) + Await.result(session.start(), Duration.Inf) session.statements.size should be (0) session.execute("") @@ -100,10 +114,10 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite wit session.execute("") session.statements.size should be (2) session.statements.map(_._1).toSet should be (Set(0, 1)) - session.execute("") - eventually { - session.statements.size should be (2) - session.statements.map(_._1).toSet should be (Set(1, 2)) + eventually(timeout(Span(1500, Millis))) { + session.execute("") + session.statements.size should be (1) + session.statements.map(_._1).toSet should be (Set(2)) } // Continue submitting statements, total statements in memory should be 2. diff --git a/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java b/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java index 4c45956d7..856029d9c 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java +++ b/rsc/src/main/java/org/apache/livy/rsc/RSCConf.java @@ -44,6 +44,7 @@ public static enum Entry implements ConfEntry { CLIENT_SHUTDOWN_TIMEOUT("client.shutdown-timeout", "10s"), DRIVER_CLASS("driver-class", null), SESSION_KIND("session.kind", null), + SESSION_INTERPRETER_THREADS("session.interpreter.threadpool.size", 1), LIVY_JARS("jars", null), SPARKR_PACKAGE("sparkr.package", null), @@ -70,6 +71,9 @@ public static enum Entry implements ConfEntry { SASL_MECHANISMS("rpc.sasl.mechanisms", "DIGEST-MD5"), SASL_QOP("rpc.sasl.qop", null), + + STATEMENT_RESULT_RETAINED_TIMEOUT("result-retained.timeout", "1h"), + STATEMENT_RESULT_DISCARD_TIMEOUT("result-discard.timeout", "10m"), TEST_NO_CODE_COVERAGE_ANALYSIS("test.do-not-use.no-code-coverage-analysis", false), TEST_STUCK_END_SESSION("test.do-not-use.stuck-end-session", false), diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java b/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java index 787fc7793..f857f1752 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/StatementState.java @@ -24,6 +24,7 @@ import org.slf4j.LoggerFactory; public enum StatementState { + Rejected("rejected"), Waiting("waiting"), Running("running"), Available("available"),