diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp index cdeb9226ed236c..ce47caa10c4c0f 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp @@ -14,7 +14,6 @@ #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/rt_info/transpose_sinking_attr.hpp" #include "transformations/transpose_sinking/ts_utils.hpp" #include "transformations/utils/utils.hpp" @@ -99,6 +98,22 @@ bool unsqueeze_axes_to_shape(const Output& input_node, } return true; } + +bool AreInputOutputShapesEqual(const std::shared_ptr& reshape) { + const auto input_shape = reshape->get_input_partial_shape(0); + const auto output_shape = reshape->get_output_partial_shape(0); + + if (input_shape.is_dynamic() || output_shape.is_dynamic()) { + return false; + } + return input_shape == output_shape; +} + +bool HasSpecialOne(const std::shared_ptr& reshape_const) { + auto const_value = reshape_const->cast_vector(); + return std::find(const_value.begin(), const_value.end(), -1) != const_value.end(); +} + } // namespace TSUnsqueezeForward::TSUnsqueezeForward() { @@ -112,6 +127,28 @@ TSUnsqueezeForward::TSUnsqueezeForward() { if (!unsqueeze_axes) { return false; } + auto ts_order_values = transpose_info.transpose_const->cast_vector(); + + // if main_node does nothing, just swap them + auto reshape = as_type_ptr(main_node); + if (reshape && AreInputOutputShapesEqual(reshape) && !HasSpecialOne(unsqueeze_axes)) { + TransposeInputsInfo transpose_input_info = {transpose_info.transpose, transpose_info.transpose_const, 0}; + // remove input Transpose + auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0}); + if (!success) { + return false; + } + + const auto reshape_order = ov::pass::transpose_sinking::utils::ReverseTransposeOrder(ts_order_values); + // transpose reshape const with Gather operation + auto axis = std::make_shared(element::i32, Shape{}, 0); + auto gather = + ov::pass::transpose_sinking::utils::ChangeValuesOrder(reshape->input_value(1), reshape_order, axis); + main_node->input(1).replace_source_output(gather); + + default_outputs_update(main_node, transpose_input_info); + return true; + } std::vector non_negative_axes; if (as_type_ptr(main_node)) { @@ -124,7 +161,6 @@ TSUnsqueezeForward::TSUnsqueezeForward() { non_negative_axes = ov::util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node); } - auto ts_order_values = transpose_info.transpose_const->cast_vector(); ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values); auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(), @@ -183,6 +219,27 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() { if (!transpose_order || !unsqueeze_axes) return false; + auto transpose_order_values = transpose_order->cast_vector(); + + // if main_node does nothing, just swap them + auto reshape = as_type_ptr(main_node); + if (reshape && AreInputOutputShapesEqual(reshape) && !HasSpecialOne(unsqueeze_axes)) { + // insert Transpose before main_node on #0 input + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_order, {0})) { + register_new_node(new_node); + } + // transpose reshape const with Gather operation + auto axis = std::make_shared(element::i32, Shape{}, 0); + auto gather = ov::pass::transpose_sinking::utils::ChangeValuesOrder(reshape->input_value(1), + transpose_order_values, + axis); + main_node->input(1).replace_source_output(gather); + + main_node->validate_and_infer_types(); + RemoveTransposeConsumers(main_node); + return true; + } + std::vector non_negative_axes; if (as_type_ptr(main_node)) { auto success = shape_to_unsqueeze_axes(main_node, unsqueeze_axes, non_negative_axes); @@ -205,7 +262,6 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() { } } - auto transpose_order_values = transpose_order->cast_vector(); auto old_transpose_order_values = transpose_order_values; std::vector new_values; diff --git a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp index d71c9006edd38a..fc5c315312cfaa 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp @@ -1677,6 +1677,103 @@ auto test_backward_unsqueeze_dyn_rank = []() { INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackwardDynRank, TSTestFixture, test_backward_unsqueeze_dyn_rank()); + +TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeBackwardSameShape) { + auto create_transpose = [](const std::shared_ptr& parent) { + auto ts_order = std::make_shared(element::u64, Shape{3}, Shape{1, 0, 2}); + return std::make_shared(parent, ts_order); + }; + + const Shape input_shape = {4, 5, 6}; + { + auto X = std::make_shared(element::f32, input_shape); + auto reshape_const = std::make_shared(element::u64, Shape{3}, Shape{4, 5, 6}); + auto reshape = std::make_shared(X, reshape_const, false); + auto transpose = create_transpose(reshape); + model = std::make_shared(ov::OutputVector{transpose}, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(element::f32, input_shape); + auto transpose = create_transpose(X); + auto reshape_const = std::make_shared(element::u64, Shape{3}, Shape{4, 5, 6}); + auto axis = std::make_shared(element::i32, Shape{}, 0); + auto indices = std::make_shared(element::i32, Shape{3}, Shape{1, 0, 2}); + auto gather = std::make_shared(reshape_const, indices, axis); + auto reshape = std::make_shared(transpose, gather, false); + model_ref = std::make_shared(ov::OutputVector{reshape}, ov::ParameterVector{X}); + } + + manager.register_pass(); +} + +TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeBackwardSameShapeSpecialOne) { + auto create_transpose = [](const std::shared_ptr& parent) { + auto ts_order = std::make_shared(element::u64, Shape{3}, Shape{1, 0, 2}); + return std::make_shared(parent, ts_order); + }; + + { + auto X = std::make_shared(element::f32, Shape{4, 5, 6}); + auto reshape_const = std::make_shared(element::i64, Shape{3}, std::vector{4, 5, -1}); + auto reshape = std::make_shared(X, reshape_const, false); + auto transpose = create_transpose(reshape); + model = std::make_shared(ov::OutputVector{transpose}, ov::ParameterVector{X}); + } + + model_ref = model->clone(); + + manager.register_pass(); +} + +TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShape) { + auto create_transpose = [](const std::shared_ptr& parent) { + auto ts_order = std::make_shared(element::u64, Shape{4}, Shape{1, 3, 0, 2}); + return std::make_shared(parent, ts_order); + }; + + const Shape input_shape = {4, 5, 6, 7}; + { + auto X = std::make_shared(element::f32, input_shape); + auto transpose = create_transpose(X); + auto reshape_const = std::make_shared(element::u64, Shape{4}, Shape{5, 7, 4, 6}); + auto reshape = std::make_shared(transpose, reshape_const, false); + model = std::make_shared(ov::OutputVector{reshape}, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(element::f32, input_shape); + auto reshape_const = std::make_shared(element::u64, Shape{4}, Shape{5, 7, 4, 6}); + auto axis = std::make_shared(element::i32, Shape{}, 0); + auto indices = std::make_shared(element::i32, Shape{4}, Shape{2, 0, 3, 1}); + auto gather = std::make_shared(reshape_const, indices, axis); + auto reshape = std::make_shared(X, gather, false); + auto transpose = create_transpose(reshape); + model_ref = std::make_shared(ov::OutputVector{transpose}, ov::ParameterVector{X}); + } + + manager.register_pass(); +} + +TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShapeSpecialOne) { + auto create_transpose = [](const std::shared_ptr& parent) { + auto ts_order = std::make_shared(element::u64, Shape{3}, Shape{1, 0, 2}); + return std::make_shared(parent, ts_order); + }; + + { + auto X = std::make_shared(element::f32, Shape{4, 5, 6}); + auto transpose = create_transpose(X); + auto reshape_const = std::make_shared(element::i64, Shape{3}, std::vector{4, 5, -1}); + auto reshape = std::make_shared(transpose, reshape_const, false); + model = std::make_shared(ov::OutputVector{reshape}, ov::ParameterVector{X}); + } + + model_ref = model->clone(); + + manager.register_pass(); +} + } // namespace common } // namespace testing } // namespace transpose_sinking diff --git a/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp index f00c69d2a8d734..7dc3a2b54c7bea 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp @@ -380,7 +380,7 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) { auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto transpose0 = std::make_shared(node0, ng_order0); - auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96}); + auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{2, 20, 55, 96}); auto reshape = std::make_shared(transpose0, reshape_const, false); auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); @@ -399,7 +399,7 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) { auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto transpose0 = std::make_shared(node0, ng_order0); - auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96}); + auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{2, 20, 55, 96}); auto reshape = std::make_shared(transpose0, reshape_const, false); auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);