Skip to content

Commit

Permalink
[CELEBORN-1490][CIP-6] Support process large buffer in flink hybrid s…
Browse files Browse the repository at this point in the history
…huffle

### What changes were proposed in this pull request?

This is the last PR in the CIP-6 series.

Fix the bug when hybrid shuffle face the buffer which large then 32K.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
UT

Closes apache#2873 from reswqa/11-large-buffer-10month.

Lead-authored-by: Yuxin Tan <[email protected]>
Co-authored-by: Weijie Guo <[email protected]>
Signed-off-by: Shuang <[email protected]>
  • Loading branch information
2 people authored and RexXiong committed Nov 4, 2024
1 parent d4044c5 commit 7ebd168
Show file tree
Hide file tree
Showing 12 changed files with 350 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -157,6 +158,28 @@ public void close() {
public static Queue<Buffer> unpack(ByteBuf byteBuf) {
Queue<Buffer> buffers = new ArrayDeque<>();
try {
if (byteBuf instanceof CompositeByteBuf) {
// If the received byteBuf is a CompositeByteBuf, it indicates that the byteBuf originates
// from the Flink hybrid shuffle integration strategy. This byteBuf consists of two parts: a
// celeborn header and a data buffer.
CompositeByteBuf compositeByteBuf = (CompositeByteBuf) byteBuf;
ByteBuf headerBuffer = compositeByteBuf.component(0).unwrap();
ByteBuf dataBuffer = compositeByteBuf.component(1).unwrap();
dataBuffer.retain();
Utils.checkState(
dataBuffer instanceof Buffer, "Illegal data buffer type for CompositeByteBuf.");
BufferHeader bufferHeader = BufferUtils.getBufferHeaderFromByteBuf(headerBuffer, 0);
Buffer slice = ((Buffer) dataBuffer).readOnlySlice(0, bufferHeader.getSize());
buffers.add(
new UnpackSlicedBuffer(
slice,
bufferHeader.getDataType(),
bufferHeader.isCompressed(),
bufferHeader.getSize()));

return buffers;
}

Utils.checkState(byteBuf instanceof Buffer, "Illegal buffer type.");

Buffer buffer = (Buffer) byteBuf;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ public class FlinkTransportClientFactory extends TransportClientFactory {

private ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;

private int bufferSizeBytes;

public FlinkTransportClientFactory(
TransportContext context, List<TransportClientBootstrap> bootstraps) {
TransportContext context, List<TransportClientBootstrap> bootstraps, int bufferSizeBytes) {
super(context, bootstraps);
bufferSuppliers = JavaUtils.newConcurrentHashMap();
this.pooledAllocator = new UnpooledByteBufAllocator(true);
this.bufferSizeBytes = bufferSizeBytes;
}

public TransportClient createClientWithRetry(String remoteHost, int remotePort)
Expand All @@ -52,7 +55,7 @@ public TransportClient createClientWithRetry(String remoteHost, int remotePort)
remoteHost,
remotePort,
-1,
() -> new TransportFrameDecoderWithBufferSupplier(bufferSuppliers));
() -> new TransportFrameDecoderWithBufferSupplier(bufferSuppliers, bufferSizeBytes));
}

public void registerSupplier(long streamId, Supplier<ByteBuf> supplier) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.util.FrameDecoder;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;

public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandlerAdapter
implements FrameDecoder {
Expand All @@ -44,17 +46,37 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
private final ByteBuf msgBuf = Unpooled.buffer(8);
private Message curMsg = null;
private int remainingSize = -1;
private int totalReadBytes = 0;
private int largeBufferHeaderRemainingBytes = -1;
private boolean isReadingLargeBuffer = false;
private ByteBuf largeBufferHeaderBuffer;
public static final int DISABLE_LARGE_BUFFER_SPLIT_SIZE = -1;

/**
* The flink buffer size bytes. If the received buffer size large than this value, means that we
* need to divide the received buffer into multiple smaller buffers, each small than {@link
* #bufferSizeBytes}. And when this value set to {@link #DISABLE_LARGE_BUFFER_SPLIT_SIZE},
* indicates that large buffer splitting will not be checked.
*/
private final int bufferSizeBytes;

private final ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;

public TransportFrameDecoderWithBufferSupplier(
ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers) {
this(bufferSuppliers, DISABLE_LARGE_BUFFER_SPLIT_SIZE);
}

public TransportFrameDecoderWithBufferSupplier(
ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers, int bufferSizeBytes) {
this.bufferSuppliers = bufferSuppliers;
this.bufferSizeBytes = bufferSizeBytes;
}

private void copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int targetSize) {
private int copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int targetSize) {
int bytes = Math.min(source.readableBytes(), targetSize - target.readableBytes());
target.writeBytes(source.readSlice(bytes).nioBuffer());
return bytes;
}

private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
Expand All @@ -69,6 +91,15 @@ private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx
// type byte is read
headerBuf.readByte();
bodySize = headerBuf.readInt();
if (bufferSizeBytes != DISABLE_LARGE_BUFFER_SPLIT_SIZE && bodySize > bufferSizeBytes) {
// if the message body size is larger than bufferSizeBytes, we need to split it into two
// parts: celeborn header and data buffer
isReadingLargeBuffer = true;
// create a temporary buffer to store the celeborn header
largeBufferHeaderBuffer =
Unpooled.buffer(BufferUtils.HEADER_LENGTH, BufferUtils.HEADER_LENGTH);
largeBufferHeaderRemainingBytes = BufferUtils.HEADER_LENGTH;
}
decodeMsg(buf, ctx);
}
}
Expand Down Expand Up @@ -138,9 +169,31 @@ private io.netty.buffer.ByteBuf decodeBodyCopyOut(
}
}

