Skip to content

Commit

Permalink
Add EvenSplitN subgraph
Browse files Browse the repository at this point in the history
  - Replaced existing even-split2, even-split3, even-split4, tests with a single test covering the EvenSplitN API.
  - Marked old functions as XNN_DEPRECATED.
  - Added shims in deprecated.c to call the new EvenSplitN API functions.
  - Ensured the subgraph API remains stable with the new implementation.
  • Loading branch information
RahulSundarMCW committed Jan 20, 2025
1 parent 854b343 commit 0f42aed
Show file tree
Hide file tree
Showing 11 changed files with 814 additions and 190 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,7 @@ IF(XNNPACK_BUILD_TESTS)
even-split2
even-split3
even-split4
even-split
fully-connected
global-average-pooling-1d
global-average-pooling-2d
Expand Down
31 changes: 28 additions & 3 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,31 @@ enum xnn_status xnn_define_copy(
uint32_t output_id,
uint32_t flags);

/// Define a n-Output Split Node and add it to a Subgraph.
///
/// The n-Output Split Node splits an input tensor into n output tensors along a specified axis evenly.
///
/// @param subgraph - a Subgraph object that will own the created Node.
/// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of
/// dimensions is added to it.
/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
/// subgraph.
/// @param num_outputs - The number of output tensors to generate. The input tensor will be evenly split into
/// this number of output tensors along the `split_dim`. Each output tensor will have
/// the same dimensions as the input tensor, except for the `split_dim`, which will be
/// divided evenly between the outputs.
/// @param outputs - An array of Value IDs for the output tensors. Each output tensor must be an N-dimensional
/// tensor defined in the @a subgraph with the same shape as the input tensor, except along the
/// `split_dim` dimension, which will be split evenly among the output tensors. The number of
/// output tensors corresponds to the value of `num_outputs`.
enum xnn_status xnn_define_even_split(
xnn_subgraph_t subgraph,
int32_t split_dim,
uint32_t input_id,
uint32_t num_outputs,
const uint32_t* outputs,
uint32_t flags);

/// Define a 2-Output Split Node and add it to a Subgraph.
///
/// The 2-Output Split Node splits an input tensor into two output tensors along a specified axis evenly.
Expand All @@ -1654,7 +1679,7 @@ enum xnn_status xnn_define_copy(
/// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
/// dimension of the first output. The split_dim dimension is half of the input's split_dim.
/// @param flags - binary features of the Split Node. No supported flags are currently defined.
enum xnn_status xnn_define_even_split2(
XNN_DEPRECATED enum xnn_status xnn_define_even_split2(
xnn_subgraph_t subgraph,
int32_t split_dim,
uint32_t input_id,
Expand Down Expand Up @@ -1683,7 +1708,7 @@ enum xnn_status xnn_define_even_split2(
/// dimension of the second and third output. The split_dim dimension is one third of the input's
/// split_dim.
/// @param flags - binary features of the Split Node. No supported flags are currently defined.
enum xnn_status xnn_define_even_split3(
XNN_DEPRECATED enum xnn_status xnn_define_even_split3(
xnn_subgraph_t subgraph,
int32_t split_dim,
uint32_t input_id,
Expand Down Expand Up @@ -1717,7 +1742,7 @@ enum xnn_status xnn_define_even_split3(
/// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
/// split_dim.
/// @param flags - binary features of the Split Node. No supported flags are currently defined.
enum xnn_status xnn_define_even_split4(
XNN_DEPRECATED enum xnn_status xnn_define_even_split4(
xnn_subgraph_t subgraph,
int32_t split_dim,
uint32_t input_id,
Expand Down
1 change: 1 addition & 0 deletions src/runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ void propagate_rank(
case xnn_node_type_even_split2:
case xnn_node_type_even_split3:
case xnn_node_type_even_split4:
case xnn_node_type_even_split:
case xnn_node_type_unary_elementwise:
case xnn_node_type_convert:
case xnn_node_type_pack_lh:
Expand Down
1 change: 1 addition & 0 deletions src/subgraph.c
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
case xnn_node_type_even_split2:
case xnn_node_type_even_split3:
case xnn_node_type_even_split4:
case xnn_node_type_even_split:
case xnn_node_type_fully_connected:
case xnn_node_type_global_average_pooling_2d:
case xnn_node_type_global_sum_pooling_2d:
Expand Down
20 changes: 20 additions & 0 deletions src/subgraph/deprecated.c
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,23 @@ enum xnn_status xnn_define_tanh(xnn_subgraph_t subgraph, uint32_t input_id,
return xnn_define_unary(subgraph, xnn_unary_tanh, NULL, input_id, output_id,
flags);
}

enum xnn_status xnn_define_even_split2(xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id,
uint32_t output1_id, uint32_t output2_id, uint32_t flags) {
const uint32_t outputs_id[2] = {output1_id, output2_id};
return xnn_define_even_split(subgraph, split_dim, input_id, /*num_outputs=*/2, outputs_id, flags);
}

enum xnn_status xnn_define_even_split3(xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id,
uint32_t output1_id, uint32_t output2_id,
uint32_t output3_id, uint32_t flags) {
const uint32_t outputs_id[3] = {output1_id, output2_id, output3_id};
return xnn_define_even_split(subgraph, split_dim, input_id, /*num_outputs=*/3, outputs_id, flags);
}

enum xnn_status xnn_define_even_split4(xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id,
uint32_t output1_id, uint32_t output2_id,
uint32_t output3_id, uint32_t output4_id, uint32_t flags) {
const uint32_t outputs_id[4] = {output1_id, output2_id, output3_id, output4_id};
return xnn_define_even_split(subgraph, split_dim, input_id, /*num_outputs=*/4, outputs_id, flags);
}
Loading

0 comments on commit 0f42aed

Please sign in to comment.