Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix TSUnsqueezeBackward Reshape does nothing #27467

Merged
merged 6 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"

Expand Down Expand Up @@ -99,6 +98,22 @@ bool unsqueeze_axes_to_shape(const Output<Node>& input_node,
}
return true;
}

bool AreInputOutputShapesEqual(const std::shared_ptr<ov::op::v1::Reshape>& reshape) {
const auto input_shape = reshape->get_input_partial_shape(0);
const auto output_shape = reshape->get_output_partial_shape(0);

if (input_shape.is_dynamic() || output_shape.is_dynamic()) {
return false;
}
return input_shape == output_shape;
}

bool HasSpecialOne(const std::shared_ptr<ov::op::v0::Constant>& reshape_const) {
auto const_value = reshape_const->cast_vector<int64_t>();
return std::find(const_value.begin(), const_value.end(), -1) != const_value.end();
}

} // namespace

TSUnsqueezeForward::TSUnsqueezeForward() {
Expand All @@ -112,6 +127,28 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
if (!unsqueeze_axes) {
return false;
}
auto ts_order_values = transpose_info.transpose_const->cast_vector<size_t>();

// if main_node does nothing, just swap them
auto reshape = as_type_ptr<ov::op::v1::Reshape>(main_node);
if (reshape && AreInputOutputShapesEqual(reshape) && !HasSpecialOne(unsqueeze_axes)) {
TransposeInputsInfo transpose_input_info = {transpose_info.transpose, transpose_info.transpose_const, 0};
// remove input Transpose
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
if (!success) {
return false;
}

const auto reshape_order = ov::pass::transpose_sinking::utils::ReverseTransposeOrder(ts_order_values);
// transpose reshape const with Gather operation
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto gather =
ov::pass::transpose_sinking::utils::ChangeValuesOrder(reshape->input_value(1), reshape_order, axis);
main_node->input(1).replace_source_output(gather);

default_outputs_update(main_node, transpose_input_info);
return true;
}

std::vector<size_t> non_negative_axes;
if (as_type_ptr<ov::op::v1::Reshape>(main_node)) {
Expand All @@ -124,7 +161,6 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
non_negative_axes =
ov::util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node);
}
auto ts_order_values = transpose_info.transpose_const->cast_vector<size_t>();

ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(),
Expand Down Expand Up @@ -183,6 +219,27 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
if (!transpose_order || !unsqueeze_axes)
return false;

auto transpose_order_values = transpose_order->cast_vector<size_t>();

// if main_node does nothing, just swap them
auto reshape = as_type_ptr<ov::op::v1::Reshape>(main_node);
if (reshape && AreInputOutputShapesEqual(reshape) && !HasSpecialOne(unsqueeze_axes)) {
// insert Transpose before main_node on #0 input
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_order, {0})) {
register_new_node(new_node);
}
// transpose reshape const with Gather operation
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto gather = ov::pass::transpose_sinking::utils::ChangeValuesOrder(reshape->input_value(1),
transpose_order_values,
axis);
main_node->input(1).replace_source_output(gather);

main_node->validate_and_infer_types();
RemoveTransposeConsumers(main_node);
return true;
}

std::vector<size_t> non_negative_axes;
if (as_type_ptr<ov::op::v1::Reshape>(main_node)) {
auto success = shape_to_unsqueeze_axes(main_node, unsqueeze_axes, non_negative_axes);
Expand All @@ -205,7 +262,6 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
}
}

auto transpose_order_values = transpose_order->cast_vector<size_t>();
auto old_transpose_order_values = transpose_order_values;
std::vector<size_t> new_values;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,103 @@ auto test_backward_unsqueeze_dyn_rank = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackwardDynRank,
TSTestFixture,
test_backward_unsqueeze_dyn_rank());

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeBackwardSameShape) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{1, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

const Shape input_shape = {4, 5, 6};
{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{3}, Shape{4, 5, 6});
auto reshape = std::make_shared<Reshape>(X, reshape_const, false);
auto transpose = create_transpose(reshape);
model = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}

{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{3}, Shape{4, 5, 6});
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto indices = std::make_shared<Constant>(element::i32, Shape{3}, Shape{1, 0, 2});
auto gather = std::make_shared<ov::op::v8::Gather>(reshape_const, indices, axis);
auto reshape = std::make_shared<Reshape>(transpose, gather, false);
model_ref = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

manager.register_pass<TSUnsqueezeBackward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeBackwardSameShapeSpecialOne) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{1, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

{
auto X = std::make_shared<Parameter>(element::f32, Shape{4, 5, 6});
auto reshape_const = std::make_shared<Constant>(element::i64, Shape{3}, std::vector<int>{4, 5, -1});
auto reshape = std::make_shared<Reshape>(X, reshape_const, false);
auto transpose = create_transpose(reshape);
model = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}

model_ref = model->clone();

manager.register_pass<TSUnsqueezeBackward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShape) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{4}, Shape{1, 3, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

const Shape input_shape = {4, 5, 6, 7};
{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{4}, Shape{5, 7, 4, 6});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);
model = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{4}, Shape{5, 7, 4, 6});
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto indices = std::make_shared<Constant>(element::i32, Shape{4}, Shape{2, 0, 3, 1});
auto gather = std::make_shared<ov::op::v8::Gather>(reshape_const, indices, axis);
auto reshape = std::make_shared<Reshape>(X, gather, false);
auto transpose = create_transpose(reshape);
model_ref = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}

manager.register_pass<TSUnsqueezeForward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShapeSpecialOne) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{1, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

{
auto X = std::make_shared<Parameter>(element::f32, Shape{4, 5, 6});
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::i64, Shape{3}, std::vector<int>{4, 5, -1});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);
model = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

model_ref = model->clone();

manager.register_pass<TSUnsqueezeForward>();
}

} // namespace common
} // namespace testing
} // namespace transpose_sinking
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(node0, ng_order0);

auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{2, 20, 55, 96});
dorloff marked this conversation as resolved.
Show resolved Hide resolved
auto reshape = std::make_shared<Reshape>(transpose0, reshape_const, false);

auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
Expand All @@ -399,7 +399,7 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(node0, ng_order0);

auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{2, 20, 55, 96});
auto reshape = std::make_shared<Reshape>(transpose0, reshape_const, false);

auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
Expand Down
Loading