copyByteBuf(buf, externalBuf, bodySize);
if (externalBuf.readableBytes() == bodySize) {
((ReadData) curMsg).setFlinkBuffer(externalBuf);
if (largeBufferHeaderRemainingBytes > 0) {
// if largeBufferHeaderRemainingBytes larger than zero, means that we are reading the celeborn
// header
int headerReadBytes = copyByteBuf(buf, largeBufferHeaderBuffer, BufferUtils.HEADER_LENGTH);
largeBufferHeaderRemainingBytes -= headerReadBytes;
totalReadBytes += headerReadBytes;
} else {
// if largeBufferHeaderRemainingBytes less or equal to zero, means that we are reading the
// data buffer
totalReadBytes += copyByteBuf(buf, externalBuf, getTargetDataBufferReadSize());
}

if (totalReadBytes == bodySize) {
ByteBuf resultByteBuf;
if (largeBufferHeaderBuffer == null) {
resultByteBuf = externalBuf;
} else {
// composite the celeborn header and data buffer together
CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
compositeByteBuf.addComponent(true, largeBufferHeaderBuffer);
compositeByteBuf.addComponent(true, externalBuf);
resultByteBuf = compositeByteBuf;
}

((ReadData) curMsg).setFlinkBuffer(resultByteBuf);
ctx.fireChannelRead(curMsg);
clear();
}
Expand Down Expand Up @@ -192,6 +245,13 @@ public void channelRead(ChannelHandlerContext ctx, Object data) {
}
}

private int getTargetDataBufferReadSize() {
if (isReadingLargeBuffer) {
return bodySize - BufferUtils.HEADER_LENGTH;
}
return bodySize;
}

