Skip to content

Commit

Permalink
[CELEBORN-980] Asynchronously delete original files to fix `ReusedExc…
Browse files Browse the repository at this point in the history
…hange` bug

### What changes were proposed in this pull request?

The `ReusedExchange` operator has the potential to generate different types of fetch requests, including both non-range and range requests. Currently, an issue arises due to the synchronous deletion of the original file by the Celeborn worker upon completion of sorting. This issue leads to the failure of non-range requests following a range request for the same partition.

the snippets to reproduce this bug
```scala
  val sparkConf = new SparkConf().setAppName("celeborn-test").setMaster("local[2]")
    .set("spark.shuffle.manager", "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
    .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", masterInfo._1.rpcEnv.address.toString)
    .set("spark.sql.autoBroadcastJoinThreshold", "-1")
    .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "100")
    .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "100")
  val spark = SparkSession.builder()
    .config(sparkConf)
    .getOrCreate()
  spark.range(0, 1000, 1, 10)
    .selectExpr("id as k1", "id as v1")
    .createOrReplaceTempView("ta")
  spark.range(0, 1000, 1, 10)
    .selectExpr("id % 1 as k21", "id % 1 as k22", "id as v2")
    .createOrReplaceTempView("tb")
  spark.range(140)
    .select(
      col("id").cast("long").as("k3"),
      concat(col("id").cast("string"), lit("a")).as("v3"))
    .createOrReplaceTempView("tc")

  spark.sql(
    """
      |SELECT *
      |FROM ta
      |LEFT JOIN tb ON ta.k1 = tb.k21
      |LEFT JOIN tc ON tb.k22 = tc.k3
      |""".stripMargin)
    .createOrReplaceTempView("v1")

  spark.sql(
    """
      |SELECT * FROM v1 WHERE v3 IS NOT NULL
      |UNION
      |SELECT * FROM v1
      |""".stripMargin)
    .collect()
```

This PR proposes a solution to address this problem. It introduces an asynchronous thread for the removal of the original file. Once the sorted file is generated for a given partition, this modification ensures that both non-range and range fetch requests will be able to and only fetch the sorted file once it is generated for a given partition.

this activity diagram of `openStream`

