Skip to content

Commit

Permalink
[SPARK-44657][CONNECT] Fix incorrect limit handling in ArrowBatchWith…
Browse files Browse the repository at this point in the history
…SchemaIterator and config parsing of CONNECT_GRPC_ARROW_MAX_BATCH_SIZE

### What changes were proposed in this pull request?

Fixes the limit checking of `maxEstimatedBatchSize` and `maxRecordsPerBatch` to respect the more restrictive limit and fixes the config parsing of `CONNECT_GRPC_ARROW_MAX_BATCH_SIZE` by converting the value to bytes.

### Why are the changes needed?

Bugfix.
In the arrow writer [code](https://github.com/apache/spark/blob/6161bf44f40f8146ea4c115c788fd4eaeb128769/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala#L154-L163) , the conditions don’t seem to hold what the documentation says regd "maxBatchSize and maxRecordsPerBatch, respect whatever smaller" since it seems to actually respect the conf which is "larger" (i.e less restrictive) due to || operator.

Further, when the `CONNECT_GRPC_ARROW_MAX_BATCH_SIZE` conf is read, the value is not converted to bytes from MiB ([example](https://github.com/apache/spark/blob/3e5203c64c06cc8a8560dfa0fb6f52e74589b583/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala#L103)).

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes apache#42321 from vicennial/SPARK-44657.

Authored-by: vicennial <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
vicennial authored and HyukjinKwon committed Aug 8, 2023
1 parent f7879b4 commit f9d417f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ object Connect {
val CONNECT_GRPC_ARROW_MAX_BATCH_SIZE =
ConfigBuilder("spark.connect.grpc.arrow.maxBatchSize")
.doc(
"When using Apache Arrow, limit the maximum size of one arrow batch that " +
"can be sent from server side to client side. Currently, we conservatively use 70% " +
"of it because the size is not accurate but estimated.")
"When using Apache Arrow, limit the maximum size of one arrow batch, in bytes unless " +
"otherwise specified, that can be sent from server side to client side. Currently, we " +
"conservatively use 70% of it because the size is not accurate but estimated.")
.version("3.4.0")
.bytesConf(ByteUnit.MiB)
.createWithDefaultString("4m")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(4 * 1024 * 1024)

val CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE =
ConfigBuilder("spark.connect.grpc.maxInboundMessageSize")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,67 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
}
}

test("SPARK-44657: Arrow batches respect max batch size limit") {
// Set 10 KiB as the batch size limit
val batchSize = 10 * 1024
withSparkConf("spark.connect.grpc.arrow.maxBatchSize" -> batchSize.toString) {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
.newBuilder()
.setUserId("c1")
.build()
val plan = proto.Plan
.newBuilder()
.setRoot(connect.sql("select * from range(0, 15000, 1, 1)"))
.build()
val request = proto.ExecutePlanRequest
.newBuilder()
.setPlan(plan)
.setUserContext(context)
.setSessionId(UUID.randomUUID.toString())
.build()

// Execute plan.
@volatile var done = false
val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
instance.executePlan(
request,
new StreamObserver[proto.ExecutePlanResponse] {
override def onNext(v: proto.ExecutePlanResponse): Unit = {
responses += v
}

override def onError(throwable: Throwable): Unit = {
throw throwable
}

override def onCompleted(): Unit = {
done = true
}
})
// The current implementation is expected to be blocking. This is here to make sure it is.
assert(done)

// 1 schema + 1 metric + at least 2 data batches
assert(responses.size > 3)

val allocator = new RootAllocator()

// Check the 'data' batches
responses.tail.dropRight(1).foreach { response =>
assert(response.hasArrowBatch)
val batch = response.getArrowBatch
assert(batch.getData != null)
// Batch size must be <= 70% since we intentionally use this multiplier for the size
// estimator.
assert(batch.getData.size() <= batchSize * 0.7)
}
}
}

gridTest("SPARK-43923: commands send events")(
Seq(
proto.Command
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,22 @@ private[sql] object ArrowConverters extends Logging {
// Always write the schema.
MessageSerializer.serialize(writeChannel, arrowSchema)

def isBatchSizeLimitExceeded: Boolean = {
// If `maxEstimatedBatchSize` is zero or negative, it implies unlimited.
maxEstimatedBatchSize > 0 && estimatedBatchSize >= maxEstimatedBatchSize
}
def isRecordLimitExceeded: Boolean = {
// If `maxRecordsPerBatch` is zero or negative, it implies unlimited.
maxRecordsPerBatch > 0 && rowCountInLastBatch >= maxRecordsPerBatch
}
// Always write the first row.
while (rowIter.hasNext && (
// For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
// If the size in bytes is positive (set properly), always write the first row.
rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
// If the size in bytes of rows are 0 or negative, unlimit it.
estimatedBatchSize <= 0 ||
estimatedBatchSize < maxEstimatedBatchSize ||
// If the size of rows are 0 or negative, unlimit it.
maxRecordsPerBatch <= 0 ||
rowCountInLastBatch < maxRecordsPerBatch)) {
(rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0) ||
// If either limit is hit, create a batch. This implies that the limit that is hit first
// triggers the creation of a batch even if the other limit is not yet hit, hence
// preferring the more restrictive limit.
(!isBatchSizeLimitExceeded && !isRecordLimitExceeded))) {
val row = rowIter.next()
arrowWriter.write(row)
estimatedBatchSize += (row match {
Expand Down

0 comments on commit f9d417f

Please sign in to comment.