private void clear() {
externalBuf = null;
curMsg = null;
Expand All @@ -200,6 +260,10 @@ private void clear() {
bodyBuf = null;
bodySize = -1;
remainingSize = -1;
totalReadBytes = 0;
largeBufferHeaderRemainingBytes = -1;
largeBufferHeaderBuffer = null;
isReadingLargeBuffer = false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.apache.celeborn.common.write.PushState;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
import org.apache.celeborn.plugin.flink.network.ReadClientHandler;
import org.apache.celeborn.plugin.flink.network.TransportFrameDecoderWithBufferSupplier;

public class FlinkShuffleClientImpl extends ShuffleClientImpl {
public static final Logger logger = LoggerFactory.getLogger(FlinkShuffleClientImpl.class);
Expand All @@ -81,6 +82,9 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl {

private final TransportContext context;

/** The buffer size bytes in flink, default value is 32KB. */
private final int bufferSizeBytes;

public static FlinkShuffleClientImpl get(
String appUniqueId,
String driverHost,
Expand All @@ -89,18 +93,49 @@ public static FlinkShuffleClientImpl get(
CelebornConf conf,
UserIdentifier userIdentifier)
throws DriverChangedException {
return get(
appUniqueId,
driverHost,
port,
driverTimestamp,
conf,
userIdentifier,
TransportFrameDecoderWithBufferSupplier.DISABLE_LARGE_BUFFER_SPLIT_SIZE);
}

public static FlinkShuffleClientImpl get(
String appUniqueId,
String driverHost,
int port,
long driverTimestamp,
CelebornConf conf,
UserIdentifier userIdentifier,
int bufferSizeBytes)
throws DriverChangedException {
if (null == _instance || !initialized || _instance.driverTimestamp < driverTimestamp) {
synchronized (FlinkShuffleClientImpl.class) {
if (null == _instance) {
_instance =
new FlinkShuffleClientImpl(
appUniqueId, driverHost, port, driverTimestamp, conf, userIdentifier);
appUniqueId,
driverHost,
port,
driverTimestamp,
conf,
userIdentifier,
bufferSizeBytes);
initialized = true;
} else if (!initialized || _instance.driverTimestamp < driverTimestamp) {
_instance.shutdown();
_instance =
new FlinkShuffleClientImpl(
appUniqueId, driverHost, port, driverTimestamp, conf, userIdentifier);
appUniqueId,
driverHost,
port,
driverTimestamp,
conf,
userIdentifier,
bufferSizeBytes);
initialized = true;
}
}
Expand Down Expand Up @@ -133,8 +168,10 @@ public FlinkShuffleClientImpl(
int port,
long driverTimestamp,
CelebornConf conf,
UserIdentifier userIdentifier) {
UserIdentifier userIdentifier,
int bufferSizeBytes) {
super(appUniqueId, conf, userIdentifier);
this.bufferSizeBytes = bufferSizeBytes;
String module = TransportModuleConstants.DATA_MODULE;
TransportConf dataTransportConf =
Utils.fromCelebornConf(conf, module, conf.getInt("celeborn." + module + ".io.threads", 8));
Expand All @@ -147,7 +184,8 @@ public FlinkShuffleClientImpl(

private void initializeTransportClientFactory() {
if (null == flinkTransportClientFactory) {
flinkTransportClientFactory = new FlinkTransportClientFactory(context, createBootstraps());
flinkTransportClientFactory =
new FlinkTransportClientFactory(context, createBootstraps(), bufferSizeBytes);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ public static BufferHeader getBufferHeader(Buffer buffer, int position, boolean
}
}

public static BufferHeader getBufferHeaderFromByteBuf(ByteBuf byteBuf, int position) {
byteBuf.readerIndex(position);
return new BufferHeader(
byteBuf.readInt(),
byteBuf.readInt(),
byteBuf.readInt(),
byteBuf.readInt(),
Buffer.DataType.values()[byteBuf.readByte()],
byteBuf.readBoolean(),
byteBuf.readInt());
}

public static void reserveNumRequiredBuffers(BufferPool bufferPool, int numRequiredBuffers)
throws IOException {
long startTime = System.nanoTime();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Queue;
import java.util.Random;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.core.memory.MemorySegment;
Expand All @@ -38,6 +40,7 @@
import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -220,6 +223,27 @@ public void testUnpackedBuffers() throws Exception {
unpacked.forEach(Buffer::recycleBuffer);
}

@Test
public void testUnpackCompositeBuffer() throws Exception {
Buffer dataBuffer = bufferPool.requestBuffer();
fillBufferWithRandomByte(dataBuffer);
ByteBuf bufferHeaderByteBuf = createBufferHeaderByteBuf(BUFFER_SIZE);
bufferHeaderByteBuf.retain();
CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
compositeByteBuf.addComponent(true, bufferHeaderByteBuf);
compositeByteBuf.addComponent(true, dataBuffer.asByteBuf());

Queue<Buffer> unpackedBuffers = BufferPacker.unpack(compositeByteBuf);
Assert.assertEquals(1, unpackedBuffers.size());
Assert.assertEquals(dataBuffer.readableBytes(), unpackedBuffers.peek().readableBytes());
Assert.assertEquals(BUFFER_SIZE, unpackedBuffers.peek().readableBytes());
for (int i = 0; i < BUFFER_SIZE; ++i) {
Assert.assertEquals(
dataBuffer.getMemorySegment().get(i), unpackedBuffers.peek().getMemorySegment().get(i));
}
dataBuffer.recycleBuffer();
}

@Test
public void testPackMultipleBuffers() throws Exception {
int numBuffers = 7;
Expand Down Expand Up @@ -404,4 +428,28 @@ public BufferPacker createBufferPakcer(
return new ReceivedNoHeaderBufferPacker(ripeBufferHandler);
}
}

public ByteBuf createBufferHeaderByteBuf(int dataBufferSize) {
ByteBuf headerBuf = Unpooled.directBuffer(BufferUtils.HEADER_LENGTH, BufferUtils.HEADER_LENGTH);
// write celeborn buffer header (subpartitionid(4) + attemptId(4) + nextBatchId(4) +
// compressedsize)
headerBuf.writeInt(0);
headerBuf.writeInt(0);
headerBuf.writeInt(0);
headerBuf.writeInt(
dataBufferSize + (BufferUtils.HEADER_LENGTH - BufferUtils.HEADER_LENGTH_PREFIX));

// write flink buffer header (dataType(1) + isCompress(1) + size(4))
headerBuf.writeByte(DATA_BUFFER.ordinal());
headerBuf.writeBoolean(false);
headerBuf.writeInt(dataBufferSize);
return headerBuf;
}

public void fillBufferWithRandomByte(Buffer buffer) {
Random random = new Random();
for (int i = 0; i < buffer.getMaxCapacity(); i++) {
buffer.asByteBuf().writeByte(random.nextInt(255));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void setup() throws IOException, InterruptedException {
conf = new CelebornConf();
shuffleClient =
new FlinkShuffleClientImpl(
"APP", "localhost", 1232, System.currentTimeMillis(), conf, null) {
"APP", "localhost", 1232, System.currentTimeMillis(), conf, null, -1) {
@Override
public void setupLifecycleManagerRef(String host, int port) {}
};
Expand Down
Loading

0 comments on commit 7ebd168

Please sign in to comment.