Skip to content

Commit

Permalink
add to forward transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Nov 13, 2024
1 parent a0c2cf3 commit 39264b2
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ bool unsqueeze_axes_to_shape(const Output<Node>& input_node,
}
return true;
}

bool AreInputOutputShapesEqual(const std::shared_ptr<ov::op::v1::Reshape>& reshape) {
const auto input_shape = reshape->get_input_partial_shape(0);
const auto output_shape = reshape->get_output_partial_shape(0);

if (input_shape.is_dynamic() || output_shape.is_dynamic()) {
return false;
}
return input_shape == output_shape;
}

bool HasSpecialOne(const std::shared_ptr<ov::op::v0::Constant>& reshape_const) {
auto const_value = reshape_const->cast_vector<int64_t>();
return std::find(const_value.begin(), const_value.end(), -1) != const_value.end();
}

} // namespace

TSUnsqueezeForward::TSUnsqueezeForward() {
Expand All @@ -111,6 +127,28 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
if (!unsqueeze_axes) {
return false;
}
auto ts_order_values = transpose_info.transpose_const->cast_vector<size_t>();

// if main_node does nothing, just swap them
auto reshape = as_type_ptr<ov::op::v1::Reshape>(main_node);
if (reshape && AreInputOutputShapesEqual(reshape) && !HasSpecialOne(unsqueeze_axes)) {
TransposeInputsInfo transpose_input_info = {transpose_info.transpose, transpose_info.transpose_const, 0};
// remove input Transpose
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
if (!success) {
return false;
}

const auto reshape_order = ov::pass::transpose_sinking::utils::ReverseTransposeOrder(ts_order_values);
// transpose reshape const with Gather operation
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto gather =
ov::pass::transpose_sinking::utils::ChangeValuesOrder(reshape->input_value(1), reshape_order, axis);
main_node->input(1).replace_source_output(gather);

default_outputs_update(main_node, transpose_input_info);
return true;
}

std::vector<size_t> non_negative_axes;
if (as_type_ptr<ov::op::v1::Reshape>(main_node)) {
Expand All @@ -123,7 +161,6 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
non_negative_axes =
ov::util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node);
}
auto ts_order_values = transpose_info.transpose_const->cast_vector<size_t>();

ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(),
Expand Down Expand Up @@ -157,23 +194,6 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
transpose_sinking(matcher_name, sinking_transformation);
}

namespace {
bool AreInputOutputShapesEqual(const std::shared_ptr<ov::op::v1::Reshape>& reshape) {
const auto input_shape = reshape->get_input_partial_shape(0);
const auto output_shape = reshape->get_output_partial_shape(0);

if (input_shape.is_dynamic() || output_shape.is_dynamic()) {
return false;
}
return input_shape == output_shape;
}

bool HasSpecialOne(const std::shared_ptr<ov::op::v0::Constant>& reshape_const) {
auto const_value = reshape_const->cast_vector<int>();
return std::find(const_value.begin(), const_value.end(), -1) != const_value.end();
}
} // namespace

TSUnsqueezeBackward::TSUnsqueezeBackward() {
MATCHER_SCOPE(TSUnsqueezeBackward);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,54 @@ TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeBackwardSameS
manager.register_pass<TSUnsqueezeBackward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShape) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{4}, Shape{1, 3, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

const Shape input_shape = {4, 5, 6, 7};
{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{4}, Shape{5, 7, 4, 6});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);
model = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

{
auto X = std::make_shared<Parameter>(element::f32, input_shape);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{4}, Shape{5, 7, 4, 6});
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto indices = std::make_shared<Constant>(element::i32, Shape{4}, Shape{2, 0, 3, 1});
auto gather = std::make_shared<ov::op::v8::Gather>(reshape_const, indices, axis);
auto reshape = std::make_shared<Reshape>(X, gather, false);
auto transpose = create_transpose(reshape);
model_ref = std::make_shared<Model>(ov::OutputVector{transpose}, ov::ParameterVector{X});
}

manager.register_pass<TSUnsqueezeForward>();
}

TEST_F(TransformationTestsF, TransposeSinkingCommonReshapeUnsqueezeForwardSameShapeSpecialOne) {
auto create_transpose = [](const std::shared_ptr<ov::Node>& parent) {
auto ts_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{1, 0, 2});
return std::make_shared<Transpose>(parent, ts_order);
};

{
auto X = std::make_shared<Parameter>(element::f32, Shape{4, 5, 6});
auto transpose = create_transpose(X);
auto reshape_const = std::make_shared<Constant>(element::i64, Shape{3}, std::vector<int>{4, 5, -1});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);
model = std::make_shared<Model>(ov::OutputVector{reshape}, ov::ParameterVector{X});
}

model_ref = model->clone();

manager.register_pass<TSUnsqueezeForward>();
}

} // namespace common
} // namespace testing
} // namespace transpose_sinking

0 comments on commit 39264b2

Please sign in to comment.