diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 08377d6fb7a..e44b8ba856a 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -605,14 +605,15 @@ private ConcurrentHashMap registerShuffleInternal( StatusCode respStatus = Utils.toStatusCode(response.getStatus()); if (StatusCode.SUCCESS.equals(respStatus)) { ConcurrentHashMap result = JavaUtils.newConcurrentHashMap(); - for (int i = 0; i < response.getPartitionLocationsList().size(); i++) { - PartitionLocation partitionLoc = - PbSerDeUtils.fromPbPartitionLocation(response.getPartitionLocationsList().get(i)); - pushExcludedWorkers.remove(partitionLoc.hostAndPushPort()); - if (partitionLoc.hasPeer()) { - pushExcludedWorkers.remove(partitionLoc.getPeer().hostAndPushPort()); + Tuple2, List> locations = + PbSerDeUtils.fromPbPackedPartitionLocationsPair( + response.getPackedPartitionLocationsPair()); + for (PartitionLocation location : locations._1) { + pushExcludedWorkers.remove(location.hostAndPushPort()); + if (location.hasPeer()) { + pushExcludedWorkers.remove(location.getPeer().hostAndPushPort()); } - result.put(partitionLoc.getId(), partitionLoc); + result.put(location.getId(), location); } return result; } else if (StatusCode.SLOT_NOT_AVAILABLE.equals(respStatus)) { diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index a8a9bca86b5..e3745b43134 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -585,16 +585,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case _ => Option.empty } + val locations = PbSerDeUtils.fromPbPackedPartitionLocationsPair( + response.getPackedPartitionLocationsPair)._1.asScala + registeringShuffleRequest.asScala .get(shuffleId) .foreach(_.asScala.foreach(context => { partitionType match { case PartitionType.MAP => if (response.getStatus == StatusCode.SUCCESS.getValue) { - val partitionLocations = - response.getPartitionLocationsList.asScala.filter( - _.getId == context.partitionId).map(r => - PbSerDeUtils.fromPbPartitionLocation(r)).toArray + val partitionLocations = locations.filter(_.getId == context.partitionId).toArray processMapTaskReply( shuffleId, context.context, @@ -1540,7 +1540,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends userIdentifier, slotsAssignMaxWorkers, availableStorageTypes, - excludedWorkerSet) + excludedWorkerSet, + true) val res = requestMasterRequestSlots(req) if (res.status != StatusCode.SUCCESS) { requestMasterRequestSlots(req) @@ -1556,7 +1557,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } catch { case e: Exception => logError(s"AskSync RegisterShuffle for $shuffleKey failed.", e) - RequestSlotsResponse(StatusCode.REQUEST_FAILED, new WorkerResource()) + RequestSlotsResponse(StatusCode.REQUEST_FAILED, new WorkerResource(), message.packed) } } diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index e86bc231738..23d6a7b8df7 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -275,7 +275,8 @@ class ReducePartitionCommitHandler( GetReducerFileGroupResponse( StatusCode.SHUFFLE_DATA_LOST, JavaUtils.newConcurrentHashMap(), - Array.empty)) + Array.empty, + new util.HashSet[Integer]())) } else { // LocalNettyRpcCallContext is for the UTs if (context.isInstanceOf[LocalNettyRpcCallContext]) { diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 7f190a3bbba..7c457bdf78d 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -173,6 +173,7 @@ message PbWorkerInfo { message PbFileGroup { repeated PbPartitionLocation locations = 1; + PbPackedPartitionLocationsPair partitionLocationsPair = 2; } message PbRegisterWorker { @@ -257,6 +258,7 @@ message PbRegisterMapPartitionTask { message PbRegisterShuffleResponse { int32 status = 1; repeated PbPartitionLocation partitionLocations = 2; + PbPackedPartitionLocationsPair packedPartitionLocationsPair = 3; } message PbRequestSlots { @@ -272,6 +274,7 @@ message PbRequestSlots { int32 maxWorkers = 10; int32 availableStorageTypes = 11; repeated PbWorkerInfo excludedWorkerSet = 12; + bool packed = 13; } message PbSlotInfo { @@ -296,6 +299,7 @@ message PbReleaseSlotsResponse { message PbRequestSlotsResponse { int32 status = 1; map workerResource = 2; + map packedWorkerResource = 3; } message PbRevivePartitionInfo { @@ -451,6 +455,7 @@ message PbReserveSlots { int64 pushDataTimeout = 10; bool partitionSplitEnabled = 11; int32 availableStorageTypes = 12; + PbPackedPartitionLocationsPair partitionLocationsPair = 13; } message PbReserveSlotsResponse { @@ -758,3 +763,29 @@ message PbApplicationMeta { message PbApplicationMetaRequest { string appId = 1; } + +message PbPackedPartitionLocations { + repeated int32 ids = 1; + repeated int32 epoches = 2; + repeated int32 workerIds = 3; + repeated string workerIdsSet = 4; + repeated bytes mapIdBitMap = 5; + repeated int32 types = 6; + repeated int32 mountPoints = 7; + repeated string mountPointsSet = 8; + repeated bool finalResult = 9 ; + repeated string filePaths = 10; + repeated int32 availableStorageTypes = 11; + repeated int32 modes = 12; +} + +message PbPackedPartitionLocationsPair { + PbPackedPartitionLocations locations = 1; + repeated int32 peerIndexes = 2; + int32 inputLocationSize = 3; +} + +message PbPackedWorkerResource { + PbPackedPartitionLocationsPair locationPairs = 1; + string networkLocation = 2; +} \ No newline at end of file diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 9bd9469b074..9491ba29d3b 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -18,7 +18,7 @@ package org.apache.celeborn.common.protocol.message import java.util -import java.util.UUID +import java.util.{Collections, UUID} import scala.collection.JavaConverters._ @@ -160,12 +160,13 @@ object ControlMessages extends Logging { object RegisterShuffleResponse { def apply( status: StatusCode, - partitionLocations: Array[PartitionLocation]): PbRegisterShuffleResponse = - PbRegisterShuffleResponse.newBuilder() + partitionLocations: Array[PartitionLocation]): PbRegisterShuffleResponse = { + val builder = PbRegisterShuffleResponse.newBuilder() .setStatus(status.getValue) - .addAllPartitionLocations( - partitionLocations.map(PbSerDeUtils.toPbPartitionLocation).toSeq.asJava) - .build() + builder.setPackedPartitionLocationsPair( + PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList)) + builder.build() + } } case class RequestSlots( @@ -179,6 +180,7 @@ object ControlMessages extends Logging { maxWorkers: Int, availableStorageTypes: Int, excludedWorkerSet: Set[WorkerInfo] = Set.empty, + packed: Boolean = false, override var requestId: String = ZERO_UUID) extends MasterRequestMessage @@ -199,7 +201,8 @@ object ControlMessages extends Logging { case class RequestSlotsResponse( status: StatusCode, - workerResource: WorkerResource) + workerResource: WorkerResource, + packed: Boolean = false) extends MasterMessage object Revive { @@ -279,7 +282,7 @@ object ControlMessages extends Logging { status: StatusCode, fileGroup: util.Map[Integer, util.Set[PartitionLocation]], attempts: Array[Int], - partitionIds: util.Set[Integer] = new util.HashSet[Integer]()) + partitionIds: util.Set[Integer] = Collections.emptySet[Integer]()) extends MasterMessage object WorkerExclude { @@ -583,6 +586,7 @@ object ControlMessages extends Logging { maxWorkers, availableStorageTypes, excludedWorkerSet, + packed, requestId) => val payload = PbRequestSlots.newBuilder() .setApplicationId(applicationId) @@ -596,6 +600,7 @@ object ControlMessages extends Logging { .setAvailableStorageTypes(availableStorageTypes) .setUserIdentifier(PbSerDeUtils.toPbUserIdentifier(userIdentifier)) .addAllExcludedWorkerSet(excludedWorkerSet.map(PbSerDeUtils.toPbWorkerInfo(_, true)).asJava) + .setPacked(packed) .build().toByteArray new TransportMessage(MessageType.REQUEST_SLOTS, payload) @@ -616,12 +621,16 @@ object ControlMessages extends Logging { .setStatus(status.getValue).build().toByteArray new TransportMessage(MessageType.RELEASE_SLOTS_RESPONSE, payload) - case RequestSlotsResponse(status, workerResource) => + case RequestSlotsResponse(status, workerResource, packed) => val builder = PbRequestSlotsResponse.newBuilder() .setStatus(status.getValue) if (!workerResource.isEmpty) { - builder.putAllWorkerResource( - PbSerDeUtils.toPbWorkerResource(workerResource)) + if (packed) { + builder.putAllPackedWorkerResource(PbSerDeUtils.toPbPackedWorkerResource(workerResource)) + } else { + builder.putAllWorkerResource( + PbSerDeUtils.toPbWorkerResource(workerResource)) + } } val payload = builder.build().toByteArray new TransportMessage(MessageType.REQUEST_SLOTS_RESPONSE, payload) @@ -660,10 +669,10 @@ object ControlMessages extends Logging { .setStatus(status.getValue) builder.putAllFileGroups( fileGroup.asScala.map { case (partitionId, fileGroup) => - ( - partitionId, - PbFileGroup.newBuilder().addAllLocations(fileGroup.asScala.map(PbSerDeUtils - .toPbPartitionLocation).toList.asJava).build()) + val pbFileGroupBuilder = PbFileGroup.newBuilder() + pbFileGroupBuilder.setPartitionLocationsPair( + PbSerDeUtils.toPbPackedPartitionLocationsPair(fileGroup.asScala.toList)) + (partitionId, pbFileGroupBuilder.build()) }.asJava) builder.addAllAttempts(attempts.map(Integer.valueOf).toIterable.asJava) builder.addAllPartitionIds(partitionIds) @@ -795,10 +804,8 @@ object ControlMessages extends Logging { val payload = PbReserveSlots.newBuilder() .setApplicationId(applicationId) .setShuffleId(shuffleId) - .addAllPrimaryLocations(primaryLocations.asScala - .map(PbSerDeUtils.toPbPartitionLocation).toList.asJava) - .addAllReplicaLocations(replicaLocations.asScala - .map(PbSerDeUtils.toPbPartitionLocation).toList.asJava) + .setPartitionLocationsPair(PbSerDeUtils.toPbPackedPartitionLocationsPair( + primaryLocations.asScala.toList ++ replicaLocations.asScala.toList)) .setSplitThreshold(splitThreshold) .setSplitMode(splitMode.getValue) .setPartitionType(partType.getValue) @@ -1001,14 +1008,22 @@ object ControlMessages extends Logging { pbRequestSlots.getMaxWorkers, pbRequestSlots.getAvailableStorageTypes, excludedWorkerInfoSet, + pbRequestSlots.getPacked, pbRequestSlots.getRequestId) case REQUEST_SLOTS_RESPONSE_VALUE => val pbRequestSlotsResponse = PbRequestSlotsResponse.parseFrom(message.getPayload) + val workerResource = + if (pbRequestSlotsResponse.getWorkerResourceMap.isEmpty) { + PbSerDeUtils.fromPbPackedWorkerResource( + pbRequestSlotsResponse.getPackedWorkerResourceMap) + } else { + PbSerDeUtils.fromPbWorkerResource( + pbRequestSlotsResponse.getWorkerResourceMap) + } RequestSlotsResponse( Utils.toStatusCode(pbRequestSlotsResponse.getStatus), - PbSerDeUtils.fromPbWorkerResource( - pbRequestSlotsResponse.getWorkerResourceMap)) + workerResource) case CHANGE_LOCATION_VALUE => PbRevive.parseFrom(message.getPayload) @@ -1041,8 +1056,8 @@ object ControlMessages extends Logging { case (partitionId, fileGroup) => ( partitionId, - fileGroup.getLocationsList.asScala.map( - PbSerDeUtils.fromPbPartitionLocation).toSet.asJava) + PbSerDeUtils.fromPbPackedPartitionLocationsPair( + fileGroup.getPartitionLocationsPair)._1.asScala.toSet.asJava) }.asJava val attempts = pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray @@ -1136,13 +1151,22 @@ object ControlMessages extends Logging { case RESERVE_SLOTS_VALUE => val pbReserveSlots = PbReserveSlots.parseFrom(message.getPayload) val userIdentifier = PbSerDeUtils.fromPbUserIdentifier(pbReserveSlots.getUserIdentifier) + val (primaryLocations, replicateLocations) = + if (pbReserveSlots.getPrimaryLocationsList.isEmpty) { + PbSerDeUtils.fromPbPackedPartitionLocationsPair( + pbReserveSlots.getPartitionLocationsPair) + } else { + ( + new util.ArrayList[PartitionLocation](pbReserveSlots.getPrimaryLocationsList.asScala + .map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava), + new util.ArrayList[PartitionLocation](pbReserveSlots.getReplicaLocationsList.asScala + .map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava)) + } ReserveSlots( pbReserveSlots.getApplicationId, pbReserveSlots.getShuffleId, - new util.ArrayList[PartitionLocation](pbReserveSlots.getPrimaryLocationsList.asScala - .map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava), - new util.ArrayList[PartitionLocation](pbReserveSlots.getReplicaLocationsList.asScala - .map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava), + primaryLocations, + replicateLocations, pbReserveSlots.getSplitThreshold, Utils.toShuffleSplitMode(pbReserveSlots.getSplitMode), Utils.toPartitionType(pbReserveSlots.getPartitionType), diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index 1ca5759d2f4..e8298843342 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import com.google.protobuf.InvalidProtocolBufferException import org.apache.celeborn.common.identity.UserIdentifier -import org.apache.celeborn.common.meta.{AppDiskUsage, AppDiskUsageSnapShot, ApplicationMeta, DiskFileInfo, DiskInfo, FileInfo, MapFileMeta, ReduceFileMeta, WorkerEventInfo, WorkerInfo, WorkerStatus} +import org.apache.celeborn.common.meta.{AppDiskUsage, AppDiskUsageSnapShot, ApplicationMeta, DiskFileInfo, DiskInfo, MapFileMeta, ReduceFileMeta, WorkerEventInfo, WorkerInfo, WorkerStatus} import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.PartitionLocation.Mode import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource @@ -490,4 +490,170 @@ object PbSerDeUtils { pbWorkerEventInfo.getWorkerEventType.getNumber, pbWorkerEventInfo.getEventStartTime()) } + + private def toPackedPartitionLocation( + pbPackedLocationsBuilder: PbPackedPartitionLocations.Builder, + workerIdIndex: Map[String, Int], + mountPointsIndex: Map[String, Int], + location: PartitionLocation): PbPackedPartitionLocations.Builder = { + pbPackedLocationsBuilder.addIds(location.getId) + pbPackedLocationsBuilder.addEpoches(location.getEpoch) + pbPackedLocationsBuilder.addWorkerIds(workerIdIndex(location.getWorker.toUniqueId())) + pbPackedLocationsBuilder.addMapIdBitMap( + Utils.roaringBitmapToByteString(location.getMapIdBitMap)) + pbPackedLocationsBuilder.addTypes(location.getStorageInfo.getType.getValue) + pbPackedLocationsBuilder.addMountPoints( + mountPointsIndex(location.getStorageInfo.getMountPoint)) + pbPackedLocationsBuilder.addFinalResult(location.getStorageInfo.isFinalResult) + if (location.getStorageInfo.getFilePath != null && location.getStorageInfo.getFilePath.nonEmpty) { + pbPackedLocationsBuilder.addFilePaths(location.getStorageInfo.getFilePath + .substring(location.getStorageInfo.getMountPoint.length)) + } else { + pbPackedLocationsBuilder.addFilePaths("") + } + pbPackedLocationsBuilder.addAvailableStorageTypes(location.getStorageInfo.availableStorageTypes) + pbPackedLocationsBuilder.addModes(location.getMode.mode()) + } + + def toPbPackedPartitionLocationsPair(inputLocations: List[PartitionLocation]) + : PbPackedPartitionLocationsPair = { + val packedLocationPairsBuilder = PbPackedPartitionLocationsPair.newBuilder() + val packedLocationsBuilder = PbPackedPartitionLocations.newBuilder() + + val implicateLocations = inputLocations.map(_.getPeer).filterNot(_ == null) + + val allLocations = (inputLocations ++ implicateLocations) + val workerIdList = new util.ArrayList[String]( + allLocations.map(_.getWorker.toUniqueId()).toSet.asJava) + val workerIdIndex = workerIdList.asScala.zipWithIndex.toMap + val mountPointsList = new util.ArrayList[String]( + allLocations.map( + _.getStorageInfo.getMountPoint).toSet.asJava) + val mountPointsIndex = mountPointsList.asScala.zipWithIndex.toMap + + packedLocationsBuilder.addAllWorkerIdsSet(workerIdList) + packedLocationsBuilder.addAllMountPointsSet(mountPointsList) + + val locationIndexes = allLocations.zipWithIndex.toMap + + for (location <- allLocations) { + toPackedPartitionLocation( + packedLocationsBuilder, + workerIdIndex, + mountPointsIndex, + location) + if (location.getPeer != null) { + packedLocationPairsBuilder.addPeerIndexes( + locationIndexes(location.getPeer)) + } else { + packedLocationPairsBuilder.addPeerIndexes(Integer.MAX_VALUE) + } + } + + packedLocationPairsBuilder.setInputLocationSize(inputLocations.size) + packedLocationPairsBuilder.setLocations(packedLocationsBuilder.build()).build() + } + + def fromPbPackedPartitionLocationsPair(pbPartitionLocationsPair: PbPackedPartitionLocationsPair) + : (util.List[PartitionLocation], util.List[PartitionLocation]) = { + val primaryLocations = new util.ArrayList[PartitionLocation]() + val replicateLocations = new util.ArrayList[PartitionLocation]() + val pbPackedPartitionLocations = pbPartitionLocationsPair.getLocations + val inputLocationSize = pbPartitionLocationsPair.getInputLocationSize + val idList = pbPackedPartitionLocations.getIdsList + val locationCount = idList.size() + var index = 0 + + val locations = new util.ArrayList[PartitionLocation]() + while (index < locationCount) { + val loc = + fromPackedPartitionLocations(pbPackedPartitionLocations, index) + if (index < inputLocationSize) { + if (loc.getMode == Mode.PRIMARY) { + primaryLocations.add(loc) + } else { + replicateLocations.add(loc) + } + } + locations.add(loc) + index = index + 1 + } + + index = 0 + while (index < locationCount) { + val replicateIndex = pbPartitionLocationsPair.getPeerIndexes(index) + if (replicateIndex != Integer.MAX_VALUE) { + locations.get(index).setPeer(locations.get(replicateIndex)) + } + index = index + 1 + } + + (primaryLocations, replicateLocations) + } + + private def fromPackedPartitionLocations( + pbPackedPartitionLocations: PbPackedPartitionLocations, + index: Int): PartitionLocation = { + val workerIdParts = pbPackedPartitionLocations.getWorkerIdsSet( + pbPackedPartitionLocations.getWorkerIds(index)).split(":").map(_.trim) + var filePath = pbPackedPartitionLocations.getFilePaths(index) + if (filePath != "") { + filePath = pbPackedPartitionLocations.getMountPointsSet( + pbPackedPartitionLocations.getMountPoints(index)) + + pbPackedPartitionLocations.getFilePaths(index) + } + + val mode = + if (pbPackedPartitionLocations.getModes(index) == Mode.PRIMARY.mode()) { + Mode.PRIMARY + } else { + Mode.REPLICA + } + + new PartitionLocation( + pbPackedPartitionLocations.getIds(index), + pbPackedPartitionLocations.getEpoches(index), + workerIdParts(0), + workerIdParts(1).toInt, + workerIdParts(2).toInt, + workerIdParts(3).toInt, + workerIdParts(4).toInt, + mode, + null, + new StorageInfo( + StorageInfo.typesMap.get(pbPackedPartitionLocations.getTypes(index)), + pbPackedPartitionLocations.getMountPointsSet( + pbPackedPartitionLocations.getMountPoints(index)), + pbPackedPartitionLocations.getFinalResult(index), + filePath, + pbPackedPartitionLocations.getAvailableStorageTypes(index)), + Utils.byteStringToRoaringBitmap(pbPackedPartitionLocations.getMapIdBitMap(index))) + } + + def fromPbPackedWorkerResource(pbWorkerResource: util.Map[String, PbPackedWorkerResource]) + : WorkerResource = { + val slots = new WorkerResource() + pbWorkerResource.asScala.foreach { case (uniqueId, pbPackedWorkerResource) => + val networkLocation = pbPackedWorkerResource.getNetworkLocation + val workerInfo = WorkerInfo.fromUniqueId(uniqueId) + workerInfo.networkLocation = networkLocation + val (primaryLocations, replicateLocations) = + fromPbPackedPartitionLocationsPair(pbPackedWorkerResource.getLocationPairs) + slots.put(workerInfo, (primaryLocations, replicateLocations)) + } + slots + } + + def toPbPackedWorkerResource(workerResource: WorkerResource) + : util.Map[String, PbPackedWorkerResource] = { + workerResource.asScala.map { case (workerInfo, (primaryLocations, replicaLocations)) => + val pbWorkerResource = PbPackedWorkerResource.newBuilder() + .setLocationPairs(toPbPackedPartitionLocationsPair( + primaryLocations.asScala.toList ++ replicaLocations.asScala.toList)) + .setNetworkLocation(workerInfo.networkLocation) + .build() + workerInfo.toUniqueId() -> pbWorkerResource + }.asJava + } + } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala index 3f4c7b2f414..5166c796533 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala @@ -21,14 +21,17 @@ import java.io.File import java.util import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta.{ApplicationMeta, DeviceInfo, DiskFileInfo, DiskInfo, FileInfo, ReduceFileMeta, WorkerEventInfo, WorkerInfo, WorkerStatus} -import org.apache.celeborn.common.protocol.{PartitionLocation, StorageInfo} -import org.apache.celeborn.common.protocol.PartitionLocation +import org.apache.celeborn.common.protocol.{PartitionLocation, PbPackedWorkerResource, PbWorkerResource, StorageInfo} import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource import org.apache.celeborn.common.quota.ResourceConsumption +import org.apache.celeborn.common.util.PbSerDeUtils.{fromPbPackedPartitionLocationsPair, toPbPackedPartitionLocationsPair} class PbSerDeUtilsTest extends CelebornFunSuite { @@ -290,4 +293,122 @@ class PbSerDeUtilsTest extends CelebornFunSuite { assert(restoredApplicationMeta.equals(applicationMeta)) } + + test("testPackedPartitionLocationPairCase1") { + partitionLocation3.setPeer(partitionLocation2) + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation3, partitionLocation2)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + val loc2 = rePb._2.get(0) + + assert(partitionLocation3 == loc1) + assert(partitionLocation2 == loc2) + } + + test("testPackedPartitionLocationPairCase2") { + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation3)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + + assert(partitionLocation3 == loc1) + } + + private def testSerializationPerformance(scale: Int): Unit = { + val mountPoints = List( + "/mnt/disk1/celeborn/", + "/mnt/disk2/celeborn/", + "/mnt/disk3/celeborn/", + "/mnt/disk4/celeborn/", + "/mnt/disk5/celeborn/", + "/mnt/disk6/celeborn/", + "/mnt/disk7/celeborn/", + "/mnt/disk8/celeborn/") + val hosts = (0 to 50).map(f => + ( + s"host${f}", + Random.nextInt(65535), + Random.nextInt(65535), + Random.nextInt(65535), + Random.nextInt(65535))).toList + val (primaryLocations, replicaLocations) = (0 to scale).map(i => { + val host = hosts(Random.nextInt(50)) + val mountPoint = mountPoints(Random.nextInt(8)) + val primary = new PartitionLocation( + i, + 0, + host._1, + host._2, + host._3, + host._4, + host._5, + PartitionLocation.Mode.PRIMARY, + null, + new StorageInfo( + StorageInfo.Type.HDD, + mountPoint, + false, + mountPoint + "/application/0/" + RandomStringUtils.randomNumeric(6), + StorageInfo.LOCAL_DISK_MASK), + null) + + val rHost = hosts(Random.nextInt(50)) + val rMountPoint = mountPoints(Random.nextInt(8)) + + val replicate = new PartitionLocation( + i, + 0, + rHost._1, + rHost._2, + rHost._3, + rHost._4, + rHost._5, + PartitionLocation.Mode.REPLICA, + null, + new StorageInfo( + StorageInfo.Type.HDD, + rMountPoint, + false, + rMountPoint + "/application-xxxsdsada-1/0/" + RandomStringUtils.randomNumeric(6), + StorageInfo.LOCAL_DISK_MASK), + null) + primary.setPeer(replicate) + replicate.setPeer(primary) + (primary, replicate) + }).toList.unzip + + val workerResourceSize = PbWorkerResource.newBuilder() + .addAllPrimaryPartitions(primaryLocations.map(PbSerDeUtils.toPbPartitionLocation).asJava) + .addAllReplicaPartitions(replicaLocations.map(PbSerDeUtils.toPbPartitionLocation).asJava) + .setNetworkLocation("location1") + .build().toByteArray.length + + val pbPackedWorkerResource = PbPackedWorkerResource.newBuilder() + .setLocationPairs(toPbPackedPartitionLocationsPair( + primaryLocations ++ replicaLocations)) + .setNetworkLocation("location1") + .build() + val packedWorkerResourceSize = pbPackedWorkerResource.toByteArray.length + + val (locs1, locs2) = fromPbPackedPartitionLocationsPair(pbPackedWorkerResource.getLocationPairs) + + assert(primaryLocations.size === locs1.size()) + assert(replicaLocations.size === locs2.size()) + + assert(primaryLocations.zip(locs1.asScala).count(x => x._1 != x._2) == 0) + assert(replicaLocations.zip(locs2.asScala).count(x => x._1 != x._2) == 0) + + assert(packedWorkerResourceSize < workerResourceSize) + log.info(s"Packed size : ${packedWorkerResourceSize} unpacked size :${workerResourceSize}") + log.info( + s"Reduced size : ${(workerResourceSize - packedWorkerResourceSize) / (workerResourceSize * 1.0f) * 100} %") + } + + test("serializationComparasion") { + testSerializationPerformance(100) + } + } diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala index 3cdc8c51250..6bc675ecbef 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala @@ -453,7 +453,7 @@ private[celeborn] class Master( // keep it for compatible reason context.reply(ReleaseSlotsResponse(StatusCode.SUCCESS)) - case requestSlots @ RequestSlots(applicationId, _, _, _, _, _, _, _, _, _, _) => + case requestSlots @ RequestSlots(applicationId, _, _, _, _, _, _, _, _, _, _, _) => logTrace(s"Received RequestSlots request $requestSlots.") checkAuth(context, applicationId) executeWithLeaderChecker(context, handleRequestSlots(context, requestSlots)) @@ -805,7 +805,8 @@ private[celeborn] class Master( if (numAvailableWorkers == 0) { logError(s"Offer slots for $shuffleKey failed due to all workers are excluded!") - context.reply(RequestSlotsResponse(StatusCode.WORKER_EXCLUDED, new WorkerResource())) + context.reply( + RequestSlotsResponse(StatusCode.WORKER_EXCLUDED, new WorkerResource(), requestSlots.packed)) } val numWorkers = Math.min( @@ -862,7 +863,10 @@ private[celeborn] class Master( // reply false if offer slots failed if (slots == null || slots.isEmpty) { logError(s"Offer slots for $numReducers reducers of $shuffleKey failed!") - context.reply(RequestSlotsResponse(StatusCode.SLOT_NOT_AVAILABLE, new WorkerResource())) + context.reply(RequestSlotsResponse( + StatusCode.SLOT_NOT_AVAILABLE, + new WorkerResource(), + requestSlots.packed)) return } @@ -893,7 +897,10 @@ private[celeborn] class Master( if (authEnabled) { pushApplicationMetaToWorkers(requestSlots, slots) } - context.reply(RequestSlotsResponse(StatusCode.SUCCESS, slots.asInstanceOf[WorkerResource])) + context.reply(RequestSlotsResponse( + StatusCode.SUCCESS, + slots.asInstanceOf[WorkerResource], + requestSlots.packed)) } def pushApplicationMetaToWorkers(