diff --git a/server/src/main/scala/com/cloudera/livy/server/recovery/BlackholeStateStore.scala b/server/src/main/scala/com/cloudera/livy/server/recovery/BlackholeStateStore.scala index e0a33371b..b62e342d1 100644 --- a/server/src/main/scala/com/cloudera/livy/server/recovery/BlackholeStateStore.scala +++ b/server/src/main/scala/com/cloudera/livy/server/recovery/BlackholeStateStore.scala @@ -18,6 +18,8 @@ package com.cloudera.livy.server.recovery +import java.util.concurrent.atomic.AtomicLong + import scala.reflect.ClassTag import com.cloudera.livy.LivyConf @@ -27,6 +29,9 @@ import com.cloudera.livy.LivyConf * Livy will use this when session recovery is disabled. */ class BlackholeStateStore(livyConf: LivyConf) extends StateStore(livyConf) { + + val atomicLong: AtomicLong = new AtomicLong(-1L) + def set(key: String, value: Object): Unit = {} def get[T: ClassTag](key: String): Option[T] = None @@ -34,4 +39,6 @@ class BlackholeStateStore(livyConf: LivyConf) extends StateStore(livyConf) { def getChildren(key: String): Seq[String] = List.empty[String] def remove(key: String): Unit = {} + + override def increment(key: String): Long = atomicLong.incrementAndGet() } diff --git a/server/src/main/scala/com/cloudera/livy/server/recovery/FileSystemStateStore.scala b/server/src/main/scala/com/cloudera/livy/server/recovery/FileSystemStateStore.scala index d841c6328..9a692ccb6 100644 --- a/server/src/main/scala/com/cloudera/livy/server/recovery/FileSystemStateStore.scala +++ b/server/src/main/scala/com/cloudera/livy/server/recovery/FileSystemStateStore.scala @@ -120,4 +120,10 @@ class FileSystemStateStore( } private def absPath(key: String): Path = new Path(fsUri.getPath(), key) + + override def increment(key: String): Long = synchronized { + val incrementedValue = get[Long](key).getOrElse(-1L) + 1 + set(key, incrementedValue.asInstanceOf[Object]) + incrementedValue + } } diff --git a/server/src/main/scala/com/cloudera/livy/server/recovery/SessionStore.scala b/server/src/main/scala/com/cloudera/livy/server/recovery/SessionStore.scala index c6f2692b6..2b3e1b61b 100644 --- a/server/src/main/scala/com/cloudera/livy/server/recovery/SessionStore.scala +++ b/server/src/main/scala/com/cloudera/livy/server/recovery/SessionStore.scala @@ -27,8 +27,6 @@ import scala.util.control.NonFatal import com.cloudera.livy.{LivyConf, Logging} import com.cloudera.livy.sessions.Session.RecoveryMetadata -private[recovery] case class SessionManagerState(nextSessionId: Int) - /** * SessionStore provides high level functions to get/save session state from/to StateStore. */ @@ -64,18 +62,14 @@ class SessionStore( } /** - * Return the next unused session id with specified session type. - * It checks the SessionManagerState stored and returns the next free session id. - * If no SessionManagerState is stored, it returns 0. - * It saves the new session ID to the session store. + * Return the next unused session ID from state store with the specified session type. + * If no value is stored state store, it returns 0. + * It saves the next unused session ID to the session store before returning the current value. * - * @throws Exception If SessionManagerState stored is corrupted, it throws an error. + * @throws Exception If session store is corrupted or unreachable, it throws an error. */ - def getNextSessionId(sessionType: String): Int = synchronized { - val nextSessionId = store.get[SessionManagerState](sessionManagerPath(sessionType)) - .map(_.nextSessionId).getOrElse(0) - store.set(sessionManagerPath(sessionType), SessionManagerState(nextSessionId + 1)) - nextSessionId + def getNextSessionId(sessionType: String): Int = { + store.increment(sessionManagerPath(sessionType)).toInt } /** diff --git a/server/src/main/scala/com/cloudera/livy/server/recovery/StateStore.scala b/server/src/main/scala/com/cloudera/livy/server/recovery/StateStore.scala index 18cf6ade4..676dfcf7e 100644 --- a/server/src/main/scala/com/cloudera/livy/server/recovery/StateStore.scala +++ b/server/src/main/scala/com/cloudera/livy/server/recovery/StateStore.scala @@ -71,6 +71,13 @@ abstract class StateStore(livyConf: LivyConf) extends JsonMapper { * @throws Exception Throw when persisting the state store fails. */ def remove(key: String): Unit + + /** + * Gets the Long value for the given key, increments the value, and stores the new value before + * returning the value. + * @return incremented value + */ + def increment(key: String): Long } /** diff --git a/server/src/main/scala/com/cloudera/livy/server/recovery/ZooKeeperStateStore.scala b/server/src/main/scala/com/cloudera/livy/server/recovery/ZooKeeperStateStore.scala index 883383590..2adf869ab 100644 --- a/server/src/main/scala/com/cloudera/livy/server/recovery/ZooKeeperStateStore.scala +++ b/server/src/main/scala/com/cloudera/livy/server/recovery/ZooKeeperStateStore.scala @@ -19,9 +19,13 @@ package com.cloudera.livy.server.recovery import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.util.Try +import scala.util.matching.Regex +import org.apache.curator.RetryPolicy import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} import org.apache.curator.framework.api.UnhandledErrorListener +import org.apache.curator.framework.recipes.atomic.{DistributedAtomicLong => DistributedLong} import org.apache.curator.retry.RetryNTimes import org.apache.zookeeper.KeeperException.NoNodeException @@ -46,18 +50,22 @@ class ZooKeeperStateStore( } private val zkAddress = livyConf.get(LivyConf.RECOVERY_STATE_STORE_URL) + require(!zkAddress.isEmpty, s"Please config ${LivyConf.RECOVERY_STATE_STORE_URL.key}.") - private val zkKeyPrefix = livyConf.get(ZK_KEY_PREFIX_CONF) - private val curatorClient = mockCuratorClient.getOrElse { - val retryValue = livyConf.get(ZK_RETRY_CONF) + + private val retryValue = livyConf.get(ZK_RETRY_CONF) + private val retryPolicy = Try { + // a regex to match patterns like "m, n" where m and m both are integer values val retryPattern = """\s*(\d+)\s*,\s*(\d+)\s*""".r - val retryPolicy = retryValue match { - case retryPattern(n, sleepMs) => new RetryNTimes(5, 100) - case _ => throw new IllegalArgumentException( - s"$ZK_KEY_PREFIX_CONF contains bad value: $retryValue. " + - "Correct format is ,. e.g. 5,100") - } + val retryPattern(retryTimes, sleepMsBetweenRetries) = retryValue + new RetryNTimes(retryTimes.toInt, sleepMsBetweenRetries.toInt) + }.getOrElse { throw new IllegalArgumentException( + s"$ZK_RETRY_CONF contains bad value: $retryValue. " + + "Correct format is ,. e.g. 5,100") + } + private val zkKeyPrefix = livyConf.get(ZK_KEY_PREFIX_CONF) + private val curatorClient = mockCuratorClient.getOrElse { CuratorFrameworkFactory.newClient(zkAddress, retryPolicy) } @@ -113,5 +121,15 @@ class ZooKeeperStateStore( } } + override def increment(key: String): Long = { + val distributedSessionId = new DistributedLong(curatorClient, key, retryPolicy) + distributedSessionId.increment() match { + case atomicValue if atomicValue.succeeded() => + atomicValue.postValue() + case _ => + throw new java.io.IOException(s"Failed to atomically increment the value for $key") + } + } + private def prefixKey(key: String) = s"/$zkKeyPrefix/$key" } diff --git a/server/src/test/scala/com/cloudera/livy/server/recovery/SessionStoreSpec.scala b/server/src/test/scala/com/cloudera/livy/server/recovery/SessionStoreSpec.scala index 25c0a1b95..9b3a88dd2 100644 --- a/server/src/test/scala/com/cloudera/livy/server/recovery/SessionStoreSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/recovery/SessionStoreSpec.scala @@ -89,12 +89,11 @@ class SessionStoreSpec extends FunSpec with LivyBaseUnitTestSuite { val stateStore = mock[StateStore] val sessionStore = new SessionStore(conf, stateStore) - when(stateStore.get[SessionManagerState](sessionManagerPath)).thenReturn(None) + when(stateStore.increment(sessionManagerPath)).thenReturn(0L) sessionStore.getNextSessionId(sessionType) shouldBe 0 - val sms = SessionManagerState(100) - when(stateStore.get[SessionManagerState](sessionManagerPath)).thenReturn(Some(sms)) - sessionStore.getNextSessionId(sessionType) shouldBe sms.nextSessionId + when(stateStore.increment(sessionManagerPath)).thenReturn(100) + sessionStore.getNextSessionId(sessionType) shouldBe 100 } it("should remove session") {