-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PT FE] Add aten::rot90 #28224
base: master
Are you sure you want to change the base?
[PT FE] Add aten::rot90 #28224
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR, please fix the comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also fix code style
? context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0,1})) | ||
: get_input_as_i32(context, 2); | ||
const auto& partial_shape = input.get_partial_shape(); | ||
const auto ndims = partial_shape.rank().get_length(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check that rank is static before get_length
if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) { | ||
scatter = context.mark_node(scatter_const); | ||
} else { | ||
context.mark_nodes( | ||
{start, step, range, axis_0, dim0_node, dim1_node, indices, updates, scatter.get_node_shared_ptr()}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) { | |
scatter = context.mark_node(scatter_const); | |
} else { | |
context.mark_nodes( | |
{start, step, range, axis_0, dim0_node, dim1_node, indices, updates, scatter.get_node_shared_ptr()}); | |
} |
That is not that popular operation, no need for this optimization
|
||
auto dims_norm = normalize_axis(context, dims, rank); | ||
auto dims_const = std::dynamic_pointer_cast<v0::Constant>(dims_norm.get_node_shared_ptr()); | ||
auto dims_values = dims_const->cast_vector<int32_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now you only use dims_values
only to validate that inputs are expected. That is nice to have, but not required and much better if we allow non-constant inputs then make these checks. So, please remove them. The only check that you can do is to verify that shape of dims is [2]
or dynamic
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the dims_values but I think there is some build errors in the pr request. Do i need to add anything if its dynamic ?
Tests fail with
|
Hi @mvafin , I have added an extra dimension for range to match the dimensions of indices and updates to try resolve the issue. Please review and let me know if further changes are needed. |
Details:
Tickets: