diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cb8ad90d..4f1b470b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -450,6 +450,7 @@ if(NANOARROW_BUILD_TESTS) endforeach() target_link_libraries(nanoarrow_ipc_files_test nlohmann_json ZLIB::ZLIB) + target_link_libraries(nanoarrow_ipc_decoder_test gmock_main) endif() if(NANOARROW_DEVICE) diff --git a/ci/scripts/bundle.py b/ci/scripts/bundle.py index 7ee09437f..c80172939 100644 --- a/ci/scripts/bundle.py +++ b/ci/scripts/bundle.py @@ -205,6 +205,7 @@ def bundle_nanoarrow_ipc( src_dir / "ipc" / "decoder.c", src_dir / "ipc" / "encoder.c", src_dir / "ipc" / "reader.c", + src_dir / "ipc" / "encoder.c", ] ) nanoarrow_ipc_c = nanoarrow_ipc_c.replace( diff --git a/src/nanoarrow/ipc/decoder.c b/src/nanoarrow/ipc/decoder.c index 02d345412..723ec03de 100644 --- a/src/nanoarrow/ipc/decoder.c +++ b/src/nanoarrow/ipc/decoder.c @@ -1654,9 +1654,16 @@ static ArrowErrorCode ArrowIpcDecoderDecodeArrayViewInternal( return EINVAL; } + // RecordBatch messages don't count the root node but decoder->fields does + // (decoder->fields[0] is the root field) + if (field_i + 1 >= private_data->n_fields) { + ArrowErrorSet(error, "cannot decode column %" PRId64 "; there are only %" PRId64, + field_i, private_data->n_fields - 1); + return EINVAL; + } + ns(RecordBatch_table_t) batch = (ns(RecordBatch_table_t))private_data->last_message; - // RecordBatch messages don't count the root node but decoder->fields does struct ArrowIpcField* root = private_data->fields + field_i + 1; struct ArrowIpcArraySetter setter; diff --git a/src/nanoarrow/ipc/decoder_test.cc b/src/nanoarrow/ipc/decoder_test.cc index 45f4f53b4..73c15f017 100644 --- a/src/nanoarrow/ipc/decoder_test.cc +++ b/src/nanoarrow/ipc/decoder_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include // For bswap32() @@ -763,6 +764,132 @@ TEST_P(ArrowTypeParameterizedTestFixture, NanoarrowIpcArrowArrayRoundtrip) { ArrowIpcDecoderReset(&decoder); } +struct ArrowArrayViewEqualTo { + const struct ArrowArrayView* expected; + + using is_gtest_matcher = void; + + bool MatchAndExplain(const struct ArrowArrayView* actual, std::ostream* os) const { + return MatchAndExplain({}, actual, expected, os); + } + + static bool MatchAndExplain(std::vector field_path, + const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, std::ostream* os) { + auto prefixed = [&]() -> std::ostream& { + if (!field_path.empty()) { + for (int i : field_path) { + *os << "." << i; + } + *os << ":"; + } + return *os; + }; + + NANOARROW_DCHECK(actual->offset == 0); + NANOARROW_DCHECK(expected->offset == 0); + + if (actual->length != expected->length) { + prefixed() << "expected length=" << expected->length << "\n"; + prefixed() << " actual length=" << actual->length << "\n"; + return false; + } + + auto null_count = [](const struct ArrowArrayView* a) { + return a->null_count != -1 ? a->null_count : ArrowArrayViewComputeNullCount(a); + }; + if (null_count(actual) != null_count(expected)) { + prefixed() << "expected null_count=" << null_count(expected) << "\n"; + prefixed() << " actual null_count=" << null_count(actual) << "\n"; + return false; + } + + for (int64_t i = 0; actual->layout.buffer_type[i] != NANOARROW_BUFFER_TYPE_NONE && + i < NANOARROW_MAX_FIXED_BUFFERS; + ++i) { + auto a_buf = actual->buffer_views[i]; + auto e_buf = expected->buffer_views[i]; + if (a_buf.size_bytes != e_buf.size_bytes) { + prefixed() << "expected buffer[" << i << "].size=" << e_buf.size_bytes << "\n"; + prefixed() << " actual buffer[" << i << "].size=" << a_buf.size_bytes << "\n"; + return false; + } + if (memcmp(a_buf.data.data, e_buf.data.data, a_buf.size_bytes) != 0) { + prefixed() << "expected buffer[" << i << "]'s data to match\n"; + return false; + } + } + + field_path.push_back(0); + for (int64_t i = 0; i < actual->n_children; ++i) { + field_path.back() = i; + if (!MatchAndExplain(field_path, actual->children[i], expected->children[i], os)) { + return false; + } + } + return true; + } + + void DescribeTo(std::ostream* os) const { *os << "is equivalent to the array view"; } + void DescribeNegationTo(std::ostream* os) const { + *os << "is not equivalent to the array view"; + } +}; + +TEST_P(ArrowTypeParameterizedTestFixture, NanoarrowIpcNanoarrowArrayRoundtrip) { + struct ArrowError error; + nanoarrow::UniqueSchema schema; + ASSERT_TRUE( + arrow::ExportSchema(arrow::Schema({arrow::field("", GetParam())}), schema.get()) + .ok()); + + // now make one empty struct array with this schema and another with all zeroes + nanoarrow::UniqueArray empty_array, zero_array; + for (auto* array : {empty_array.get(), zero_array.get()}) { + ASSERT_EQ(ArrowArrayInitFromSchema(array, schema.get(), nullptr), NANOARROW_OK); + ASSERT_EQ(ArrowArrayStartAppending(array), NANOARROW_OK); + if (array == zero_array.get()) { + ASSERT_EQ(ArrowArrayAppendEmpty(array, 5), NANOARROW_OK); + } + ASSERT_EQ(ArrowArrayFinishBuildingDefault(array, nullptr), NANOARROW_OK); + + nanoarrow::UniqueArrayView array_view; + ASSERT_EQ(ArrowArrayViewInitFromSchema(array_view.get(), schema.get(), &error), + NANOARROW_OK); + ASSERT_EQ(ArrowArrayViewSetArray(array_view.get(), array, &error), NANOARROW_OK) + << error.message; + + nanoarrow::ipc::UniqueEncoder encoder; + EXPECT_EQ(ArrowIpcEncoderInit(encoder.get()), NANOARROW_OK); + + nanoarrow::UniqueBuffer buffer, body_buffer; + ArrowIpcEncoderBuildContiguousBodyBuffer(encoder.get(), body_buffer.get()); + EXPECT_EQ(ArrowIpcEncoderEncodeRecordBatch(encoder.get(), array_view.get(), &error), + NANOARROW_OK) + << error.message; + EXPECT_EQ( + ArrowIpcEncoderFinalizeBuffer(encoder.get(), /*encapsulate=*/true, buffer.get()), + NANOARROW_OK); + + nanoarrow::ipc::UniqueDecoder decoder; + ArrowIpcDecoderInit(decoder.get()); + EXPECT_EQ(ArrowIpcDecoderSetSchema(decoder.get(), schema.get(), &error), NANOARROW_OK) + << error.message; + EXPECT_EQ(ArrowIpcDecoderDecodeHeader(decoder.get(), + {buffer->data, buffer->size_bytes}, &error), + NANOARROW_OK) + << error.message; + + struct ArrowArrayView* roundtripped; + ASSERT_EQ(ArrowIpcDecoderDecodeArrayView(decoder.get(), + {body_buffer->data, body_buffer->size_bytes}, + -1, &roundtripped, nullptr), + NANOARROW_OK); + + EXPECT_THAT(roundtripped, ArrowArrayViewEqualTo{array_view.get()}); + } +} + INSTANTIATE_TEST_SUITE_P( NanoarrowIpcTest, ArrowTypeParameterizedTestFixture, ::testing::Values( diff --git a/src/nanoarrow/ipc/encoder.c b/src/nanoarrow/ipc/encoder.c index 6813f61fa..eeb303a15 100644 --- a/src/nanoarrow/ipc/encoder.c +++ b/src/nanoarrow/ipc/encoder.c @@ -52,6 +52,9 @@ ArrowErrorCode ArrowIpcEncoderInit(struct ArrowIpcEncoder* encoder) { encoder->private_data = ArrowMalloc(sizeof(struct ArrowIpcEncoderPrivate)); struct ArrowIpcEncoderPrivate* private = (struct ArrowIpcEncoderPrivate*)encoder->private_data; + if (private == NULL) { + return ENOMEM; + } if (flatcc_builder_init(&private->builder) == -1) { ArrowFree(private); return ESPIPE; @@ -65,10 +68,12 @@ void ArrowIpcEncoderReset(struct ArrowIpcEncoder* encoder) { NANOARROW_DCHECK(encoder != NULL && encoder->private_data != NULL); struct ArrowIpcEncoderPrivate* private = (struct ArrowIpcEncoderPrivate*)encoder->private_data; - flatcc_builder_clear(&private->builder); - ArrowBufferReset(&private->nodes); - ArrowBufferReset(&private->buffers); - ArrowFree(private); + if (private != NULL) { + flatcc_builder_clear(&private->builder); + ArrowBufferReset(&private->nodes); + ArrowBufferReset(&private->buffers); + ArrowFree(private); + } memset(encoder, 0, sizeof(struct ArrowIpcEncoder)); } @@ -422,3 +427,139 @@ ArrowErrorCode ArrowIpcEncoderEncodeSchema(struct ArrowIpcEncoder* encoder, FLATCC_RETURN_UNLESS_0(Message_bodyLength_add(builder, 0)); return ns(Message_end_as_root(builder)) ? NANOARROW_OK : ENOMEM; } + +static ArrowErrorCode ArrowIpcEncoderBuildContiguousBodyBufferCallback( + struct ArrowBufferView buffer_view, struct ArrowIpcEncoder* encoder, int64_t* offset, + int64_t* length, struct ArrowError* error) { + struct ArrowIpcEncoderPrivate* private = + (struct ArrowIpcEncoderPrivate*)encoder->private_data; + struct ArrowBuffer* body_buffer = (struct ArrowBuffer*)encoder->encode_buffer_state; + + int compressed_buffer_header = + encoder->codec != NANOARROW_IPC_COMPRESSION_TYPE_NONE ? sizeof(int64_t) : 0; + int64_t old_size = body_buffer->size_bytes; + int64_t buffer_begin = _ArrowRoundUpToMultipleOf8(old_size); + int64_t buffer_end = buffer_begin + compressed_buffer_header + buffer_view.size_bytes; + int64_t new_size = _ArrowRoundUpToMultipleOf8(buffer_end); + + // reserve all the memory we'll need now + NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(body_buffer, new_size - old_size)); + + // zero padding up to the start of the buffer + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFill(body_buffer, 0, buffer_begin - old_size)); + + // store offset and length of the buffer + *offset = buffer_begin; + *length = buffer_view.size_bytes; + + if (compressed_buffer_header) { + // Signal that the buffer is not compressed; eventually we will set this to the + // decompressed length of an actually compressed buffer. + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt64(body_buffer, -1)); + } + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(body_buffer, buffer_view.data.data, buffer_view.size_bytes)); + + // zero padding after writing the buffer + NANOARROW_DCHECK(body_buffer->size_bytes == buffer_end); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFill(body_buffer, 0, new_size - buffer_end)); + + encoder->body_length = body_buffer->size_bytes; + return NANOARROW_OK; +} + +void ArrowIpcEncoderBuildContiguousBodyBuffer(struct ArrowIpcEncoder* encoder, + struct ArrowBuffer* body_buffer) { + NANOARROW_DCHECK(encoder != NULL && encoder->private_data != NULL && + body_buffer != NULL); + struct ArrowIpcEncoderPrivate* private = + (struct ArrowIpcEncoderPrivate*)encoder->private_data; + encoder->encode_buffer = &ArrowIpcEncoderBuildContiguousBodyBufferCallback; + encoder->encode_buffer_state = body_buffer; +} + +static ArrowErrorCode ArrowIpcEncoderEncodeRecordBatchImpl( + struct ArrowIpcEncoder* encoder, const struct ArrowArrayView* array_view, + struct ArrowBuffer* buffers, struct ArrowBuffer* nodes, struct ArrowError* error) { + if (array_view->offset != 0) { + ArrowErrorSet(error, "Cannot encode arrays with nonzero offset"); + return ENOTSUP; + } + + for (int64_t c = 0; c < array_view->n_children; ++c) { + const struct ArrowArrayView* child = array_view->children[c]; + + struct ns(FieldNode) node = {child->length, child->null_count}; + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(nodes, &node, sizeof(node))); + + for (int64_t b = 0; b < child->array->n_buffers; ++b) { + struct ns(Buffer) buffer; + NANOARROW_RETURN_NOT_OK(encoder->encode_buffer( + child->buffer_views[b], encoder, &buffer.offset, &buffer.length, error)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(buffers, &buffer, sizeof(buffer))); + } + + NANOARROW_RETURN_NOT_OK( + ArrowIpcEncoderEncodeRecordBatchImpl(encoder, child, buffers, nodes, error)); + } + return NANOARROW_OK; +} + +ArrowErrorCode ArrowIpcEncoderEncodeRecordBatch(struct ArrowIpcEncoder* encoder, + const struct ArrowArrayView* array_view, + struct ArrowError* error) { + NANOARROW_DCHECK(encoder != NULL && encoder->private_data != NULL && schema != NULL); + + if (array_view->null_count != 0 && ArrowArrayViewComputeNullCount(array_view) != 0) { + ArrowErrorSet(error, + "RecordBatches cannot be constructed from arrays with top level nulls"); + return EINVAL; + } + + if (array_view->storage_type != NANOARROW_TYPE_STRUCT) { + ArrowErrorSet( + error, + "RecordBatches cannot be constructed from arrays of type other than struct"); + return EINVAL; + } + + if (!encoder->encode_buffer) { + ArrowErrorSet(error, "No encode_buffer behavior provided when encoding RecordBatch"); + return EINVAL; + } + + struct ArrowIpcEncoderPrivate* private = + (struct ArrowIpcEncoderPrivate*)encoder->private_data; + + flatcc_builder_t* builder = &private->builder; + + FLATCC_RETURN_UNLESS_0(Message_start_as_root(builder)); + FLATCC_RETURN_UNLESS_0(Message_version_add(builder, ns(MetadataVersion_V5))); + + encoder->body_length = 0; + + FLATCC_RETURN_UNLESS_0(Message_header_RecordBatch_start(builder)); + if (encoder->codec != NANOARROW_IPC_COMPRESSION_TYPE_NONE) { + FLATCC_RETURN_UNLESS_0(RecordBatch_compression_start(builder)); + FLATCC_RETURN_UNLESS_0(BodyCompression_codec_add(builder, encoder->codec)); + FLATCC_RETURN_UNLESS_0(RecordBatch_compression_end(builder)); + } + FLATCC_RETURN_UNLESS_0(RecordBatch_length_add(builder, array_view->length)); + + ArrowBufferResize(&private->buffers, 0, 0); + ArrowBufferResize(&private->nodes, 0, 0); + NANOARROW_RETURN_NOT_OK(ArrowIpcEncoderEncodeRecordBatchImpl( + encoder, array_view, &private->buffers, &private->nodes, error)); + + FLATCC_RETURN_UNLESS_0(RecordBatch_nodes_create( // + builder, (struct ns(FieldNode)*)private->nodes.data, + private->nodes.size_bytes / sizeof(struct ns(FieldNode)))); + FLATCC_RETURN_UNLESS_0(RecordBatch_buffers_create( // + builder, (struct ns(Buffer)*)private->buffers.data, + private->buffers.size_bytes / sizeof(struct ns(Buffer)))); + + FLATCC_RETURN_UNLESS_0(Message_header_RecordBatch_end(builder)); + + FLATCC_RETURN_UNLESS_0(Message_bodyLength_add(builder, encoder->body_length)); + return ns(Message_end_as_root(builder)) ? NANOARROW_OK : ENOMEM; +} diff --git a/src/nanoarrow/ipc/ipc_hpp_test.cc b/src/nanoarrow/ipc/ipc_hpp_test.cc index ec3af84ba..10a41b94c 100644 --- a/src/nanoarrow/ipc/ipc_hpp_test.cc +++ b/src/nanoarrow/ipc/ipc_hpp_test.cc @@ -31,6 +31,18 @@ TEST(NanoarrowIpcHppTest, NanoarrowIpcHppTestUniqueDecoder) { EXPECT_EQ(decoder->private_data, nullptr); } +TEST(NanoarrowIpcHppTest, NanoarrowIpcHppTestUniqueEncoder) { + nanoarrow::ipc::UniqueEncoder encoder; + + EXPECT_EQ(encoder->private_data, nullptr); + ASSERT_EQ(ArrowIpcEncoderInit(encoder.get()), NANOARROW_OK); + EXPECT_NE(encoder->private_data, nullptr); + + nanoarrow::ipc::UniqueEncoder encoder2 = std::move(encoder); + EXPECT_NE(encoder2->private_data, nullptr); + EXPECT_EQ(encoder->private_data, nullptr); +} + TEST(NanoarrowIpcHppTest, NanoarrowIpcHppTestUniqueInputStream) { nanoarrow::ipc::UniqueInputStream input; nanoarrow::UniqueBuffer buf; diff --git a/src/nanoarrow/nanoarrow_ipc.h b/src/nanoarrow/nanoarrow_ipc.h index c74e288a2..e0426c604 100644 --- a/src/nanoarrow/nanoarrow_ipc.h +++ b/src/nanoarrow/nanoarrow_ipc.h @@ -63,6 +63,10 @@ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowIpcEncoderFinalizeBuffer) #define ArrowIpcEncoderEncodeSchema \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowIpcEncoderEncodeSchema) +#define ArrowIpcEncoderBuildContiguousBodyBuffer \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowIpcEncoderBuildContiguousBodyBuffer) +#define ArrowIpcEncoderEncodeRecordBatch \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowIpcEncoderEncodeRecordBatch) #endif @@ -208,7 +212,7 @@ void ArrowIpcDecoderReset(struct ArrowIpcDecoder* decoder); /// \brief Peek at a message header /// -/// The first 8 bytes of an Arrow IPC message are 0xFFFFFF followed by the size +/// The first 8 bytes of an Arrow IPC message are 0xFFFFFFFF followed by the size /// of the header as a little-endian 32-bit integer. ArrowIpcDecoderPeekHeader() reads /// these bytes and returns ESPIPE if there are not enough remaining bytes in data to read /// the entire header message, EINVAL if the first 8 bytes are not valid, ENODATA if the @@ -453,7 +457,21 @@ ArrowErrorCode ArrowIpcEncoderFinalizeBuffer(struct ArrowIpcEncoder* encoder, ArrowErrorCode ArrowIpcEncoderEncodeSchema(struct ArrowIpcEncoder* encoder, const struct ArrowSchema* schema, struct ArrowError* error); - +/// \brief Set the encoder to concatenate encoded buffers into body_buffer +/// +/// encoder->encode_buffer_state will point to the provided ArrowBuffer. +/// The contiguous body buffer will be appended to this during +/// ArrowIpcEncoderEncodeRecordBatch. +void ArrowIpcEncoderBuildContiguousBodyBuffer(struct ArrowIpcEncoder* encoder, + struct ArrowBuffer* body_buffer); + +/// \brief Encode a struct typed ArrayView to a flatbuffer RecordBatch, embedded in a +/// Message. +/// +/// Returns ENOMEM if allocation fails, NANOARROW_OK otherwise. +ArrowErrorCode ArrowIpcEncoderEncodeRecordBatch(struct ArrowIpcEncoder* encoder, + const struct ArrowArrayView* array_view, + struct ArrowError* error); /// @} #ifdef __cplusplus