From 0f42aed96ca0dcc5031e01b04b297564fd97647c Mon Sep 17 00:00:00 2001 From: RahulSudarMCW Date: Tue, 24 Dec 2024 10:58:01 +0530 Subject: [PATCH] Add EvenSplitN subgraph - Replaced existing even-split2, even-split3, even-split4, tests with a single test covering the EvenSplitN API. - Marked old functions as XNN_DEPRECATED. - Added shims in deprecated.c to call the new EvenSplitN API functions. - Ensured the subgraph API remains stable with the new implementation. --- CMakeLists.txt | 1 + include/xnnpack.h | 31 +- src/runtime.c | 1 + src/subgraph.c | 1 + src/subgraph/deprecated.c | 20 + src/subgraph/even-split.c | 205 ++-------- src/xnnpack/node-type-defs.h | 1 + test/BUILD.bazel | 5 +- test/even-split.cc | 719 +++++++++++++++++++++++++++++++++++ test/fusion.cc | 2 +- test/subgraph-tester.h | 18 +- 11 files changed, 814 insertions(+), 190 deletions(-) create mode 100644 test/even-split.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a9fad59a09..9e0a40bb616 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1349,6 +1349,7 @@ IF(XNNPACK_BUILD_TESTS) even-split2 even-split3 even-split4 + even-split fully-connected global-average-pooling-1d global-average-pooling-2d diff --git a/include/xnnpack.h b/include/xnnpack.h index 4e72df94072..d51e261a5f4 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -1638,6 +1638,31 @@ enum xnn_status xnn_define_copy( uint32_t output_id, uint32_t flags); +/// Define a n-Output Split Node and add it to a Subgraph. +/// +/// The n-Output Split Node splits an input tensor into n output tensors along a specified axis evenly. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a +/// subgraph. +/// @param num_outputs - The number of output tensors to generate. The input tensor will be evenly split into +/// this number of output tensors along the `split_dim`. Each output tensor will have +/// the same dimensions as the input tensor, except for the `split_dim`, which will be +/// divided evenly between the outputs. +/// @param outputs - An array of Value IDs for the output tensors. Each output tensor must be an N-dimensional +/// tensor defined in the @a subgraph with the same shape as the input tensor, except along the +/// `split_dim` dimension, which will be split evenly among the output tensors. The number of +/// output tensors corresponds to the value of `num_outputs`. +enum xnn_status xnn_define_even_split( + xnn_subgraph_t subgraph, + int32_t split_dim, + uint32_t input_id, + uint32_t num_outputs, + const uint32_t* outputs, + uint32_t flags); + /// Define a 2-Output Split Node and add it to a Subgraph. /// /// The 2-Output Split Node splits an input tensor into two output tensors along a specified axis evenly. @@ -1654,7 +1679,7 @@ enum xnn_status xnn_define_copy( /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding /// dimension of the first output. The split_dim dimension is half of the input's split_dim. /// @param flags - binary features of the Split Node. No supported flags are currently defined. -enum xnn_status xnn_define_even_split2( +XNN_DEPRECATED enum xnn_status xnn_define_even_split2( xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, @@ -1683,7 +1708,7 @@ enum xnn_status xnn_define_even_split2( /// dimension of the second and third output. The split_dim dimension is one third of the input's /// split_dim. /// @param flags - binary features of the Split Node. No supported flags are currently defined. -enum xnn_status xnn_define_even_split3( +XNN_DEPRECATED enum xnn_status xnn_define_even_split3( xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, @@ -1717,7 +1742,7 @@ enum xnn_status xnn_define_even_split3( /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's /// split_dim. /// @param flags - binary features of the Split Node. No supported flags are currently defined. -enum xnn_status xnn_define_even_split4( +XNN_DEPRECATED enum xnn_status xnn_define_even_split4( xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, diff --git a/src/runtime.c b/src/runtime.c index 82a7da704d3..3d03f7f5898 100644 --- a/src/runtime.c +++ b/src/runtime.c @@ -482,6 +482,7 @@ void propagate_rank( case xnn_node_type_even_split2: case xnn_node_type_even_split3: case xnn_node_type_even_split4: + case xnn_node_type_even_split: case xnn_node_type_unary_elementwise: case xnn_node_type_convert: case xnn_node_type_pack_lh: diff --git a/src/subgraph.c b/src/subgraph.c index be3ff741acc..82032c09b01 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -899,6 +899,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) case xnn_node_type_even_split2: case xnn_node_type_even_split3: case xnn_node_type_even_split4: + case xnn_node_type_even_split: case xnn_node_type_fully_connected: case xnn_node_type_global_average_pooling_2d: case xnn_node_type_global_sum_pooling_2d: diff --git a/src/subgraph/deprecated.c b/src/subgraph/deprecated.c index 601793b4655..3d5251c4700 100644 --- a/src/subgraph/deprecated.c +++ b/src/subgraph/deprecated.c @@ -335,3 +335,23 @@ enum xnn_status xnn_define_tanh(xnn_subgraph_t subgraph, uint32_t input_id, return xnn_define_unary(subgraph, xnn_unary_tanh, NULL, input_id, output_id, flags); } + +enum xnn_status xnn_define_even_split2(xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, + uint32_t output1_id, uint32_t output2_id, uint32_t flags) { + const uint32_t outputs_id[2] = {output1_id, output2_id}; + return xnn_define_even_split(subgraph, split_dim, input_id, /*num_outputs=*/2, outputs_id, flags); +} + +enum xnn_status xnn_define_even_split3(xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, + uint32_t output1_id, uint32_t output2_id, + uint32_t output3_id, uint32_t flags) { + const uint32_t outputs_id[3] = {output1_id, output2_id, output3_id}; + return xnn_define_even_split(subgraph, split_dim, input_id, /*num_outputs=*/3, outputs_id, flags); +} + +enum xnn_status xnn_define_even_split4(xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, + uint32_t output1_id, uint32_t output2_id, + uint32_t output3_id, uint32_t output4_id, uint32_t flags) { + const uint32_t outputs_id[4] = {output1_id, output2_id, output3_id, output4_id}; + return xnn_define_even_split(subgraph, split_dim, input_id, /*num_outputs=*/4, outputs_id, flags); +} diff --git a/src/subgraph/even-split.c b/src/subgraph/even-split.c index 25c3bd44a03..6a578cadfd8 100644 --- a/src/subgraph/even-split.c +++ b/src/subgraph/even-split.c @@ -47,15 +47,15 @@ static enum xnn_status create_even_split_operator_helper( } } -static enum xnn_status create_even_split_n_operator( +static enum xnn_status create_even_split_operator( const struct xnn_node* node, const struct xnn_value* values, size_t num_values, struct xnn_operator_data* opdata, struct xnn_code_cache* code_cache, - size_t num_splits, xnn_weights_cache_t weights_cache) { + size_t num_splits = opdata->num_outputs; assert(node->num_inputs == 1); assert(node->num_outputs == num_splits); enum xnn_datatype datatype = values[opdata->inputs[0]].datatype; @@ -77,39 +77,6 @@ static enum xnn_status create_even_split_n_operator( return status; } -static enum xnn_status create_even_split2_operator( - const struct xnn_node* node, - const struct xnn_value* values, - size_t num_values, - struct xnn_operator_data* opdata, - struct xnn_code_cache* code_cache, - xnn_weights_cache_t weights_cache) -{ - return create_even_split_n_operator(node, values, num_values, opdata, code_cache, /*num_splits=*/2, weights_cache); -} - -static enum xnn_status create_even_split3_operator( - const struct xnn_node* node, - const struct xnn_value* values, - size_t num_values, - struct xnn_operator_data* opdata, - struct xnn_code_cache* code_cache, - xnn_weights_cache_t weights_cache) -{ - return create_even_split_n_operator(node, values, num_values, opdata, code_cache, /*num_splits=*/3, weights_cache); -} - -static enum xnn_status create_even_split4_operator( - const struct xnn_node* node, - const struct xnn_value* values, - size_t num_values, - struct xnn_operator_data* opdata, - struct xnn_code_cache* code_cache, - xnn_weights_cache_t weights_cache) -{ - return create_even_split_n_operator(node, values, num_values, opdata, code_cache, /*num_splits=*/4, weights_cache); -} - static enum xnn_status reshape_even_split_operator_helper( const struct xnn_value* values, const uint32_t num_values, @@ -150,21 +117,20 @@ static enum xnn_status reshape_even_split_operator_helper( } } -static enum xnn_status reshape_even_split_n_operator( +static enum xnn_status reshape_even_split_operator( struct xnn_operator_data* opdata, struct xnn_value* values, size_t num_values, - size_t num_splits, pthreadpool_t threadpool) { enum xnn_status status = xnn_status_success; - + assert(opdata->num_inputs == 1); const uint32_t input_id = opdata->inputs[0]; assert(input_id != XNN_INVALID_VALUE_ID); assert(input_id < num_values); const struct xnn_value* input_value = values + input_id; - + int32_t axis = opdata->axis; if (axis < 0) { axis += input_value->shape.num_dims; @@ -178,7 +144,8 @@ static enum xnn_status reshape_even_split_n_operator( return xnn_status_invalid_parameter; } opdata->batch_size = xnn_shape_multiply_leading_dims(&input_value->shape, axis); - + + size_t num_splits = opdata->num_outputs; const size_t axis_elements = input_value->shape.dim[axis] / num_splits; const size_t old_workspace_size = opdata->workspace_size; bool reallocation_required = false; @@ -214,33 +181,6 @@ static enum xnn_status reshape_even_split_n_operator( return status; } -static enum xnn_status reshape_even_split2_operator( - struct xnn_operator_data* opdata, - struct xnn_value* values, - size_t num_values, - pthreadpool_t threadpool) -{ - return reshape_even_split_n_operator(opdata, values, num_values, /*num_splits=*/2, threadpool); -} - -static enum xnn_status reshape_even_split3_operator( - struct xnn_operator_data* opdata, - struct xnn_value* values, - size_t num_values, - pthreadpool_t threadpool) -{ - return reshape_even_split_n_operator(opdata, values, num_values, /*num_splits=*/3, threadpool); -} - -static enum xnn_status reshape_even_split4_operator( - struct xnn_operator_data* opdata, - struct xnn_value* values, - size_t num_values, - pthreadpool_t threadpool) -{ - return reshape_even_split_n_operator(opdata, values, num_values, /*num_splits=*/4, threadpool); -} - static enum xnn_status setup_even_split_operator_helper( const struct xnn_value* values, const uint32_t num_values, @@ -283,11 +223,10 @@ static enum xnn_status setup_even_split_operator_helper( } } -static enum xnn_status setup_even_split_n_operator( +static enum xnn_status setup_even_split_operator( const struct xnn_operator_data* opdata, const struct xnn_value* values, size_t num_values, - size_t num_splits, pthreadpool_t threadpool) { const uint32_t input_id = opdata->inputs[0]; @@ -299,7 +238,8 @@ static enum xnn_status setup_even_split_n_operator( assert(input_data != NULL); enum xnn_status status = xnn_status_success; - + + size_t num_splits = opdata->num_outputs; int operator_index = 0; for (size_t i = 0; i < num_splits; ++i) { const uint32_t output_id = opdata->outputs[i]; @@ -314,33 +254,6 @@ static enum xnn_status setup_even_split_n_operator( return status; } -static enum xnn_status setup_even_split2_operator( - const struct xnn_operator_data* opdata, - const struct xnn_value* values, - size_t num_values, - pthreadpool_t threadpool) -{ - return setup_even_split_n_operator(opdata, values, num_values, /*num_splits=*/2, threadpool);; -} - -static enum xnn_status setup_even_split3_operator( - const struct xnn_operator_data* opdata, - const struct xnn_value* values, - size_t num_values, - pthreadpool_t threadpool) -{ - return setup_even_split_n_operator(opdata, values, num_values, /*num_splits=*/3, threadpool);; -} - -static enum xnn_status setup_even_split4_operator( - const struct xnn_operator_data* opdata, - const struct xnn_value* values, - size_t num_values, - pthreadpool_t threadpool) -{ - return setup_even_split_n_operator(opdata, values, num_values, /*num_splits=*/4, threadpool);; -} - enum xnn_status check_output_value( xnn_subgraph_t subgraph, int32_t split_dim, @@ -388,7 +301,7 @@ static enum xnn_status check_datatype_copyable( return xnn_subgraph_check_quantization_parameter_matches(node_type, input_id, input_value, output_id, output_value); } -enum xnn_status xnn_define_even_split_n( +enum xnn_status xnn_define_even_split_impl( enum xnn_node_type node_type, xnn_subgraph_t subgraph, int32_t split_dim, @@ -415,35 +328,15 @@ enum xnn_status xnn_define_even_split_n( return status; } - status = check_output_value(subgraph, split_dim, input_id, output_ids[0], "first", node_type); - if (status != xnn_status_success) { - return status; - } - status = check_output_value(subgraph, split_dim, input_id, output_ids[1], "second", node_type); - if (status != xnn_status_success) { - return status; - } - - if (num_outputs > 2) { - status = check_output_value(subgraph, split_dim, input_id, output_ids[2], "third", node_type); - if (status != xnn_status_success) { - return status; - } - } - if (num_outputs > 3) { - status = check_output_value(subgraph, split_dim, input_id, output_ids[3], "fourth", node_type); + for (int i = 0; i < num_outputs; ++i) { + status = check_output_value(subgraph, split_dim, input_id, output_ids[i], "Nth", node_type); if (status != xnn_status_success) { return status; } } - check_datatype_copyable(subgraph, input_id, output_ids[0], "first", node_type); - check_datatype_copyable(subgraph, input_id, output_ids[1], "second", node_type); - if (num_outputs > 2) { - check_datatype_copyable(subgraph, input_id, output_ids[2], "third", node_type); - } - if (num_outputs > 3) { - check_datatype_copyable(subgraph, input_id, output_ids[3], "fourth", node_type); + for(int i = 0; i < num_outputs; ++i){ + check_datatype_copyable(subgraph, input_id, output_ids[i], "Nth", node_type); } struct xnn_node* node = xnn_subgraph_new_node(subgraph); @@ -456,73 +349,25 @@ enum xnn_status xnn_define_even_split_n( node->num_inputs = 1; node->inputs[0] = input_id; node->num_outputs = num_outputs; - node->outputs[0] = output_ids[0]; - node->outputs[1] = output_ids[1]; - switch (num_outputs) { - case 2: - node->create = create_even_split2_operator; - node->reshape = reshape_even_split2_operator; - node->setup = setup_even_split2_operator; - break; - case 3: - node->outputs[2] = output_ids[2]; - node->create = create_even_split3_operator; - node->reshape = reshape_even_split3_operator; - node->setup = setup_even_split3_operator; - break; - case 4: - node->outputs[2] = output_ids[2]; - node->outputs[3] = output_ids[3]; - node->create = create_even_split4_operator; - node->reshape = reshape_even_split4_operator; - node->setup = setup_even_split4_operator; - break; - default: - XNN_UNREACHABLE; + for(int i=0;ioutputs[i]=output_ids[i]; } + node->create = create_even_split_operator; + node->reshape = reshape_even_split_operator; + node->setup = setup_even_split_operator; node->flags = flags; return xnn_status_success; }; -enum xnn_status xnn_define_even_split2( - xnn_subgraph_t subgraph, - int32_t split_dim, - uint32_t input_id, - uint32_t output1_id, - uint32_t output2_id, - uint32_t flags) -{ - const uint32_t output_ids[2] = { output1_id, output2_id }; - return xnn_define_even_split_n( - xnn_node_type_even_split2, subgraph, split_dim, input_id, XNN_COUNT_OF(output_ids), output_ids, flags); -} - -enum xnn_status xnn_define_even_split3( +inline enum xnn_status xnn_define_even_split( xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, - uint32_t output1_id, - uint32_t output2_id, - uint32_t output3_id, - uint32_t flags) -{ - const uint32_t output_ids[3] = { output1_id, output2_id, output3_id }; - return xnn_define_even_split_n( - xnn_node_type_even_split3, subgraph, split_dim, input_id, XNN_COUNT_OF(output_ids), output_ids, flags); -} - -enum xnn_status xnn_define_even_split4( - xnn_subgraph_t subgraph, - int32_t split_dim, - uint32_t input_id, - uint32_t output1_id, - uint32_t output2_id, - uint32_t output3_id, - uint32_t output4_id, + uint32_t num_outputs, + const uint32_t* output_ids, uint32_t flags) { - const uint32_t output_ids[4] = { output1_id, output2_id, output3_id, output4_id }; - return xnn_define_even_split_n( - xnn_node_type_even_split4, subgraph, split_dim, input_id, XNN_COUNT_OF(output_ids), output_ids, flags); + return xnn_define_even_split_impl( + xnn_node_type_even_split, subgraph, split_dim, input_id, num_outputs, output_ids, flags); } diff --git a/src/xnnpack/node-type-defs.h b/src/xnnpack/node-type-defs.h index 280a5bbe595..cb711702c75 100644 --- a/src/xnnpack/node-type-defs.h +++ b/src/xnnpack/node-type-defs.h @@ -26,6 +26,7 @@ XNN_ENUM_ITEM(xnn_node_type_depthwise_convolution_2d, "Depthwise Convolution 2D" XNN_ENUM_ITEM(xnn_node_type_even_split2, "Even Split2") XNN_ENUM_ITEM(xnn_node_type_even_split3, "Even Split3") XNN_ENUM_ITEM(xnn_node_type_even_split4, "Even Split4") +XNN_ENUM_ITEM(xnn_node_type_even_split, "Even Split") XNN_ENUM_ITEM(xnn_node_type_fully_connected, "Fully Connected") XNN_ENUM_ITEM(xnn_node_type_fully_connected_sparse, "Fully Connected Sparse") XNN_ENUM_ITEM(xnn_node_type_global_average_pooling_1d, "Global Average Pooling 1D") diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 9e91f7f6e37..e0af25fde7f 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -1802,9 +1802,9 @@ xnnpack_unit_test( ) [xnnpack_unit_test( - name = "even_split%d_test" % n, + name = "even_split_test" if n == None else "even_split%d_test" % n, srcs = [ - "even-split%s.cc" % n, + "even_split" if n == None else "even-split%s.cc" % n, ], deps = [ ":replicable_random_device", @@ -1820,6 +1820,7 @@ xnnpack_unit_test( 2, 3, 4, + None, ]] xnnpack_unit_test( diff --git a/test/even-split.cc b/test/even-split.cc new file mode 100644 index 00000000000..1925746b44b --- /dev/null +++ b/test/even-split.cc @@ -0,0 +1,719 @@ +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "xnnpack.h" +#include "xnnpack/node-type.h" +#include "xnnpack/operator.h" +#include "xnnpack/subgraph.h" +#include "replicable_random_device.h" + +template class EvenSplitNTest : public ::testing::Test { +protected: + EvenSplitNTest() + { + shape_dist = std::uniform_int_distribution(1, XNN_MAX_TENSOR_DIMS); + dim_dist = std::uniform_int_distribution(1, 9); + f32dist = std::uniform_real_distribution(); + i8dist = + std::uniform_int_distribution(std::numeric_limits::min(), std::numeric_limits::max()); + u8dist = + std::uniform_int_distribution(std::numeric_limits::min(), std::numeric_limits::max()); + scale_dist = std::uniform_real_distribution(0.1f, 5.0f); + + num_outputs = RandomNumOutputs(); + output_dims.resize(num_outputs); + output_dims[0] = RandomShape(); + output_id.resize(num_outputs); + for (int i = 1; i < num_outputs; i++) { + output_dims[i] = output_dims[0]; + } + input_dims = output_dims[0]; + axis = RandomAxis(output_dims[0]); + for (int i = 1; i < num_outputs; i++) { + input_dims[axis] += output_dims[i][axis]; + } + + input = std::vector(NumElements(input_dims)); + operator_outputs.resize(num_outputs); + subgraph_outputs.resize(num_outputs); + for (int i = 0; i < num_outputs; i++) { + operator_outputs[i] = std::vector(NumElements(output_dims[i])); + subgraph_outputs[i] = std::vector(NumElements(output_dims[i])); + } + + signed_zero_point = i8dist(rng); + unsigned_zero_point = u8dist(rng); + scale = scale_dist(rng); + + batch_size = 1; + input_stride = 1; + for (size_t i = 0; i < axis; i++) { + batch_size *= input_dims[i]; + } + + for (size_t i = axis; i < input_dims.size(); i++) { + input_stride *= input_dims[i]; + } + channels = input_stride / num_outputs; + + } + + std::vector RandomShape() + { + std::vector dims(shape_dist(rng)); + std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); }); + return dims; + } + + size_t RandomAxis(const std::vector& dims) + { + return std::uniform_int_distribution(0, dims.size() - 1)(rng); + } + + size_t RandomNumOutputs() { return std::uniform_int_distribution(1, XNN_MAX_OUTPUTS)(rng); } + + size_t NumElements(const std::vector& dims) + { + return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies()); + } + + xnnpack::ReplicableRandomDevice rng; + std::uniform_int_distribution shape_dist; + std::uniform_int_distribution dim_dist; + std::uniform_real_distribution f32dist; + std::uniform_int_distribution i8dist; + std::uniform_int_distribution u8dist; + std::uniform_real_distribution scale_dist; + + std::vector output_id; + uint32_t input_id; + + std::vector> output_dims; + std::vector input_dims; + + size_t axis; + size_t num_outputs; + size_t batch_size; + size_t channels; + size_t input_stride; + + int32_t signed_zero_point; + int32_t unsigned_zero_point; + float scale; + + std::vector> operator_outputs; + std::vector> subgraph_outputs; + std::vector input; +}; + +using EvenSplitNTestQS8 = EvenSplitNTest; +using EvenSplitNTestQU8 = EvenSplitNTest; +using EvenSplitNTestF16 = EvenSplitNTest; +using EvenSplitNTestF32 = EvenSplitNTest; + +TEST_F(EvenSplitNTestQS8, define) +{ + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + + for (size_t i = 0; i < num_outputs; ++i) { + output_id[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims[i].size(), + output_dims[i].data(), nullptr, 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id[i])); + ASSERT_NE(output_id[i], XNN_INVALID_NODE_ID); + } + int32_t split_dim = axis; + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, split_dim, input_id, num_outputs, output_ddds, /*flags=*/0)); + ASSERT_EQ(subgraph->num_nodes, 1); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_even_split); + ASSERT_EQ(node->params.even_split.axis, axis); + ASSERT_EQ(node->num_inputs, 1); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->num_outputs, num_outputs); + for (int i = 0; i < num_outputs; i++) { + ASSERT_EQ(node->outputs[i], output_ddds[i]); + } + ASSERT_EQ(node->flags, 0); +} + +TEST_F(EvenSplitNTestQU8, define) +{ + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + for (size_t i = 0; i < num_outputs; ++i) { + output_id[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims[i].size(), + output_dims[i].data(), nullptr, 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id[i])); + ASSERT_NE(output_id[i], XNN_INVALID_NODE_ID); + } + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + ASSERT_EQ(subgraph->num_nodes, 1); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_even_split); + ASSERT_EQ(node->params.even_split.axis, axis); + ASSERT_EQ(node->num_inputs, 1); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->num_outputs, num_outputs); + for (int i = 0; i < num_outputs; i++) { + ASSERT_EQ(node->outputs[i], output_ddds[i]); + } + ASSERT_EQ(node->flags, 0); +} + +TEST_F(EvenSplitNTestF16, define) +{ + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + for (int i = 0; i < num_outputs; ++i) { + output_id[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, output_dims[i].size(), output_dims[i].data(), nullptr, 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id[i])); + ASSERT_NE(output_id[i], XNN_INVALID_NODE_ID); + } + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + + ASSERT_EQ(subgraph->num_nodes, 1); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_even_split); + ASSERT_EQ(node->params.even_split.axis, axis); + ASSERT_EQ(node->num_inputs, 1); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->num_outputs, num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(node->outputs[i], output_ddds[i]); + } + ASSERT_EQ(node->flags, 0); +} + +TEST_F(EvenSplitNTestF32, define) +{ + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + for (size_t i = 0; i < num_outputs; ++i) { + output_id[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims[i].size(), output_dims[i].data(), nullptr, 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id[i])); + ASSERT_NE(output_id[i], XNN_INVALID_NODE_ID); + } + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + + ASSERT_EQ(subgraph->num_nodes, 1); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_even_split); + ASSERT_EQ(node->params.even_split.axis, axis); + ASSERT_EQ(node->num_inputs, 1); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->num_outputs, num_outputs); + for (int i = 0; i < num_outputs; i++) { + ASSERT_EQ(node->outputs[i], output_ddds[i]); + } + ASSERT_EQ(node->flags, 0); +} + +TEST_F(EvenSplitNTestQS8, matches_operator_api) +{ + for (int i = 0; i < input.size(); ++i) { + input = std::vector(NumElements(input_dims), static_cast(i + 1)); + } + for (int i = 0; i < num_outputs; i++) { + std::fill(operator_outputs[i].begin(), operator_outputs[i].end(), INT8_C(0xA5)); + std::fill(subgraph_outputs[i].begin(), subgraph_outputs[i].end(), INT8_C(0xA5)); + } + + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + std::vector ops(num_outputs, nullptr); + + // Call operator API. + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(/*flags=*/0, &ops[i])); + std::unique_ptr auto_op1(ops[i], xnn_delete_operator); + ASSERT_EQ( + xnn_status_success, + xnn_reshape_copy_nc_x8(ops[i], batch_size, channels, input_stride, channels, /*threadpool=*/nullptr)); + if (i == 0) { + ASSERT_EQ( + xnn_status_success, + xnn_setup_copy_nc_x8(ops[i], input.data() + (i * (ops[i]->channels)), operator_outputs[i].data())); + } + else { + ASSERT_EQ( + xnn_status_success, + xnn_setup_copy_nc_x8(ops[i], (uint8_t*) input.data() + (i * (ops[i]->channels)), operator_outputs[i].data())); + } + ASSERT_EQ(xnn_status_success, xnn_run_operator(ops[i], /*threadpool=*/nullptr)); + } + + + // Call subgraph API. + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + + for (int i = 0; i < num_outputs; ++i) { + output_ddds[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims[i].size(), + output_dims[i].data(), nullptr, i + 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_ddds[i])); + ASSERT_NE(output_ddds[i], XNN_INVALID_NODE_ID); + } + + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::vector external; + external.reserve(1 + num_outputs); // Reserve space for input + outputs + + // Add the input value + external.emplace_back(xnn_external_value{input_id, input.data()}); + + // Loop to add output values + for (int i = 0; i < num_outputs; ++i) { + external.emplace_back(xnn_external_value{output_ddds[i], subgraph_outputs[i].data()}); + } + + ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(subgraph_outputs[i], operator_outputs[i]); + } +} + +TEST_F(EvenSplitNTestQU8, matches_operator_api) +{ + for (int i = 0; i < input.size(); ++i) { + input = std::vector(NumElements(input_dims), static_cast(i + 1)); + } + for (int i = 0; i < num_outputs; i++) { + std::fill(operator_outputs[i].begin(), operator_outputs[i].end(), UINT8_C(0xA5)); + std::fill(subgraph_outputs[i].begin(), subgraph_outputs[i].end(), UINT8_C(0xA5)); + } + + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + std::vector ops(num_outputs, nullptr); + + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(/*flags=*/0, &ops[i])); + std::unique_ptr auto_op1(ops[i], xnn_delete_operator); + ASSERT_EQ( + xnn_status_success, + xnn_reshape_copy_nc_x8(ops[i], batch_size, channels, input_stride, channels, /*threadpool=*/nullptr)); + if (i == 0) { + ASSERT_EQ(xnn_status_success, xnn_setup_copy_nc_x8(ops[i], input.data(), operator_outputs[i].data())); + } + else { + ASSERT_EQ( + xnn_status_success, + xnn_setup_copy_nc_x8(ops[i], (uint8_t*) input.data() + (i * (ops[i]->channels)), operator_outputs[i].data())); + } + ASSERT_EQ(xnn_status_success, xnn_run_operator(ops[i], /*threadpool=*/nullptr)); + } + + // Call subgraph API. + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + + for (int i = 0; i < num_outputs; ++i) { + output_ddds[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims[i].size(), + output_dims[i].data(), nullptr, i + 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_ddds[i])); + ASSERT_NE(output_ddds[i], XNN_INVALID_NODE_ID); + } + + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::vector external; + external.reserve(1 + num_outputs); // Reserve space for input + outputs + + // Add the input value + external.emplace_back(xnn_external_value{input_id, input.data()}); + + // Loop to add output values + for (int i = 0; i < num_outputs; ++i) { + external.emplace_back(xnn_external_value{output_ddds[i], subgraph_outputs[i].data()}); + } + + ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(subgraph_outputs[i], operator_outputs[i]); + } +} + +TEST_F(EvenSplitNTestF16, matches_operator_api) +{ + for (int i = 0; i < input.size(); ++i) { + input = std::vector(NumElements(input_dims), static_cast(i + 1)); + } + for (int i = 0; i < num_outputs; i++) { + std::fill(operator_outputs[i].begin(), operator_outputs[i].end(), std::nanf("")); + std::fill(subgraph_outputs[i].begin(), subgraph_outputs[i].end(), std::nanf("")); + } + + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + std::vector ops(num_outputs, nullptr); + + // Call operator API. + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x16(/*flags=*/0, &ops[i])); + std::unique_ptr auto_op1(ops[i], xnn_delete_operator); + + ASSERT_EQ( + xnn_status_success, + xnn_reshape_copy_nc_x16(ops[i], batch_size, channels, input_stride, channels, /*threadpool=*/nullptr)); + if (i == 0) { + ASSERT_EQ(xnn_status_success, xnn_setup_copy_nc_x16(ops[i], input.data(), operator_outputs[i].data())); + } + else { + ASSERT_EQ( + xnn_status_success, + xnn_setup_copy_nc_x16( + ops[i], (xnn_float16*) input.data() + (i * (ops[i]->channels)), operator_outputs[i].data())); + } + ASSERT_EQ(xnn_status_success, xnn_run_operator(ops[i], /*threadpool=*/nullptr)); + } + + // Call subgraph API. + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + for (int i = 0; i < num_outputs; ++i) { + output_ddds[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, output_dims[i].size(), output_dims[i].data(), nullptr, i + 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_ddds[i])); + ASSERT_NE(output_ddds[i], XNN_INVALID_NODE_ID); + } + + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::vector external; + external.reserve(1 + num_outputs); // Reserve space for input + outputs + + // Add the input value + external.emplace_back(xnn_external_value{input_id, input.data()}); + + // Loop to add output values + for (int i = 0; i < num_outputs; ++i) { + external.emplace_back(xnn_external_value{output_ddds[i], subgraph_outputs[i].data()}); + } + ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); + + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(subgraph_outputs[i], operator_outputs[i]); + } +} + +TEST_F(EvenSplitNTestF32, matches_operator_api) +{ + for (int i = 0; i < input.size(); ++i) { + input = std::vector(NumElements(input_dims), static_cast(i + 1)); + } + for (int i = 0; i < num_outputs; i++) { + std::fill(operator_outputs[i].begin(), operator_outputs[i].end(), std::nanf("")); + std::fill(subgraph_outputs[i].begin(), subgraph_outputs[i].end(), std::nanf("")); + } + + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + std::vector ops(num_outputs, nullptr); + + // Call operator API. + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(/*flags=*/0, &ops[i])); + std::unique_ptr auto_op1(ops[i], xnn_delete_operator); + + ASSERT_EQ( + xnn_status_success, + xnn_reshape_copy_nc_x32(ops[i], batch_size, channels, input_stride, channels, /*threadpool=*/nullptr)); + if (i == 0) { + ASSERT_EQ(xnn_status_success, xnn_setup_copy_nc_x32(ops[i], input.data(), operator_outputs[i].data())); + } + else { + ASSERT_EQ( + xnn_status_success, + xnn_setup_copy_nc_x32(ops[i], (uint32_t*) input.data() + (i * ops[i]->channels), operator_outputs[i].data())); + } + ASSERT_EQ(xnn_status_success, xnn_run_operator(ops[i], /*threadpool=*/nullptr)); + } + + // Call subgraph API. + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + for (int i = 0; i < num_outputs; ++i) { + output_ddds[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims[i].size(), output_dims[i].data(), nullptr, i + 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_ddds[i])); + ASSERT_NE(output_ddds[i], XNN_INVALID_NODE_ID); + } + + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::vector external; + external.reserve(1 + num_outputs); // Reserve space for input + outputs + + // Add the input value + external.emplace_back(xnn_external_value{input_id, input.data()}); + + // Loop to add output values + for (int i = 0; i < num_outputs; ++i) { + external.emplace_back(xnn_external_value{output_ddds[i], subgraph_outputs[i].data()}); + } + ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); + + for (int i = 0; i < num_outputs; ++i) { + ASSERT_EQ(subgraph_outputs[i], operator_outputs[i]); + } +} + +TEST_F(EvenSplitNTestF32, reshape_output) +{ + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + + // Call subgraph API. + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/num_outputs + 1, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + + input_id = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + + uint32_t output_ddds[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + output_ddds[i] = output_id[i]; + } + for (int i = 0; i < num_outputs; ++i) { + output_ddds[i] = XNN_INVALID_NODE_ID; + ASSERT_EQ( + xnn_status_success, xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims[i].size(), output_dims[i].data(), nullptr, i + 1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_ddds[i])); + ASSERT_NE(output_ddds[i], XNN_INVALID_NODE_ID); + } + + ASSERT_EQ( + xnn_status_success, + xnn_define_even_split(subgraph, axis, input_id, num_outputs, output_ddds, /*flags=*/0)); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::vector external; + external.reserve(1 + num_outputs); // Reserve space for input + outputs + + // Add the input value + external.emplace_back(xnn_external_value{input_id, input.data()}); + + // Loop to add output values + for (int i = 0; i < num_outputs; ++i) { + external.emplace_back(xnn_external_value{output_ddds[i], subgraph_outputs[i].data()}); + } + ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); + + input_dims[axis] += num_outputs; + ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_dims.size(), input_dims.data())); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ( + node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr), + xnn_status_reallocation_required); + for (size_t i = 0; i < num_outputs; ++i) { + const xnn_shape* output_n_shape = &runtime->values[node->outputs[i]].shape; + ASSERT_EQ(output_n_shape->dim[axis], input_dims[axis] / num_outputs); + for (size_t i = 0; i < input_dims.size(); ++i) { + if (i == axis) + continue; + ASSERT_EQ(output_n_shape->dim[i], input_dims[i]); + } + } + + input_dims[axis] -= 2 * num_outputs; + ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_dims.size(), input_dims.data())); + ASSERT_EQ( + node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr), + xnn_status_success); + for (size_t i = 0; i < num_outputs; ++i) { + const xnn_shape* output_n_shape = &runtime->values[node->outputs[i]].shape; + ASSERT_EQ(output_n_shape->dim[axis], input_dims[axis] / num_outputs); + for (size_t i = 0; i < input_dims.size(); ++i) { + if (i == axis) + continue; + ASSERT_EQ(output_n_shape->dim[i], input_dims[i]); + } + } +} diff --git a/test/fusion.cc b/test/fusion.cc index 354e1f871f7..dac55d6076a 100644 --- a/test/fusion.cc +++ b/test/fusion.cc @@ -651,7 +651,7 @@ TEST(COPY, fused_upstream_with_multiple_outputs) { EXPECT_EQ(unoptimized_output, optimized_output); const xnn_node* split_node = tester.Node(0); - ASSERT_EQ(split_node->type, xnn_node_type_even_split2); + ASSERT_EQ(split_node->type, xnn_node_type_even_split); EXPECT_EQ(split_node->inputs[0], input_id); ASSERT_EQ(split_node->num_outputs, 2); EXPECT_EQ(split_node->outputs[0], copy_out1); diff --git a/test/subgraph-tester.h b/test/subgraph-tester.h index 83ce8787090..d7ac94714a2 100644 --- a/test/subgraph-tester.h +++ b/test/subgraph-tester.h @@ -444,8 +444,17 @@ class SubgraphTester { } SubgraphTester& AddEvenSplit2(size_t split_dim, uint32_t input_id, uint32_t output1_id, uint32_t output2_id) { - const xnn_status status = xnn_define_even_split2( - subgraph_.get(), split_dim, input_id, output1_id, output2_id, 0 /* flags */); + const uint32_t output_ids[] = {output1_id, output2_id}; + const xnn_status status = + xnn_define_even_split(subgraph_.get(), split_dim, input_id, 2, output_ids, 0 /* flags */); + EXPECT_EQ(status, xnn_status_success); + return *this; + } + + SubgraphTester& AddEvenSplit3(size_t split_dim, uint32_t input_id, uint32_t output1_id, uint32_t output2_id, uint32_t output3_id) { + const uint32_t output_ids[] = {output1_id, output2_id, output3_id}; + const xnn_status status = + xnn_define_even_split(subgraph_.get(), split_dim, input_id, 3, output_ids, 0 /* flags */); EXPECT_EQ(status, xnn_status_success); return *this; } @@ -473,8 +482,9 @@ class SubgraphTester { } SubgraphTester& AddEvenSplit3(uint32_t input_id, uint32_t output_id0, uint32_t output_id1, uint32_t output_id2) { - const xnn_status status = xnn_define_even_split3( - subgraph_.get(), 0, input_id, output_id0, output_id1, output_id2, 0 /*flags */); + const uint32_t output_id[3] = {output_id0, output_id1, output_id2}; + const xnn_status status = + xnn_define_even_split(subgraph_.get(), 0, input_id, 3, output_id, 0 /*flags */); EXPECT_EQ(status, xnn_status_success); return *this;