diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientActor.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientActor.scala new file mode 100644 index 0000000..531ff35 --- /dev/null +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientActor.scala @@ -0,0 +1,277 @@ +/** + * Copyright (C) 2015 Stratio (http://stratio.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.stratio.datasource.mongodb.client + +import javax.net.ssl.SSLSocketFactory + +import akka.actor.Actor +import com.mongodb.ServerAddress +import com.mongodb.casbah.Imports._ +import com.mongodb.casbah.{MongoClient, MongoClientOptions} +import com.stratio.datasource.mongodb.client.MongodbClientActor._ +import com.stratio.datasource.mongodb.client.MongodbClientFactory._ +import com.stratio.datasource.mongodb.config.MongodbConfig.{ReadPreference => ProviderReadPreference, _} +import com.stratio.datasource.mongodb.config.MongodbSSLOptions + +import scala.annotation.tailrec +import scala.util.Try + + +class MongodbClientActor extends Actor { + + private val KeySeparator = "-" + + private val CloseSleepTime = 1000 + + private val mongoClient: scala.collection.mutable.Map[String, MongodbConnection] = + scala.collection.mutable.Map.empty[String, MongodbConnection] + + override def receive = { + case CheckConnections => doCheckConnections() + case GetClient(host) => doGetClient(host) + case GetClientWithMongoDbConfig(hostPort, credentials, optionSSLOptions, clientOptions) => + doGetClientWithMongoDbConfig(hostPort, credentials, optionSSLOptions, clientOptions) + case GetClientWithUser(host, port, user, database, password) => + doGetClientWithUser(host, port, user, database, password) + case SetFreeConnectionByKey(clientKey, extendedTime) => doSetFreeConnectionByKey(clientKey, extendedTime) + case SetFreeConnectionsByClient(client, extendedTime) => doSetFreeConnectionByClient(client, extendedTime) + case CloseAll(gracefully, attempts) => doCloseAll(gracefully, attempts) + case CloseByClient(client, gracefully) => doCloseByClient(client, gracefully) + case CloseByKey(clientKey, gracefully) => doCloseByKey(clientKey, gracefully) + case GetSize => doGetSize + } + + private def doCheckConnections(): Unit = { + val currentTime = System.currentTimeMillis() + mongoClient.foreach { case (key, connection) => + if ((connection.status == ConnectionStatus.Free) && (connection.timeOut <= currentTime)) { + connection.client.close() + mongoClient.remove(key) + } + } + } + + private def doGetClientWithMongoDbConfig(hostPort: List[ServerAddress], + credentials: List[MongoCredential], + optionSSLOptions: Option[MongodbSSLOptions], + clientOptions: Map[String, Any]): Unit = { + val connKey = connectionKey(0, hostPort, credentials, clientOptions) + val (finalKey, connection) = mongoClient.get(connKey) match { + case Some(client) => + if (client.status == ConnectionStatus.Free) (connKey, client) + else createClient(1 + connKey, hostPort, credentials, optionSSLOptions, clientOptions) + case None => createClient(connKey, hostPort, credentials, optionSSLOptions, clientOptions) + } + + mongoClient.update(finalKey, connection.copy( + timeOut = System.currentTimeMillis() + + extractValue[String](clientOptions, ConnectionsTime).map(_.toLong).getOrElse(DefaultConnectionsTime), + status = ConnectionStatus.Busy)) + + sender ! ClientResponse(finalKey, connection.client) + } + + private def doGetClientWithUser(host: String, port: Int, user: String, database: String, password: String): Unit = { + val credentials = List(MongoCredential.createCredential(user, database, password.toCharArray)) + val hostPort = new ServerAddress(host, port) + val connKey = connectionKey(0, List(hostPort), credentials) + + val (finalKey, connection) = mongoClient.get(connKey) match { + case Some(client) => + if (client.status == ConnectionStatus.Free) (connKey, client) + else createClient(1 + connKey, List(hostPort), credentials) + case None => createClient(connKey, List(hostPort), credentials) + } + + mongoClient.update(finalKey, connection.copy( + timeOut = System.currentTimeMillis() + DefaultConnectionsTime, + status = ConnectionStatus.Busy)) + + sender ! ClientResponse(finalKey, connection.client) + } + + private def doGetClient(host: String): Unit = { + val hostPort = new ServerAddress(host) + val connKey = connectionKey(0, List(hostPort)) + val (finalKey, connection) = mongoClient.get(connKey) match { + case Some(client) => + if (client.status == ConnectionStatus.Free) (connKey, client) + else createClient(connKey, host) + case None => createClient(connKey, host) + } + + mongoClient.update(finalKey, connection.copy( + timeOut = System.currentTimeMillis() + DefaultConnectionsTime, + status = ConnectionStatus.Busy)) + + sender ! ClientResponse(finalKey, connection.client) + } + + private def doSetFreeConnectionByKey(clientKey: String, extendedTime: Option[Long]): Unit = { + mongoClient.get(clientKey).foreach(clientFound => { + mongoClient.update(clientKey, clientFound.copy(status = ConnectionStatus.Free, + timeOut = System.currentTimeMillis() + extendedTime.getOrElse(DefaultConnectionsTime))) + }) + } + + private def doSetFreeConnectionByClient(client: Client, extendedTime: Option[Long]): Unit = { + mongoClient.find { case (key, clientSearch) => clientSearch.client == client } + .foreach { case (key, clientFound) => + mongoClient.update(key, clientFound.copy(status = ConnectionStatus.Free, + timeOut = System.currentTimeMillis() + extendedTime.getOrElse(DefaultConnectionsTime))) + } + } + + private def doCloseAll(gracefully: Boolean, attempts: Int): Unit = { + mongoClient.foreach { case (key, connection) => + if (!gracefully || connection.status == ConnectionStatus.Free) { + connection.client.close() + mongoClient.remove(key) + } + } + if (mongoClient.nonEmpty && attempts > 0) { + Thread.sleep(CloseSleepTime) + doCloseAll(gracefully, attempts - 1) + } + } + + private def doCloseByClient(client: Client, gracefully: Boolean): Unit = { + mongoClient.find { case (key, clientSearch) => clientSearch.client == client } + .foreach { case (key, clientFound) => + if (!gracefully || clientFound.status == ConnectionStatus.Free) { + clientFound.client.close() + mongoClient.remove(key) + } + } + } + + private def doCloseByKey(clientKey: String, gracefully: Boolean): Unit = { + mongoClient.get(clientKey).foreach(clientFound => { + if (!gracefully || clientFound.status == ConnectionStatus.Free) { + clientFound.client.close() + mongoClient.remove(clientKey) + } + }) + } + + private def doGetSize() : Unit = sender ! mongoClient.size + + private def createClient(key: String, host: String): (String, MongodbConnection) = { + saveConnection(key, MongodbConnection(MongoClient(host))) + } + + private def createClient(key: String, + hostPort: List[ServerAddress], + credentials: List[MongoCredential] = List(), + optionSSLOptions: Option[MongodbSSLOptions] = None, + clientOptions: Map[String, Any] = Map()): (String, MongodbConnection) = { + + val options = { + + val builder = new MongoClientOptions.Builder() + .readPreference(extractValue[String](clientOptions, ProviderReadPreference) match { + case Some(preference) => parseReadPreference(preference) + case None => DefaultReadPreference + }) + .connectTimeout(extractValue[String](clientOptions, ConnectTimeout).map(_.toInt) + .getOrElse(DefaultConnectTimeout)) + .connectionsPerHost(extractValue[String](clientOptions, ConnectionsPerHost).map(_.toInt) + .getOrElse(DefaultConnectionsPerHost)) + .maxWaitTime(extractValue[String](clientOptions, MaxWaitTime).map(_.toInt) + .getOrElse(DefaultMaxWaitTime)) + .threadsAllowedToBlockForConnectionMultiplier(extractValue[String](clientOptions, ThreadsAllowedToBlockForConnectionMultiplier).map(_.toInt) + .getOrElse(DefaultThreadsAllowedToBlockForConnectionMultiplier)) + + if (sslBuilder(optionSSLOptions)) builder.socketFactory(SSLSocketFactory.getDefault()) + + builder.build() + } + + saveConnection(key, MongodbConnection(MongoClient(hostPort, credentials, options))) + } + + @tailrec + private def saveConnection(key: String, mongoDbConnection: MongodbConnection): (String, MongodbConnection) = { + mongoClient.put(key, mongoDbConnection) match { + case Some(_) => + val splittedKey = key.split(KeySeparator) + val index = splittedKey.headOption match { + case Some(indexNumber) => Try(indexNumber.toInt + 1).getOrElse(0) + case None => 0 + } + saveConnection(s"$index${splittedKey.drop(1).mkString(KeySeparator)}", mongoDbConnection) + case None => (key, mongoDbConnection) + } + } + + /** + * Create the connection string for the concurrent hashMap, the params make the unique key + * @param index Index for the same concurrent connections to the same database with the same options + * @param hostPort List of servers addresses + * @param credentials Credentials for connect + * @param clientOptions All options for the client connections + * @return The calculated string + */ + @tailrec + private def connectionKey(index: Int, + hostPort: List[ServerAddress], + credentials: List[MongoCredential] = List(), + clientOptions: Map[String, Any] = Map()): String = { + val key = if (clientOptions.nonEmpty) + s"$index-${clientOptions.mkString(KeySeparator)}" + else s"$index$KeySeparator${hostPort.mkString(KeySeparator)}$KeySeparator${credentials.mkString(KeySeparator)}" + + val clientFound = mongoClient.find { case (clientKey, connection) => + clientKey == key && connection.status == ConnectionStatus.Busy + } + + clientFound match { + case Some(client) => connectionKey(index + 1, hostPort, credentials, clientOptions) + case None => key + } + } + +} + +object MongodbClientActor { + + case object CheckConnections + + case class GetClient(host: String) + + case class GetClientWithUser(host: String, port: Int, user: String, database: String, password: String) + + case class GetClientWithMongoDbConfig(hostPort: List[ServerAddress], + credentials: List[MongoCredential], + optionSSLOptions: Option[MongodbSSLOptions], + clientOptions: Map[String, Any]) + + case class ClientResponse(key: String, clientConnection: Client) + + case class SetFreeConnectionsByClient(client: Client, extendedTime: Option[Long]) + + case class SetFreeConnectionByKey(clientKey: String, extendedTime: Option[Long]) + + case class CloseAll(gracefully: Boolean, attempts: Int) + + case class CloseByClient(client: Client, gracefully: Boolean) + + case class CloseByKey(clientKey: String, gracefully: Boolean) + + case object GetSize + +} diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientFactory.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientFactory.scala index 03482de..7329e86 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientFactory.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientFactory.scala @@ -17,19 +17,19 @@ package com.stratio.datasource.mongodb.client import java.util.concurrent._ -import javax.net.ssl.SSLSocketFactory -import akka.actor.ActorSystem +import akka.actor.{ActorSystem, Props} +import akka.pattern.ask +import akka.util.Timeout import com.mongodb.ServerAddress import com.mongodb.casbah.Imports._ -import com.mongodb.casbah.{MongoClient, MongoClientOptions} -import com.stratio.datasource.mongodb.config.MongodbConfig.{ReadPreference => ProviderReadPreference, _} -import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbSSLOptions} +import com.mongodb.casbah.MongoClient +import com.stratio.datasource.mongodb.client.MongodbClientActor._ +import com.stratio.datasource.mongodb.config.MongodbSSLOptions -import scala.annotation.tailrec -import scala.collection.JavaConversions._ +import scala.concurrent.Await import scala.concurrent.duration._ -import scala.util.Try + /** * Different client configurations to Mongodb database @@ -38,65 +38,34 @@ object MongodbClientFactory { type Client = MongoClient - val KeySeparator = "-" - - val CloseSleepTime = 1000 - - val CloseAttempts = 180 - - /** - * MongoDb Client connections are saved in a Concurrent hashMap, is used to have only one concurrent connection for - * each operation, when this are finished the connection are reused - */ - var mongoClient: scala.collection.concurrent.Map[String, MongodbConnection] = - new ConcurrentHashMap[String, MongodbConnection]() + private val CloseAttempts = 120 /** * Scheduler that close connections automatically when the timeout was expired */ - val actorSystem = ActorSystem() - val scheduler = actorSystem.scheduler - val SecondsToCheckConnections = 60 - val task = new Runnable { - def run() { - synchronized { - val currentTime = System.currentTimeMillis() - mongoClient.foreach { case (key, connection) => - if((connection.status == ConnectionStatus.Free) && (connection.timeOut <= currentTime)) { - connection.client.close() - mongoClient.remove(key) - } - } - } - } - } - implicit val executor = actorSystem.dispatcher + private val actorSystem = ActorSystem() + private val scheduler = actorSystem.scheduler + private val SecondsToCheckConnections = 60 + private val mongoConnectionsActor = actorSystem.actorOf(Props(new MongodbClientActor), "mongoConnectionActor") + + private implicit val executor = actorSystem.dispatcher + private implicit val timeout: Timeout = Timeout(3.seconds) + scheduler.schedule( initialDelay = Duration(SecondsToCheckConnections, TimeUnit.SECONDS), interval = Duration(SecondsToCheckConnections, TimeUnit.SECONDS), - runnable = task) + mongoConnectionsActor, + CheckConnections) /** * Get or Create one client connection to MongoDb * @param host Ip or Dns to connect - * @return Client connection + * @return Client connection with identifier */ - def getClient(host: String): (String, Client) = { - synchronized { - val hostPort = new ServerAddress(host) - val connKey = connectionKey(0, List(hostPort)) - val (finalKey, connection) = mongoClient.get(connKey) match { - case Some(client) => - if (client.status == ConnectionStatus.Free) (connKey, client) - else createClient(connKey, host) - case None => createClient(connKey, host) - } - - mongoClient.update(finalKey, connection.copy( - timeOut = System.currentTimeMillis() + DefaultConnectionsTime, - status = ConnectionStatus.Busy)) - - (finalKey, connection.client) + def getClient(host: String): ClientResponse = { + val futureResult = mongoConnectionsActor ? GetClient(host) + Await.result(futureResult, timeout.duration) match { + case ClientResponse(key, clientConnection) => ClientResponse(key, clientConnection) } } @@ -107,26 +76,12 @@ object MongodbClientFactory { * @param user User for credentials * @param database Database for credentials * @param password Password for credentials - * @return Client connection + * @return Client connection with identifier */ - def getClient(host: String, port: Int, user: String, database: String, password: String): (String, Client) = { - synchronized { - val credentials = List(MongoCredential.createCredential(user, database, password.toCharArray)) - val hostPort = new ServerAddress(host, port) - val connKey = connectionKey(0, List(hostPort), credentials) - - val (finalKey, connection) = mongoClient.get(connKey) match { - case Some(client) => - if (client.status == ConnectionStatus.Free) (connKey, client) - else createClient(1 + connKey, List(hostPort), credentials) - case None => createClient(connKey, List(hostPort), credentials) - } - - mongoClient.update(finalKey, connection.copy( - timeOut = System.currentTimeMillis() + DefaultConnectionsTime, - status = ConnectionStatus.Busy)) - - (finalKey, connection.client) + def getClient(host: String, port: Int, user: String, database: String, password: String): ClientResponse = { + val futureResult = mongoConnectionsActor ? GetClientWithUser(host, port, user, database, password) + Await.result(futureResult, timeout.duration) match { + case ClientResponse(key, clientConnection) => ClientResponse(key, clientConnection) } } @@ -136,118 +91,26 @@ object MongodbClientFactory { * @param credentials Credentials to connect * @param optionSSLOptions SSL options for secure connections * @param clientOptions All options for the client connections - * @return Client connection + * @return Client connection with identifier */ def getClient(hostPort: List[ServerAddress], credentials: List[MongoCredential] = List(), optionSSLOptions: Option[MongodbSSLOptions] = None, - clientOptions: Map[String, Any] = Map()): (String, Client) = { - synchronized { - val connKey = connectionKey(0, hostPort, credentials, clientOptions) - val (finalKey, connection) = mongoClient.get(connKey) match { - case Some(client) => - if (client.status == ConnectionStatus.Free) (connKey,client) - else createClient(1 + connKey, hostPort, credentials, optionSSLOptions, clientOptions) - case None => createClient(connKey, hostPort, credentials, optionSSLOptions, clientOptions) - } - - mongoClient.update(finalKey, connection.copy( - timeOut = System.currentTimeMillis() + - extractValue[String](clientOptions, ConnectionsTime).map(_.toLong).getOrElse(DefaultConnectionsTime), - status = ConnectionStatus.Busy)) - - (finalKey, connection.client) + clientOptions: Map[String, Any] = Map()): ClientResponse = { + val futureResult = + mongoConnectionsActor ? GetClientWithMongoDbConfig(hostPort, credentials, optionSSLOptions, clientOptions) + Await.result(futureResult, timeout.duration) match { + case ClientResponse(key, clientConnection) => ClientResponse(key, clientConnection) } } - private def createClient(key: String , host: String): (String ,MongodbConnection) = { - saveConnection(key, MongodbConnection(MongoClient(host))) - } - - private def createClient(key: String, - hostPort: List[ServerAddress], - credentials: List[MongoCredential] = List(), - optionSSLOptions: Option[MongodbSSLOptions] = None, - clientOptions: Map[String, Any] = Map()): (String ,MongodbConnection) = { - - val options = { - - val builder = new MongoClientOptions.Builder() - .readPreference(extractValue[String](clientOptions, ProviderReadPreference) match { - case Some(preference) => parseReadPreference(preference) - case None => DefaultReadPreference - }) - .connectTimeout(extractValue[String](clientOptions, ConnectTimeout).map(_.toInt).getOrElse(DefaultConnectTimeout)) - .connectionsPerHost(extractValue[String](clientOptions, ConnectionsPerHost).map(_.toInt).getOrElse(DefaultConnectionsPerHost)) - .maxWaitTime(extractValue[String](clientOptions, MaxWaitTime).map(_.toInt).getOrElse(DefaultMaxWaitTime)) - .threadsAllowedToBlockForConnectionMultiplier(extractValue[String](clientOptions, ThreadsAllowedToBlockForConnectionMultiplier).map(_.toInt).getOrElse(DefaultThreadsAllowedToBlockForConnectionMultiplier)) - - if (sslBuilder(optionSSLOptions)) builder.socketFactory(SSLSocketFactory.getDefault()) - - builder.build() - } - - saveConnection(key, MongodbConnection(MongoClient(hostPort, credentials, options))) - } - - @tailrec - private def saveConnection(key: String, mongoDbConnection : MongodbConnection) : (String, MongodbConnection) = { - mongoClient.putIfAbsent(key, mongoDbConnection) match { - case Some(_) => - val splittedKey = key.split(KeySeparator) - val index = splittedKey.headOption match { - case Some(indexNumber) => Try(indexNumber.toInt + 1).getOrElse(0) - case None => 0 - } - saveConnection(s"$index${splittedKey.drop(1).mkString(KeySeparator)}", mongoDbConnection) - case None => (key, mongoDbConnection) - } - } - - /** - * Create the connection string for the concurrent hashMap, the params make the unique key - * @param index Index for the same concurrent connections to the same database with the same options - * @param hostPort List of servers addresses - * @param credentials Credentials for connect - * @param clientOptions All options for the client connections - * @return The calculated string - */ - @tailrec - private def connectionKey(index : Int, - hostPort: List[ServerAddress], - credentials: List[MongoCredential] = List(), - clientOptions: Map[String, Any] = Map()): String = { - val key = if (clientOptions.nonEmpty) - s"$index-${clientOptions.mkString(KeySeparator)}" - else s"$index$KeySeparator${hostPort.mkString(KeySeparator)}$KeySeparator${credentials.mkString(KeySeparator)}" - - val clientFound = mongoClient.find { case(clientKey, connection) => - clientKey == key && connection.status == ConnectionStatus.Busy - } - - clientFound match { - case Some(client) => connectionKey(index + 1, hostPort, credentials, clientOptions) - case None => key - } - } /** * Close all client connections on the concurrent map * @param gracefully Close the connections if is free */ - def closeAll(gracefully : Boolean = true, attempts : Int = CloseAttempts): Unit = { - synchronized { - mongoClient.foreach { case (key, connection) => - if (!gracefully || connection.status == ConnectionStatus.Free) { - connection.client.close() - mongoClient.remove(key) - } - } - if (mongoClient.nonEmpty && attempts > 0) { - Thread.sleep(CloseSleepTime) - closeAll(gracefully, attempts - 1) - } - } + def closeAll(gracefully: Boolean = true, attempts: Int = CloseAttempts): Unit = { + mongoConnectionsActor ! CloseAll(gracefully, attempts) } /** @@ -255,16 +118,8 @@ object MongodbClientFactory { * @param client client value for connect to MongoDb * @param gracefully Close the connection if is free */ - def close(client: Client, gracefully: Boolean = true): Unit = { - synchronized { - mongoClient.find { case (key, clientSearch) => clientSearch.client == client } - .foreach { case (key, clientFound) => - if (!gracefully || clientFound.status == ConnectionStatus.Free) { - clientFound.client.close() - mongoClient.remove(key) - } - } - } + def closeByClient(client: Client, gracefully: Boolean = true): Unit = { + mongoConnectionsActor ! CloseByClient(client, gracefully) } /** @@ -273,46 +128,36 @@ object MongodbClientFactory { * @param gracefully Close the connection if is free */ def closeByKey(clientKey: String, gracefully: Boolean = true): Unit = { - synchronized { - mongoClient.get(clientKey).foreach(clientFound => { - if (!gracefully || clientFound.status == ConnectionStatus.Free) { - clientFound.client.close() - mongoClient.remove(clientKey) - } - }) - } + mongoConnectionsActor ! CloseByKey(clientKey, gracefully) } /** * Set Free the connection that have the same client as the client param * @param client client value for connect to MongoDb */ - def setFreeConnection(client: Client, extendedTime : Option[Long] = None): Unit = { - synchronized { - mongoClient.find { case (key, clientSearch) => clientSearch.client == client } - .foreach { case (key, clientFound) => - mongoClient.update(key, clientFound.copy(status = ConnectionStatus.Free, - timeOut = System.currentTimeMillis() + extendedTime.getOrElse(DefaultConnectionsTime))) - } - } + def setFreeConnectionByClient(client: Client, extendedTime: Option[Long] = None): Unit = { + mongoConnectionsActor ! SetFreeConnectionsByClient(client, extendedTime) } /** * Set Free the connection that have the same key as the clientKey param * @param clientKey key pre calculated with the connection options */ - def setFreeConnectionByKey(clientKey: String, extendedTime : Option[Long] = None): Unit = { - synchronized { - mongoClient.get(clientKey).foreach(clientFound => { - mongoClient.update(clientKey, clientFound.copy(status = ConnectionStatus.Free, - timeOut = System.currentTimeMillis() + extendedTime.getOrElse(DefaultConnectionsTime))) - }) + def setFreeConnectionByKey(clientKey: String, extendedTime: Option[Long] = None): Unit = { + mongoConnectionsActor ! SetFreeConnectionByKey(clientKey, extendedTime) + } + + def getClientPoolSize: Int = { + val futureResult = mongoConnectionsActor ? GetSize + Await.result(futureResult, timeout.duration) match { + case size: Int => size } } - private def extractValue[T](options: Map[String, Any], key: String): Option[T] = options.get(key).map(_.asInstanceOf[T]) + def extractValue[T](options: Map[String, Any], key: String): Option[T] = + options.get(key).map(_.asInstanceOf[T]) - private def sslBuilder(optionSSLOptions: Option[MongodbSSLOptions]): Boolean = + def sslBuilder(optionSSLOptions: Option[MongodbSSLOptions]): Boolean = optionSSLOptions.exists(sslOptions => { if (sslOptions.keyStore.nonEmpty) { System.setProperty("javax.net.ssl.keyStore", sslOptions.keyStore.get) @@ -326,4 +171,5 @@ object MongodbClientFactory { } true }) + } diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala index cd4f35f..0e28e1e 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala @@ -58,12 +58,12 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] { private val cursorBatchSize = config.getOrElse[Int](MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize) override def computePartitions(): Array[MongodbPartition] = { - val mongoClient = MongodbClientFactory.getClient(hosts, credentials, ssloptions, clientOptions)._2 + val mongoClient = MongodbClientFactory.getClient(hosts, credentials, ssloptions, clientOptions) - val result = if (isShardedCollection(mongoClient)) - computeShardedChunkPartitions(mongoClient) + val result = if (isShardedCollection(mongoClient.clientConnection)) + computeShardedChunkPartitions(mongoClient.clientConnection) else - computeNotShardedPartitions(mongoClient) + computeNotShardedPartitions(mongoClient.clientConnection) result } @@ -158,14 +158,14 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] { .find(MongoDBObject("_id" -> stats.getString("primary"))).batchSize(cursorBatchSize) val shard = shards.next() val shardHost: String = shard.as[String]("host").replace(shard.get("_id") + "/", "") - val (shardClientKey, shardClient) = MongodbClientFactory.getClient(shardHost) - val data = shardClient.getDB("admin").command(cmd) + val shardClient = MongodbClientFactory.getClient(shardHost) + val data = shardClient.clientConnection.getDB("admin").command(cmd) val splitKeys = data.as[List[DBObject]]("splitKeys").map(Option(_)) val ranges = (None +: splitKeys) zip (splitKeys :+ None) shards.close() - MongodbClientFactory.setFreeConnectionByKey(shardClientKey, connectionsTime) + MongodbClientFactory.setFreeConnectionByKey(shardClient.key, connectionsTime) ranges.toSeq }.getOrElse(Seq((None, None))) diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala index 7aa1c98..4bc05db 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala @@ -57,8 +57,8 @@ class MongodbReader(config: Config, mongoClient.fold(ifEmpty = ()) { client => mongoClientKey.fold({ - MongodbClientFactory.setFreeConnection(client, connectionsTime) - MongodbClientFactory.close(client) + MongodbClientFactory.setFreeConnectionByClient(client, connectionsTime) + MongodbClientFactory.closeByClient(client) }) {key => MongodbClientFactory.setFreeConnectionByKey(key, connectionsTime) MongodbClientFactory.closeByKey(key) @@ -91,9 +91,9 @@ class MongodbReader(config: Config, val sslOptions = config.get[MongodbSSLOptions](MongodbConfig.SSLOptions) val clientOptions = config.properties.filterKeys(_.contains(MongodbConfig.ListMongoClientOptions)) - val (clientKey, client) = MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions) - mongoClient = Option(client) - mongoClientKey = Option(clientKey) + val mongoClientResponse = MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions) + mongoClient = Option(mongoClientResponse.clientConnection) + mongoClientKey = Option(mongoClientResponse.key) val emptyFilter = MongoDBObject(List()) val filter = Try(queryPartition(filters)).getOrElse(emptyFilter) diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala index b054c97..057a204 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala @@ -56,14 +56,14 @@ abstract class MongodbWriter(config: Config) extends Serializable { private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong) - protected val (clientKey, mongoClient) = + protected val mongoClient = MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions) /** * A MongoDB collection created from the specified database and collection. */ protected val dbCollection: MongoCollection = - mongoClient(config(MongodbConfig.Database))(config(MongodbConfig.Collection)) + mongoClient.clientConnection(config(MongodbConfig.Database))(config(MongodbConfig.Collection)) /** * Abstract method that checks if a primary key exists in provided configuration @@ -104,7 +104,7 @@ abstract class MongodbWriter(config: Config) extends Serializable { * Free current MongoDB client. */ def freeConnection(): Unit = { - MongodbClientFactory.setFreeConnection(mongoClient, connectionsTime) + MongodbClientFactory.setFreeConnectionByKey(mongoClient.key, connectionsTime) } } \ No newline at end of file diff --git a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala index b748c9c..5eeabb1 100644 --- a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala +++ b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala @@ -21,16 +21,20 @@ import com.stratio.datasource.MongodbTestConstants import com.stratio.datasource.mongodb.config.MongodbSSLOptions import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner -import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfter, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) -class MongodbClientFactoryTest extends FlatSpec with Matchers with MongodbTestConstants with BeforeAndAfter { +class MongodbClientFactoryTest extends FlatSpec +with Matchers +with MongodbTestConstants +with BeforeAndAfter +with BeforeAndAfterAll { type Client = MongoClient - val hostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val hostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - val hostPortCredentialsClient = MongodbClientFactory.getClient("127.0.0.1", 27017, "user", "database", "password")._2 + val hostPortCredentialsClient = MongodbClientFactory.getClient("127.0.0.1", 27017, "user", "database", "password").clientConnection val fullClient = MongodbClientFactory.getClient( List(new ServerAddress("127.0.0.1:27017")), @@ -44,7 +48,7 @@ class MongodbClientFactoryTest extends FlatSpec with Matchers with MongodbTestCo "connectionsPerHost" -> "20", "threadsAllowedToBlockForConnectionMultiplier" -> "5" ) - )._2 + ).clientConnection val gracefully = true @@ -58,82 +62,90 @@ class MongodbClientFactoryTest extends FlatSpec with Matchers with MongodbTestCo hostClient shouldBe a [Client] hostPortCredentialsClient shouldBe a [Client] fullClient shouldBe a [Client] + + MongodbClientFactory.closeAll(notGracefully) } it should "Valid clients size when getting the same client " in { - val sameHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val sameHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - MongodbClientFactory.mongoClient.size should be (1) + MongodbClientFactory.getClientPoolSize should be (1) - val otherHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val otherHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - MongodbClientFactory.mongoClient.size should be (2) + MongodbClientFactory.getClientPoolSize should be (2) + + MongodbClientFactory.closeAll(notGracefully) } it should "Valid clients size when getting the same client and set free " in { - val sameHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val sameHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - MongodbClientFactory.mongoClient.size should be (1) + MongodbClientFactory.getClientPoolSize should be (1) - MongodbClientFactory.setFreeConnection(sameHostClient) + MongodbClientFactory.setFreeConnectionByClient(sameHostClient) - val otherHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val otherHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - MongodbClientFactory.mongoClient.size should be (1) + MongodbClientFactory.getClientPoolSize should be (1) + + MongodbClientFactory.closeAll(notGracefully) } it should "Valid clients size when closing one client gracefully " in { - val sameHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val sameHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - MongodbClientFactory.mongoClient.size should be (1) + MongodbClientFactory.getClientPoolSize should be (1) - MongodbClientFactory.close(sameHostClient) + MongodbClientFactory.closeByClient(sameHostClient) - MongodbClientFactory.mongoClient.size should be (1) + MongodbClientFactory.getClientPoolSize should be (1) + + MongodbClientFactory.closeAll(notGracefully) } it should "Valid clients size when closing one client not gracefully " in { - val sameHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val sameHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - MongodbClientFactory.mongoClient.size should be (1) + MongodbClientFactory.getClientPoolSize should be (1) - MongodbClientFactory.close(sameHostClient, notGracefully) + MongodbClientFactory.closeByClient(sameHostClient, notGracefully) - MongodbClientFactory.mongoClient.size should be (0) + MongodbClientFactory.getClientPoolSize should be (0) + + MongodbClientFactory.closeAll(notGracefully) } it should "Valid clients size when closing all clients gracefully " in { - val sameHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 - val otherHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val sameHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection + val otherHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection - MongodbClientFactory.mongoClient.size should be (2) + MongodbClientFactory.getClientPoolSize should be (2) MongodbClientFactory.closeAll(gracefully, 1) - MongodbClientFactory.mongoClient.size should be (2) + MongodbClientFactory.getClientPoolSize should be (2) - MongodbClientFactory.setFreeConnection(sameHostClient) + MongodbClientFactory.setFreeConnectionByClient(sameHostClient) MongodbClientFactory.closeAll(gracefully, 1) - MongodbClientFactory.mongoClient.size should be (1) + MongodbClientFactory.getClientPoolSize should be (1) + + MongodbClientFactory.closeAll(notGracefully) } it should "Valid clients size when closing all clients not gracefully " in { - val sameHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 - val otherHostClient = MongodbClientFactory.getClient("127.0.0.1")._2 + val sameHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection + val otherHostClient = MongodbClientFactory.getClient("127.0.0.1").clientConnection val gracefully = false - MongodbClientFactory.mongoClient.size should be (2) + MongodbClientFactory.getClientPoolSize should be (2) MongodbClientFactory.closeAll(notGracefully) - MongodbClientFactory.mongoClient.size should be (0) - } - + MongodbClientFactory.getClientPoolSize should be (0) - after { MongodbClientFactory.closeAll(notGracefully) } - }