-
Notifications
You must be signed in to change notification settings - Fork 382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ConcatenateN Subgraph #7642
base: master
Are you sure you want to change the base?
Conversation
RahulSundarMCW
commented
Jan 6, 2025
•
edited
Loading
edited
- Replaced existing concatenate2, concatenate3, concatenate4, concatenate5, tests with a single test covering the Concatenate API.
- Marked old functions as XNN_DEPRECATED.
- Added shims in deprecated.c to call the new Concatenate API functions.
- Ensured the subgraph API remains stable with the new implementation.
- Replaced existing concatenate2, concatenate3, concatenate4, concatenate5, tests with a single test covering the ConcatenateN API. - Marked old functions as XNN_DEPRECATED. - Added shims in deprecated.c to call the new EvenSplitN API functions. - Ensured the subgraph API remains stable with the new implementation.
src/subgraph/concatenate.c
Outdated
size_t num_values, | ||
pthreadpool_t threadpool) | ||
{ | ||
return reshape_concatenate_n_operator(opdata, values, num_values, opdata->num_inputs, threadpool); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the num_inputs
parameter now that we don't use it It looks like you just don't need these wrappers at all any more.
src/subgraph/concatenate.c
Outdated
{ | ||
return setup_concatenate_n_operator(opdata, values, num_values, opdata->num_inputs, threadpool); | ||
} | ||
|
||
enum xnn_status xnn_define_concatenate_n( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need this any more. I think you should just rename this xnn_define_concatenate and remove the new function below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We pass node_type while calling xnn_define_concatenate_n() from xnn_define_concatenate() and renamed xnn_define_concatenate_n to xnn_define_concatenate_impl.
src/xnnpack/node-type-defs.h
Outdated
@@ -17,6 +17,7 @@ XNN_ENUM_ITEM(xnn_node_type_concatenate2, "Concatenate2") | |||
XNN_ENUM_ITEM(xnn_node_type_concatenate3, "Concatenate3") | |||
XNN_ENUM_ITEM(xnn_node_type_concatenate4, "Concatenate4") | |||
XNN_ENUM_ITEM(xnn_node_type_concatenate5, "Concatenate5") | |||
XNN_ENUM_ITEM(xnn_node_type_concatenate_n, "ConcatenateN") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just call it concatenate
. Also, remove the xnn_node_type_concatenateX
if they are unused now.
test/subgraph-tester.h
Outdated
SubgraphTester& AddConcatenate2(size_t axis, uint32_t input1_id, uint32_t input2_id, uint32_t output_id) { | ||
const xnn_status status = xnn_define_concatenate2( | ||
subgraph_.get(), axis, input1_id, input2_id, output_id, 0 /* flags */); | ||
SubgraphTester& AddConcatenate2(size_t axis, uint32_t input1_id, uint32_t input2_id, uint32_t output_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix indentation
src/subgraph/concatenate.c
Outdated
@@ -149,7 +116,7 @@ static enum xnn_status reshape_concatenate_n_operator( | |||
{ | |||
enum xnn_status status; | |||
|
|||
assert(opdata->num_inputs == num_inputs); | |||
num_inputs = opdata->num_inputs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There needs to be an error check here, that num_inputs <= XNN_MAX_OPERATOR_OJBECTS
test/concatenateN.cc
Outdated
@@ -0,0 +1,708 @@ | |||
// Copyright 2022 Google LLC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to concatenate.cc
test/concatenateN.cc
Outdated
|
||
size_t RandomNumInputs() | ||
{ | ||
return std::uniform_int_distribution<size_t>(2, 5)(rng); // You can adjust the range |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use XNN_MAX_OPERATOR_OJBECTS
as the upper bound
688a1c8
to
3d11d59
Compare
3d11d59
to
65d4924
Compare
1649169
to
8352c56
Compare
} | ||
if (num_inputs > 2) { | ||
status = check_datatype_copyable(subgraph, input_ids[2], output_id, "third", node_type); | ||
for (size_t i = 0; i < num_inputs; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need another check like this that the number of inputs is supported. Asserts are good, but we shouldn't crash if we hit an unsupported number of inputs, it should result in status != xnn_status_success
and an error message like the one produced by check_datatype_copyable
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the required check for the number of inputs exceeding the maximum limit. Whether it is fine or should it be implemented within a loop?
I don't see how that commit is going to fix all of the check failures, e.g.:
Can you please make sure the checks are passing locally to the extent that you can? |