Skip to content

Commit

Permalink
[CELEBORN-1270] Introduce PbPackedPartitionLocations to (de-)serializ…
Browse files Browse the repository at this point in the history
…e PartitionLocations more efficiently

### What changes were proposed in this pull request?
1. Introduces new approaches to (de-)serialize partition locations.
2. The Celeborn server remains compatible with old clients.

### Why are the changes needed?
1. Improve memory efficiency for partition locations.

### Does this PR introduce _any_ user-facing change?
NO.

### How was this patch tested?
1. Pass GA.
2. Run tests on cluster:
```
val start = System.currentTimeMillis
spark.sparkContext.parallelize(1 to 10000, 10000).flatMap( _ => (1 to 950000).iterator.map(num => num)).repartition(10000).count
val after = System.currentTimeMillis
println((after-start)/1000)
```
packed RPC time: 70,65,64,64,64,64
baseline RPC time: 69,66,66,66,67,66

I think this PR does not introduce performance overhead.

4. RPC size test: this PR can reduce PRC size by up to 60%.

Closes apache#2456 from FMX/CELEBORN-1270.

Authored-by: mingji <[email protected]>
Signed-off-by: Shuang <[email protected]>
  • Loading branch information
FMX authored and RexXiong committed May 11, 2024
1 parent db163bd commit 8dd33ce
Show file tree
Hide file tree
Showing 8 changed files with 400 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,15 @@ private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
if (StatusCode.SUCCESS.equals(respStatus)) {
ConcurrentHashMap<Integer, PartitionLocation> result = JavaUtils.newConcurrentHashMap();
for (int i = 0; i < response.getPartitionLocationsList().size(); i++) {
PartitionLocation partitionLoc =
PbSerDeUtils.fromPbPartitionLocation(response.getPartitionLocationsList().get(i));
pushExcludedWorkers.remove(partitionLoc.hostAndPushPort());
if (partitionLoc.hasPeer()) {
pushExcludedWorkers.remove(partitionLoc.getPeer().hostAndPushPort());
Tuple2<List<PartitionLocation>, List<PartitionLocation>> locations =
PbSerDeUtils.fromPbPackedPartitionLocationsPair(
response.getPackedPartitionLocationsPair());
for (PartitionLocation location : locations._1) {
pushExcludedWorkers.remove(location.hostAndPushPort());
if (location.hasPeer()) {
pushExcludedWorkers.remove(location.getPeer().hostAndPushPort());
}
result.put(partitionLoc.getId(), partitionLoc);
result.put(location.getId(), location);
}
return result;
} else if (StatusCode.SLOT_NOT_AVAILABLE.equals(respStatus)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,16 +585,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
case _ => Option.empty
}

val locations = PbSerDeUtils.fromPbPackedPartitionLocationsPair(
response.getPackedPartitionLocationsPair)._1.asScala

registeringShuffleRequest.asScala
.get(shuffleId)
.foreach(_.asScala.foreach(context => {
partitionType match {
case PartitionType.MAP =>
if (response.getStatus == StatusCode.SUCCESS.getValue) {
val partitionLocations =
response.getPartitionLocationsList.asScala.filter(
_.getId == context.partitionId).map(r =>
PbSerDeUtils.fromPbPartitionLocation(r)).toArray
val partitionLocations = locations.filter(_.getId == context.partitionId).toArray
processMapTaskReply(
shuffleId,
context.context,
Expand Down Expand Up @@ -1540,7 +1540,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
userIdentifier,
slotsAssignMaxWorkers,
availableStorageTypes,
excludedWorkerSet)
excludedWorkerSet,
true)
val res = requestMasterRequestSlots(req)
if (res.status != StatusCode.SUCCESS) {
requestMasterRequestSlots(req)
Expand All @@ -1556,7 +1557,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
} catch {
case e: Exception =>
logError(s"AskSync RegisterShuffle for $shuffleKey failed.", e)
RequestSlotsResponse(StatusCode.REQUEST_FAILED, new WorkerResource())
RequestSlotsResponse(StatusCode.REQUEST_FAILED, new WorkerResource(), message.packed)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ class ReducePartitionCommitHandler(
GetReducerFileGroupResponse(
StatusCode.SHUFFLE_DATA_LOST,
JavaUtils.newConcurrentHashMap(),
Array.empty))
Array.empty,
new util.HashSet[Integer]()))
} else {
// LocalNettyRpcCallContext is for the UTs
if (context.isInstanceOf[LocalNettyRpcCallContext]) {
Expand Down
31 changes: 31 additions & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ message PbWorkerInfo {

message PbFileGroup {
repeated PbPartitionLocation locations = 1;
PbPackedPartitionLocationsPair partitionLocationsPair = 2;
}

message PbRegisterWorker {
Expand Down Expand Up @@ -257,6 +258,7 @@ message PbRegisterMapPartitionTask {
message PbRegisterShuffleResponse {
int32 status = 1;
repeated PbPartitionLocation partitionLocations = 2;
PbPackedPartitionLocationsPair packedPartitionLocationsPair = 3;
}

message PbRequestSlots {
Expand All @@ -272,6 +274,7 @@ message PbRequestSlots {
int32 maxWorkers = 10;
int32 availableStorageTypes = 11;
repeated PbWorkerInfo excludedWorkerSet = 12;
bool packed = 13;
}

message PbSlotInfo {
Expand All @@ -296,6 +299,7 @@ message PbReleaseSlotsResponse {
message PbRequestSlotsResponse {
int32 status = 1;
map<string, PbWorkerResource> workerResource = 2;
map<string, PbPackedWorkerResource> packedWorkerResource = 3;
}

message PbRevivePartitionInfo {
Expand Down Expand Up @@ -451,6 +455,7 @@ message PbReserveSlots {
int64 pushDataTimeout = 10;
bool partitionSplitEnabled = 11;
int32 availableStorageTypes = 12;
PbPackedPartitionLocationsPair partitionLocationsPair = 13;
}

message PbReserveSlotsResponse {
Expand Down Expand Up @@ -758,3 +763,29 @@ message PbApplicationMeta {
message PbApplicationMetaRequest {
string appId = 1;
}

message PbPackedPartitionLocations {
repeated int32 ids = 1;
repeated int32 epoches = 2;
repeated int32 workerIds = 3;
repeated string workerIdsSet = 4;
repeated bytes mapIdBitMap = 5;
repeated int32 types = 6;
repeated int32 mountPoints = 7;
repeated string mountPointsSet = 8;
repeated bool finalResult = 9 ;
repeated string filePaths = 10;
repeated int32 availableStorageTypes = 11;
repeated int32 modes = 12;
}

message PbPackedPartitionLocationsPair {
PbPackedPartitionLocations locations = 1;
repeated int32 peerIndexes = 2;
int32 inputLocationSize = 3;
}

message PbPackedWorkerResource {
PbPackedPartitionLocationsPair locationPairs = 1;
string networkLocation = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.celeborn.common.protocol.message

import java.util
import java.util.UUID
import java.util.{Collections, UUID}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -160,12 +160,13 @@ object ControlMessages extends Logging {
object RegisterShuffleResponse {
def apply(
status: StatusCode,
partitionLocations: Array[PartitionLocation]): PbRegisterShuffleResponse =
PbRegisterShuffleResponse.newBuilder()
partitionLocations: Array[PartitionLocation]): PbRegisterShuffleResponse = {
val builder = PbRegisterShuffleResponse.newBuilder()
.setStatus(status.getValue)
.addAllPartitionLocations(
partitionLocations.map(PbSerDeUtils.toPbPartitionLocation).toSeq.asJava)
.build()
builder.setPackedPartitionLocationsPair(
PbSerDeUtils.toPbPackedPartitionLocationsPair(partitionLocations.toList))
builder.build()
}
}

case class RequestSlots(
Expand All @@ -179,6 +180,7 @@ object ControlMessages extends Logging {
maxWorkers: Int,
availableStorageTypes: Int,
excludedWorkerSet: Set[WorkerInfo] = Set.empty,
packed: Boolean = false,
override var requestId: String = ZERO_UUID)
extends MasterRequestMessage

Expand All @@ -199,7 +201,8 @@ object ControlMessages extends Logging {

case class RequestSlotsResponse(
status: StatusCode,
workerResource: WorkerResource)
workerResource: WorkerResource,
packed: Boolean = false)
extends MasterMessage

object Revive {
Expand Down Expand Up @@ -279,7 +282,7 @@ object ControlMessages extends Logging {
status: StatusCode,
fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
attempts: Array[Int],
partitionIds: util.Set[Integer] = new util.HashSet[Integer]())
partitionIds: util.Set[Integer] = Collections.emptySet[Integer]())
extends MasterMessage

object WorkerExclude {
Expand Down Expand Up @@ -583,6 +586,7 @@ object ControlMessages extends Logging {
maxWorkers,
availableStorageTypes,
excludedWorkerSet,
packed,
requestId) =>
val payload = PbRequestSlots.newBuilder()
.setApplicationId(applicationId)
Expand All @@ -596,6 +600,7 @@ object ControlMessages extends Logging {
.setAvailableStorageTypes(availableStorageTypes)
.setUserIdentifier(PbSerDeUtils.toPbUserIdentifier(userIdentifier))
.addAllExcludedWorkerSet(excludedWorkerSet.map(PbSerDeUtils.toPbWorkerInfo(_, true)).asJava)
.setPacked(packed)
.build().toByteArray
new TransportMessage(MessageType.REQUEST_SLOTS, payload)

Expand All @@ -616,12 +621,16 @@ object ControlMessages extends Logging {
.setStatus(status.getValue).build().toByteArray
new TransportMessage(MessageType.RELEASE_SLOTS_RESPONSE, payload)

case RequestSlotsResponse(status, workerResource) =>
case RequestSlotsResponse(status, workerResource, packed) =>
val builder = PbRequestSlotsResponse.newBuilder()
.setStatus(status.getValue)
if (!workerResource.isEmpty) {
builder.putAllWorkerResource(
PbSerDeUtils.toPbWorkerResource(workerResource))
if (packed) {
builder.putAllPackedWorkerResource(PbSerDeUtils.toPbPackedWorkerResource(workerResource))
} else {
builder.putAllWorkerResource(
PbSerDeUtils.toPbWorkerResource(workerResource))
}
}
val payload = builder.build().toByteArray
new TransportMessage(MessageType.REQUEST_SLOTS_RESPONSE, payload)
Expand Down Expand Up @@ -660,10 +669,10 @@ object ControlMessages extends Logging {
.setStatus(status.getValue)
builder.putAllFileGroups(
fileGroup.asScala.map { case (partitionId, fileGroup) =>
(
partitionId,
PbFileGroup.newBuilder().addAllLocations(fileGroup.asScala.map(PbSerDeUtils
.toPbPartitionLocation).toList.asJava).build())
val pbFileGroupBuilder = PbFileGroup.newBuilder()
pbFileGroupBuilder.setPartitionLocationsPair(
PbSerDeUtils.toPbPackedPartitionLocationsPair(fileGroup.asScala.toList))
(partitionId, pbFileGroupBuilder.build())
}.asJava)
builder.addAllAttempts(attempts.map(Integer.valueOf).toIterable.asJava)
builder.addAllPartitionIds(partitionIds)
Expand Down Expand Up @@ -795,10 +804,8 @@ object ControlMessages extends Logging {
val payload = PbReserveSlots.newBuilder()
.setApplicationId(applicationId)
.setShuffleId(shuffleId)
.addAllPrimaryLocations(primaryLocations.asScala
.map(PbSerDeUtils.toPbPartitionLocation).toList.asJava)
.addAllReplicaLocations(replicaLocations.asScala
.map(PbSerDeUtils.toPbPartitionLocation).toList.asJava)
.setPartitionLocationsPair(PbSerDeUtils.toPbPackedPartitionLocationsPair(
primaryLocations.asScala.toList ++ replicaLocations.asScala.toList))
.setSplitThreshold(splitThreshold)
.setSplitMode(splitMode.getValue)
.setPartitionType(partType.getValue)
Expand Down Expand Up @@ -1001,14 +1008,22 @@ object ControlMessages extends Logging {
pbRequestSlots.getMaxWorkers,
pbRequestSlots.getAvailableStorageTypes,
excludedWorkerInfoSet,
pbRequestSlots.getPacked,
pbRequestSlots.getRequestId)

case REQUEST_SLOTS_RESPONSE_VALUE =>
val pbRequestSlotsResponse = PbRequestSlotsResponse.parseFrom(message.getPayload)
val workerResource =
if (pbRequestSlotsResponse.getWorkerResourceMap.isEmpty) {
PbSerDeUtils.fromPbPackedWorkerResource(
pbRequestSlotsResponse.getPackedWorkerResourceMap)
} else {
PbSerDeUtils.fromPbWorkerResource(
pbRequestSlotsResponse.getWorkerResourceMap)
}
RequestSlotsResponse(
Utils.toStatusCode(pbRequestSlotsResponse.getStatus),
PbSerDeUtils.fromPbWorkerResource(
pbRequestSlotsResponse.getWorkerResourceMap))
workerResource)

case CHANGE_LOCATION_VALUE =>
PbRevive.parseFrom(message.getPayload)
Expand Down Expand Up @@ -1041,8 +1056,8 @@ object ControlMessages extends Logging {
case (partitionId, fileGroup) =>
(
partitionId,
fileGroup.getLocationsList.asScala.map(
PbSerDeUtils.fromPbPartitionLocation).toSet.asJava)
PbSerDeUtils.fromPbPackedPartitionLocationsPair(
fileGroup.getPartitionLocationsPair)._1.asScala.toSet.asJava)
}.asJava

val attempts = pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray
Expand Down Expand Up @@ -1136,13 +1151,22 @@ object ControlMessages extends Logging {
case RESERVE_SLOTS_VALUE =>
val pbReserveSlots = PbReserveSlots.parseFrom(message.getPayload)
val userIdentifier = PbSerDeUtils.fromPbUserIdentifier(pbReserveSlots.getUserIdentifier)
val (primaryLocations, replicateLocations) =
if (pbReserveSlots.getPrimaryLocationsList.isEmpty) {
PbSerDeUtils.fromPbPackedPartitionLocationsPair(
pbReserveSlots.getPartitionLocationsPair)
} else {
(
new util.ArrayList[PartitionLocation](pbReserveSlots.getPrimaryLocationsList.asScala
.map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava),
new util.ArrayList[PartitionLocation](pbReserveSlots.getReplicaLocationsList.asScala
.map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava))
}
ReserveSlots(
pbReserveSlots.getApplicationId,
pbReserveSlots.getShuffleId,
new util.ArrayList[PartitionLocation](pbReserveSlots.getPrimaryLocationsList.asScala
.map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava),
new util.ArrayList[PartitionLocation](pbReserveSlots.getReplicaLocationsList.asScala
.map(PbSerDeUtils.fromPbPartitionLocation).toList.asJava),
primaryLocations,
replicateLocations,
pbReserveSlots.getSplitThreshold,
Utils.toShuffleSplitMode(pbReserveSlots.getSplitMode),
Utils.toPartitionType(pbReserveSlots.getPartitionType),
Expand Down
Loading

0 comments on commit 8dd33ce

Please sign in to comment.