![openStream](https://github.com/apache/incubator-celeborn/assets/8537877/633cc5b8-e673-45a0-860e-e1f7e50c8965)

### Does this PR introduce _any_ user-facing change?

No, only bug fix

### How was this patch tested?

UT

Closes apache#1932 from cfmcgrady/fix-partition-sort-bug-v4.

Authored-by: Fu Chen <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
  • Loading branch information
cfmcgrady authored and waitinfuture committed Oct 9, 2023
1 parent a734b8c commit c4135dc
Show file tree
Hide file tree
Showing 21 changed files with 909 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.StreamType;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;

public class CelebornBufferStream {
Expand Down Expand Up @@ -138,7 +139,11 @@ private void closeStream(long streamId) {
client.sendRpc(
new TransportMessage(
MessageType.BUFFER_STREAM_END,
PbBufferStreamEnd.newBuilder().setStreamId(streamId).build().toByteArray())
PbBufferStreamEnd.newBuilder()
.setStreamType(StreamType.CreditStream)
.setStreamId(streamId)
.build()
.toByteArray())
.toByteBuffer());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.StreamType;
import org.apache.celeborn.common.util.ShuffleBlockInfoUtils;
import org.apache.celeborn.common.util.Utils;

Expand All @@ -58,6 +61,8 @@ public class DfsPartitionReader implements PartitionReader {
private int numChunks = 0;
private int returnedChunks = 0;
private int currentChunkIndex = 0;
private TransportClient client;
private PbStreamHandler streamHandler;

public DfsPartitionReader(
CelebornConf conf,
Expand All @@ -77,8 +82,7 @@ public DfsPartitionReader(
if (endMapIndex != Integer.MAX_VALUE) {
long fetchTimeoutMs = conf.clientFetchTimeoutMs();
try {
TransportClient client =
clientFactory.createClient(location.getHost(), location.getFetchPort());
client = clientFactory.createClient(location.getHost(), location.getFetchPort());
TransportMessage openStream =
new TransportMessage(
MessageType.OPEN_STREAM,
Expand All @@ -90,7 +94,7 @@ public DfsPartitionReader(
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStream.toByteBuffer(), fetchTimeoutMs);
TransportMessage.fromByteBuffer(response).getParsedPayload();
streamHandler = TransportMessage.fromByteBuffer(response).getParsedPayload();
// Parse this message to ensure sort is done.
} catch (IOException | InterruptedException e) {
throw new IOException(
Expand Down Expand Up @@ -258,6 +262,21 @@ public void close() {
results.forEach(ReferenceCounted::release);
}
results.clear();
closeStream();
}

private void closeStream() {
if (client != null && client.isActive()) {
TransportMessage bufferStreamEnd =
new TransportMessage(
MessageType.BUFFER_STREAM_END,
PbBufferStreamEnd.newBuilder()
.setStreamType(StreamType.ChunkStream)
.setStreamId(streamHandler.getStreamId())
.build()
.toByteArray());
client.sendRpc(bufferStreamEnd.toByteBuffer());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.StreamType;
import org.apache.celeborn.common.util.FileChannelUtils;
import org.apache.celeborn.common.util.ThreadUtils;

Expand All @@ -63,6 +65,8 @@ public class LocalPartitionReader implements PartitionReader {
private FileChannel shuffleChannel;
private List<Long> chunkOffsets;
private AtomicBoolean pendingFetchTask = new AtomicBoolean(false);
private PbStreamHandler streamHandler;
private TransportClient client;

public LocalPartitionReader(
CelebornConf conf,
Expand All @@ -84,11 +88,9 @@ public LocalPartitionReader(
fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
results = new LinkedBlockingQueue<>();
this.location = location;
PbStreamHandler streamHandle;
long fetchTimeoutMs = conf.clientFetchTimeoutMs();
try {
TransportClient client =
clientFactory.createClient(location.getHost(), location.getFetchPort(), 0);
client = clientFactory.createClient(location.getHost(), location.getFetchPort(), 0);
TransportMessage openStreamMsg =
new TransportMessage(
MessageType.OPEN_STREAM,
Expand All @@ -101,7 +103,7 @@ public LocalPartitionReader(
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs);
streamHandle = TransportMessage.fromByteBuffer(response).getParsedPayload();
streamHandler = TransportMessage.fromByteBuffer(response).getParsedPayload();
} catch (IOException | InterruptedException e) {
throw new IOException(
"Read shuffle file from local file failed, partition location: "
Expand All @@ -111,9 +113,9 @@ public LocalPartitionReader(
e);
}

chunkOffsets = new ArrayList<>(streamHandle.getChunkOffsetsList());
numChunks = streamHandle.getNumChunks();
fullPath = streamHandle.getFullPath();
chunkOffsets = new ArrayList<>(streamHandler.getChunkOffsetsList());
numChunks = streamHandler.getNumChunks();
fullPath = streamHandler.getFullPath();
mapRangeRead = endMapIndex != Integer.MAX_VALUE;

logger.debug(
Expand Down Expand Up @@ -231,6 +233,21 @@ public void close() {
} catch (IOException e) {
logger.warn("Close local shuffle file failed.", e);
}
closeStream();
}

private void closeStream() {
if (client != null && client.isActive()) {
TransportMessage bufferStreamEnd =
new TransportMessage(
MessageType.BUFFER_STREAM_END,
PbBufferStreamEnd.newBuilder()
.setStreamType(StreamType.ChunkStream)
.setStreamId(streamHandler.getStreamId())
.build()
.toByteArray());
client.sendRpc(bufferStreamEnd.toByteBuffer());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.StreamType;
import org.apache.celeborn.common.util.ExceptionUtils;

public class WorkerPartitionReader implements PartitionReader {
private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class);
private PartitionLocation location;
private final TransportClientFactory clientFactory;
private PbStreamHandler streamHandle;
private PbStreamHandler streamHandler;
private TransportClient client;

private int returnedChunks;
private int chunkIndex;
Expand Down Expand Up @@ -100,7 +103,6 @@ public void onFailure(int chunkIndex, Throwable e) {
exception.set(new CelebornIOException(errorMsg, e));
}
};
TransportClient client = null;
try {
client = clientFactory.createClient(location.getHost(), location.getFetchPort());
} catch (InterruptedException ie) {
Expand All @@ -119,7 +121,7 @@ public void onFailure(int chunkIndex, Throwable e) {
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs);
streamHandle = TransportMessage.fromByteBuffer(response).getParsedPayload();
streamHandler = TransportMessage.fromByteBuffer(response).getParsedPayload();

this.location = location;
this.clientFactory = clientFactory;
Expand All @@ -130,12 +132,12 @@ public void onFailure(int chunkIndex, Throwable e) {
}

public boolean hasNext() {
return returnedChunks < streamHandle.getNumChunks();
return returnedChunks < streamHandler.getNumChunks();
}

public ByteBuf next() throws IOException, InterruptedException {
checkException();
if (chunkIndex < streamHandle.getNumChunks()) {
if (chunkIndex < streamHandler.getNumChunks()) {
fetchChunks();
}
ByteBuf chunk = null;
Expand All @@ -160,6 +162,21 @@ public void close() {
results.forEach(ReferenceCounted::release);
}
results.clear();
closeStream();
}

private void closeStream() {
if (client != null && client.isActive()) {
TransportMessage bufferStreamEnd =
new TransportMessage(
MessageType.BUFFER_STREAM_END,
PbBufferStreamEnd.newBuilder()
.setStreamType(StreamType.ChunkStream)
.setStreamId(streamHandler.getStreamId())
.build()
.toByteArray());
client.sendRpc(bufferStreamEnd.toByteBuffer());
}
}

@Override
Expand All @@ -171,27 +188,28 @@ private void fetchChunks() throws IOException, InterruptedException {
final int inFlight = chunkIndex - returnedChunks;
if (inFlight < fetchMaxReqsInFlight) {
final int toFetch =
Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandle.getNumChunks() - chunkIndex);
Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandler.getNumChunks() - chunkIndex);
for (int i = 0; i < toFetch; i++) {
if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) {
callback.onFailure(chunkIndex, new CelebornIOException("Test fetch chunk failure"));
} else {
try {
TransportClient client =
clientFactory.createClient(location.getHost(), location.getFetchPort());
client.fetchChunk(streamHandle.getStreamId(), chunkIndex, fetchTimeoutMs, callback);
chunkIndex++;
} catch (IOException e) {
logger.error(
"fetchChunk for streamId: {}, chunkIndex: {} failed.",
streamHandle.getStreamId(),
chunkIndex,
e);
ExceptionUtils.wrapAndThrowIOException(e);
} catch (InterruptedException e) {
logger.error("PartitionReader thread interrupted while fetching chunks.");
throw e;
if (!client.isActive()) {
try {
client = clientFactory.createClient(location.getHost(), location.getFetchPort());
} catch (IOException e) {
logger.error(
"fetchChunk for streamId: {}, chunkIndex: {} failed.",
streamHandler.getStreamId(),
chunkIndex,
e);
ExceptionUtils.wrapAndThrowIOException(e);
} catch (InterruptedException e) {
logger.error("PartitionReader thread interrupted while fetching chunks.");
throw e;
}
}
client.fetchChunk(streamHandler.getStreamId(), chunkIndex, fetchTimeoutMs, callback);
chunkIndex++;
}
}
}
Expand Down
41 changes: 41 additions & 0 deletions common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.lang3.StringUtils;
Expand All @@ -40,6 +43,15 @@ public class FileInfo {
private final PartitionType partitionType;
private final UserIdentifier userIdentifier;

/**
* A flag used to indicate whether this FileInfo is sorted or not. Currently, it is only set for
* unsorted FileInfo instances.
*/
private final AtomicBoolean sorted = new AtomicBoolean(false);

/** The set of stream IDs that are fetching this FileInfo. */
private final Set<Long> streams = ConcurrentHashMap.newKeySet();

// members for ReducePartition
private final List<Long> chunkOffsets;

Expand Down Expand Up @@ -265,4 +277,33 @@ public boolean isPartitionSplitEnabled() {
public void setPartitionSplitEnabled(boolean partitionSplitEnabled) {
this.partitionSplitEnabled = partitionSplitEnabled;
}

public void setSorted() {
synchronized (sorted) {
sorted.set(true);
}
}

public boolean addStream(long streamId) {
synchronized (sorted) {
if (sorted.get()) {
return false;
} else {
streams.add(streamId);
return true;
}
}
}

public void closeStream(long streamId) {
synchronized (sorted) {
streams.remove(streamId);
}
}

public boolean isStreamsEmpty() {
synchronized (sorted) {
return streams.isEmpty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,18 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
}

/**
* Sends an opaque message to the RpcHandler on the server-side.
* Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the
* message, and no delivery guarantees are made.
*
* @param message The message to send.
* @return The RPC's id.
*/
public long sendRpc(ByteBuffer message) {
public void sendRpc(ByteBuffer message) {
if (logger.isTraceEnabled()) {
logger.trace("Sending RPC to {}", NettyUtils.getRemoteAddress(channel));
}

long requestId = requestId();
channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)));
return requestId;
}

public ChannelFuture pushData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.apache.celeborn.common.protocol.PbBufferStreamEnd;

@Deprecated
public class BufferStreamEnd extends RequestMessage {
private long streamId;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,7 @@

package org.apache.celeborn.common.network.protocol;

import static org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.PUSH_DATA_HAND_SHAKE_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.READ_ADD_CREDIT_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.REGION_FINISH_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.REGION_START_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.*;

import java.io.Serializable;
import java.nio.ByteBuffer;
Expand Down
Loading

0 comments on commit c4135dc

Please sign in to comment.