Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Add aten::rot90 #28224

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open

Conversation

Po-V
Copy link

@Po-V Po-V commented Dec 29, 2024

Details:

  • Add aten::rot90 operator

Tickets:

@Po-V Po-V requested a review from a team as a code owner December 29, 2024 03:49
@Po-V Po-V requested review from mvafin and PiotrKrzem December 29, 2024 03:49
@github-actions github-actions bot added the category: PyTorch FE OpenVINO PyTorch Frontend label Dec 29, 2024
@sys-openvino-ci sys-openvino-ci added the ExternalPR External contributor label Dec 29, 2024
Copy link
Contributor

@mvafin mvafin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR, please fix the comments

src/frontends/pytorch/src/op/rot90.cpp Outdated Show resolved Hide resolved
src/frontends/pytorch/src/op/rot90.cpp Outdated Show resolved Hide resolved
src/frontends/pytorch/src/op/rot90.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@mvafin mvafin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also fix code style

? context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0,1}))
: get_input_as_i32(context, 2);
const auto& partial_shape = input.get_partial_shape();
const auto ndims = partial_shape.rank().get_length();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that rank is static before get_length

Comment on lines 55 to 60
if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) {
scatter = context.mark_node(scatter_const);
} else {
context.mark_nodes(
{start, step, range, axis_0, dim0_node, dim1_node, indices, updates, scatter.get_node_shared_ptr()});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) {
scatter = context.mark_node(scatter_const);
} else {
context.mark_nodes(
{start, step, range, axis_0, dim0_node, dim1_node, indices, updates, scatter.get_node_shared_ptr()});
}

That is not that popular operation, no need for this optimization

src/frontends/pytorch/src/op/rot90.cpp Outdated Show resolved Hide resolved

auto dims_norm = normalize_axis(context, dims, rank);
auto dims_const = std::dynamic_pointer_cast<v0::Constant>(dims_norm.get_node_shared_ptr());
auto dims_values = dims_const->cast_vector<int32_t>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you only use dims_values only to validate that inputs are expected. That is nice to have, but not required and much better if we allow non-constant inputs then make these checks. So, please remove them. The only check that you can do is to verify that shape of dims is [2] or dynamic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the dims_values but I think there is some build errors in the pr request. Do i need to add anything if its dynamic ?

@mvafin
Copy link
Contributor

mvafin commented Jan 14, 2025

Tests fail with

Model wasn't fully converted. Failed operations detailed log:
-- aten::rot90 with a message:
Exception happened during conversion of operation aten::rot90 with schema aten::rot90(Tensor self, int k=1, int[] dims=[0, 1]) -> Tensor
Check 'indices_rank.compatible(data_rank)' failed at src/core/shape_inference/include/scatter_elements_update_shape_inference.hpp:34:
While validating node 'util::ScatterElementsUpdateBase ScatterElementsUpdateBase_15762611 (opset4::Range Range_15762604[0]:i32[3], opset1::Concat Concat_15762609[0]:i32[2,1], opset1::Concat Concat_15762610[0]:i32[2,1], opset1::Constant Constant_15762605[0]:i32[]) -> (dynamic[...])' with friendly_name 'ScatterElementsUpdateBase_15762611':
Indices rank and data rank are required to be equal. Got: 2 and: 1

Summary:
-- Conversion is failed for: aten::rot90

@Po-V
Copy link
Author

Po-V commented Jan 20, 2025

Hi @mvafin , I have added an extra dimension for range to match the dimensions of indices and updates to try resolve the issue. Please review and let me know if further changes are needed.

@Po-V Po-V closed this Jan 20, 2025
@Po-V Po-V reopened this Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: PyTorch FE OpenVINO PyTorch Frontend ExternalPR External contributor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Good First Issue]: Support aten::rot90 for pytorch models
3 participants