Skip to content

Commit

Permalink
fix TSUnsqueezeBackward Reshape does nothing (#27467)
Browse files Browse the repository at this point in the history
### Details:
 - fix TSUnsqueezeBackward

### Tickets:
 - CVS-111560
  • Loading branch information
evkotov authored Nov 15, 2024
1 parent 250f001 commit 6489755
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"

Expand Down Expand Up @@ -99,6 +98,22 @@ bool unsqueeze_axes_to_shape(const Output<Node>& input_node,
}
return true;
}

bool AreInputOutputShapesEqual(const std::shared_ptr<ov::op::v1::Reshape>& reshape) {
const auto input_shape = reshape->get_input_partial_shape(0);
const auto output_shape = reshape->get_output_partial_shape(0);

if (input_shape.is_dynamic() || output_shape.is_dynamic()) {
return false;
}
return input_shape == output_shape;
}

bool HasSpecialOne(const std::shared_ptr<ov::op::v0::Constant>& reshape_const) {
auto const_value = reshape_const->cast_vector<int64_t>();
return std::find(const_value.begin(), const_value.end(), -1) != const_value.end();
}

} // namespace

TSUnsqueezeForward::TSUnsqueezeForward() {
Expand All @@ -112,6 +127,28 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
if (!unsqueeze_axes) {
return false;
}
auto ts_order_values = transpose_info.transpose_const->cast_vector<size_t>();

// if main_node does nothing, just swap them
auto reshape = as_type_ptr<ov::op::v1::Reshape>(main_node);
if (reshape && AreInputOutputShapesEqual(reshape) && !HasSpecialOne(unsqueeze_axes)) {
TransposeInputsInfo transpose_input_info = {transpose_info.transpose, transpose_info.transpose_const, 0};
// remove input Transpose
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
if (!success) {
return false;
}

const auto reshape_order = ov::pass::transpose_sinking::utils::ReverseTransposeOrder(ts_order_values);
// transpose reshape const with Gather operation
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto gather =
ov::pass::transpose_sinking::utils::ChangeValuesOrder(reshape->input_value(1), reshape_order, axis);
main_node->input(1).replace_source_output(gather);

default_outputs_update(main_node, transpose_input_info);
return true;
}

std::vector<size_t> non_negative_axes;
if (as_type_ptr<ov::op::v1::Reshape>(main_node)) {
Expand All @@ -124,7 +161,6 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
non_negative_axes =
ov::util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node);
}
auto ts_order_values = transpose_info.transpose_const->cast_vector<size_t>();

ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(),
Expand Down Expand Up @@ -183,6 +219,27 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
if (!transpose_order || !unsqueeze_axes)
return false;

auto transpose_order_values = transpose_order->cast_vector<size_t>();

// if main_node does nothing, just swap them
auto reshape = as_type_ptr<ov::op::v1::Reshape>(main_node);
if (reshape && AreInputOutputShapesEqual(reshape) && !HasSpecialOne(unsqueeze_axes)) {
// insert Transpose before main_node on #0 input
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_order, {0})) {
register_new_node(new_node);
}
// transpose reshape const with Gather operation
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto gather = ov::pass::transpose_sinking::utils::ChangeValuesOrder(reshape->input_value(1),
transpose_order_values,
axis);
main_node->input(1).replace_source_output(gather);

main_node->validate_and_infer_types();
RemoveTransposeConsumers(main_node);
return true;
}

std::vector<size_t> non_negative_axes;
if (as_type_ptr<ov::op::v1::Reshape>(main_node)) {
auto success = shape_to_unsqueeze_axes(main_node, unsqueeze_axes, non_negative_axes);
Expand All @@ -205,7 +262,6 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
}
}

auto transpose_order_values = transpose_order->cast_vector<size_t>();
auto old_transpose_order_values = transpose_order_values;
std::vector<size_t> new_values;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,103 @@ auto test_backward_unsqueeze_dyn_rank = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackwardDynRank,
TSTestFixture,
test_backward_unsqueeze_dyn_rank());

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeBackwardSameShape) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{1, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

const Shape input_shape = {4, 5, 6};
{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{3}, Shape{4, 5, 6});
auto reshape = std::make_shared<Reshape>(X, reshape_const, false);
auto transpose = create_transpose(reshape);
model = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}

{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{3}, Shape{4, 5, 6});
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto indices = std::make_shared<Constant>(element::i32, Shape{3}, Shape{1, 0, 2});
auto gather = std::make_shared<ov::op::v8::Gather>(reshape_const, indices, axis);
auto reshape = std::make_shared<Reshape>(transpose, gather, false);
model_ref = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

manager.register_pass<TSUnsqueezeBackward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeBackwardSameShapeSpecialOne) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{1, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

{
auto X = std::make_shared<Parameter>(element::f32, Shape{4, 5, 6});
auto reshape_const = std::make_shared<Constant>(element::i64, Shape{3}, std::vector<int>{4, 5, -1});
auto reshape = std::make_shared<Reshape>(X, reshape_const, false);
auto transpose = create_transpose(reshape);
model = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}

model_ref = model->clone();

manager.register_pass<TSUnsqueezeBackward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShape) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{4}, Shape{1, 3, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

const Shape input_shape = {4, 5, 6, 7};
{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{4}, Shape{5, 7, 4, 6});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);
model = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{4}, Shape{5, 7, 4, 6});
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto indices = std::make_shared<Constant>(element::i32, Shape{4}, Shape{2, 0, 3, 1});
auto gather = std::make_shared<ov::op::v8::Gather>(reshape_const, indices, axis);
auto reshape = std::make_shared<Reshape>(X, gather, false);
auto transpose = create_transpose(reshape);
model_ref = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}

manager.register_pass<TSUnsqueezeForward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShapeSpecialOne) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{1, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

{
auto X = std::make_shared<Parameter>(element::f32, Shape{4, 5, 6});
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::i64, Shape{3}, std::vector<int>{4, 5, -1});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);
model = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

model_ref = model->clone();

manager.register_pass<TSUnsqueezeForward>();
}

} // namespace common
} // namespace testing
} // namespace transpose_sinking
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(node0, ng_order0);

auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{2, 20, 55, 96});
auto reshape = std::make_shared<Reshape>(transpose0, reshape_const, false);

auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
Expand All @@ -399,7 +399,7 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(node0, ng_order0);

auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{2, 20, 55, 96});
auto reshape = std::make_shared<Reshape>(transpose0, reshape_const, false);

auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
Expand Down

0 comments on commit 6489755

Please sign in to comment.