From e04dbdec382cb07994be52b2e5c1e4aa875443a0 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 16 Jan 2025 11:32:13 +0100 Subject: [PATCH] [PT FE] Improve support for complex data type Signed-off-by: Maxim Vafin --- src/frontends/pytorch/src/frontend.cpp | 9 +- src/frontends/pytorch/src/op/complex.cpp | 84 +++++++ src/frontends/pytorch/src/op/fft.cpp | 208 ++++++++++++++++++ src/frontends/pytorch/src/op/permute.cpp | 35 ++- src/frontends/pytorch/src/op/reshape.cpp | 26 ++- src/frontends/pytorch/src/op/size.cpp | 23 +- src/frontends/pytorch/src/op/stft.cpp | 9 +- src/frontends/pytorch/src/op_table.cpp | 21 +- .../transforms/irfftn_complex_replacer.cpp | 164 -------------- .../transforms/irfftn_complex_replacer.hpp | 24 -- .../src/transforms/rfftn_complex_replacer.cpp | 163 -------------- .../src/transforms/rfftn_complex_replacer.hpp | 24 -- src/frontends/pytorch/src/utils.cpp | 23 +- src/frontends/pytorch/src/utils.hpp | 4 +- .../layer_tests/pytorch_tests/test_permute.py | 43 ++-- .../layer_tests/pytorch_tests/test_reshape.py | 44 ++-- tests/layer_tests/pytorch_tests/test_size.py | 30 ++- tests/layer_tests/pytorch_tests/test_stft.py | 12 +- 18 files changed, 496 insertions(+), 450 deletions(-) create mode 100644 src/frontends/pytorch/src/op/complex.cpp create mode 100644 src/frontends/pytorch/src/op/fft.cpp delete mode 100644 src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp delete mode 100644 src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp delete mode 100644 src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp delete mode 100644 src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 04ba9a9c92c281..bb69e8fa313130 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -30,7 +30,6 @@ #include "transforms/dict_resolver.hpp" #include "transforms/einsum_list_construct.hpp" #include "transforms/index_loop_getitem_replacer.hpp" -#include "transforms/irfftn_complex_replacer.hpp" #include "transforms/listconstruct_replacer.hpp" #include "transforms/min_max_prim_list_construct_replacer.hpp" #include "transforms/prim_list_construct_pad.hpp" @@ -40,7 +39,6 @@ #include "transforms/quantized_node_remover.hpp" #include "transforms/remove_packing_ops.hpp" #include "transforms/reverseprop_resolver.hpp" -#include "transforms/rfftn_complex_replacer.hpp" #include "transforms/softmax_reshape_elimination.hpp" #include "transforms/string_equality_replacer.hpp" #include "transforms/torchfx_gptq_pattern_replacer.hpp" @@ -69,6 +67,11 @@ std::map get_unconverted_types_from_model(const std::s if (!unconverted_ops_types.count(op_type_it->second)) { unconverted_ops_types.emplace(op_type_it->second, std::move(exception_msg)); } + } else if (const auto& fw_node = ov::as_type_ptr(node)) { + auto op_type = std::string(fw_node->get_type_name()); + if (!unconverted_ops_types.count(op_type)) { + unconverted_ops_types.emplace(op_type, "This is OpenVINO internal type."); + } } if (const auto& fw_node = ov::as_type_ptr(node)) { for (size_t i = 0; i < fw_node->get_internal_subgraphs_size(); ++i) { @@ -283,8 +286,6 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); - manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/op/complex.cpp b/src/frontends/pytorch/src/op/complex.cpp new file mode 100644 index 00000000000000..8ec0f5435e358b --- /dev/null +++ b/src/frontends/pytorch/src/op/complex.cpp @@ -0,0 +1,84 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/complex_type_mark.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_complex(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto real = context.get_input(0); + auto imag = context.get_input(1); + + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + real = context.mark_node(std::make_shared(real, const_neg_1)); + imag = context.mark_node(std::make_shared(imag, const_neg_1)); + + auto complex = context.mark_node(std::make_shared(OutputVector{real, imag}, -1)); + + return {context.mark_node(std::make_shared(complex, complex->get_element_type()))}; +}; + +OutputVector translate_imag(const NodeContext& context) { + num_inputs_check(context, 1, 1, true); + auto complex = context.get_input(0); + + auto complex_type_mark = as_type_ptr(complex.get_node_shared_ptr()); + PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::imag operation expects complex type tensor on input."); + + complex = complex_type_mark->input_value(0); + auto axis = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto imag = context.mark_node(std::make_shared(complex, axis, 2))->output(1); + + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + return {context.mark_node(std::make_shared(imag, const_neg_1))}; +}; + +OutputVector translate_real(const NodeContext& context) { + num_inputs_check(context, 1, 1, true); + auto complex = context.get_input(0); + + auto complex_type_mark = as_type_ptr(complex.get_node_shared_ptr()); + PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::real operation expects complex type tensor on input."); + + complex = complex_type_mark->input_value(0); + auto axis = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto real = context.mark_node(std::make_shared(complex, axis, 2))->output(0); + + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + return {context.mark_node(std::make_shared(real, const_neg_1))}; +}; + +OutputVector translate_view_as_real(const NodeContext& context) { + num_inputs_check(context, 1, 1, true); + auto complex = context.get_input(0); + + auto complex_type_mark = as_type_ptr(complex.get_node_shared_ptr()); + PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::real operation expects complex type tensor on input."); + + return {complex_type_mark->input_value(0)}; +}; + +OutputVector translate_view_as_complex(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto complex = context.get_input(0); + + return {context.mark_node(std::make_shared(complex, complex.get_element_type()))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp new file mode 100644 index 00000000000000..0c2eb17c49d305 --- /dev/null +++ b/src/frontends/pytorch/src/op/fft.cpp @@ -0,0 +1,208 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/complex_type_mark.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/irdft.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/rdft.hpp" +#include "openvino/op/reduce_prod.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_update.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_fft_rfftn(const NodeContext& context) { + // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + num_inputs_check(context, 1, 4); + auto input = context.get_input(0); + + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + + Output input_shape; + Output input_rank_scalar; + std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, input, true); + + Output raw_s; + // Inputs can be either none or List. Check whether input values should be used or should be set to default values. + if (!context.input_is_none(1)) { + // s is provided, load from input. + raw_s = get_input_concat_if_list(context, 1); + raw_s = context.mark_node(std::make_shared(raw_s, element::i32)); + } + Output dim; + // Handle dim parameter containing vector of integers indicating dimensions to be transformed. + if (!context.input_is_none(2)) { + // dim is provided, load from input. + dim = get_input_concat_if_list(context, 2); + dim = context.mark_node(std::make_shared(dim, element::i32)); + } else if (!context.input_is_none(1)) { + // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. + auto s_len = context.mark_node(std::make_shared(raw_s, element::i32)); + auto slice_start = context.mark_node(std::make_shared(input_rank_scalar, s_len)); + auto slice_start_scalar = context.mark_node(std::make_shared(slice_start)); + dim = context.mark_node( + std::make_shared(slice_start_scalar, input_rank_scalar, const_1, element::i32)); + } else { + // Dim and s are set to default, use all of dimensions. + dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32)); + } + + Output s; + if (context.input_is_none(1)) { + // Value for s was set to default, use full size for all dimensions. + s = context.mark_node(std::make_shared(input_shape, dim, const_0)); + } else { + // Values for s were provided. Replace -1 values with default full size in given dimension. + auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1)); + auto full_s_values = context.mark_node(std::make_shared(input_shape, dim, const_0)); + s = context.mark_node(std::make_shared(full_s_cond, full_s_values, raw_s)); + } + + // Handle norm parameter indicating normalization mode to use. Defaults to "backward". + std::string norm = "backward"; + if (!context.input_is_none(3)) { + norm = context.const_input(3); + } + + auto rdft = context.mark_node(std::make_shared(input, dim, s)); + + // Apply normalizations + auto n_int = context.mark_node(std::make_shared(s, const_0)); + auto n = context.mark_node(std::make_shared(n_int, rdft)); + Output normalized_rfftn; + if (norm == "forward") { + // Normalize by 1/n + normalized_rfftn = context.mark_node(std::make_shared(rdft, n)); + } else if (norm == "backward") { + // No normalization + normalized_rfftn = rdft; + } else if (norm == "ortho") { + // Normalize by 1/sqrt(n) + auto sqrt_n = context.mark_node(std::make_shared(n)); + normalized_rfftn = context.mark_node(std::make_shared(rdft, sqrt_n)); + } else { + FRONT_END_THROW( + "aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); + } + + return {std::make_shared(normalized_rfftn, normalized_rfftn.get_element_type())}; +} + +OutputVector translate_fft_irfftn(const NodeContext& context) { + // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + num_inputs_check(context, 1, 4, true); + auto input = context.get_input(0); + + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::fft_irfftn operation expects complex type tensor on input."); + input = complex_type_mark->input_value(0); + + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + auto const_scalar_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); + auto const_scalar_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2})); + + // Input shape of complex number (excluding dimension created by concatenation of real and imag) + auto complex_input_shape = get_complex_shape(context, input); + auto input_rank = context.mark_node(std::make_shared(complex_input_shape, element::i32)); + auto input_rank_scalar = context.mark_node(std::make_shared(input_rank)); + + Output raw_s; + // Inputs can be either none or List. Check whether input values should be used or should be set to default values. + if (!context.input_is_none(1)) { + // s is provided, load from input. + raw_s = get_input_concat_if_list(context, 1); + raw_s = context.mark_node(std::make_shared(raw_s, element::i32)); + } + + // Handle dim parameter containing vector of integers indicating dimensions to be transformed. + Output dim; + if (!context.input_is_none(2)) { + // Dim values is provided, load from input. + dim = get_input_concat_if_list(context, 2); + dim = context.mark_node(std::make_shared(dim, element::i32)); + } else if (!context.input_is_none(1)) { + // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. + auto s_len = context.mark_node(std::make_shared(raw_s, element::i32)); + auto range_start = context.mark_node(std::make_shared(input_rank, s_len)); + auto range_start_scalar = context.mark_node(std::make_shared(range_start)); + dim = context.mark_node( + std::make_shared(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32)); + } else { + // Dim and s are set to default, use all of dimensions. + dim = context.mark_node( + std::make_shared(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32)); + } + + // Calculate default s values. Use full available size except last element, which is set to even value in last + // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) + auto default_s_raw = context.mark_node(std::make_shared(complex_input_shape, dim, const_0)); + auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0)); + auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1)); + auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2)); + auto s_shape = context.mark_node(std::make_shared(default_s_raw, element::i32)); + auto last_s_idx = context.mark_node(std::make_shared(s_shape, const_1)); + auto default_s = context.mark_node(std::make_shared(default_s_raw, last_s_idx, s_upd, const_0)); + + // Handle s parameter containing vector of intigers indicating signal sizes for dimensions. + Output s; + if (!context.input_is_none(1)) { + // Values for s were provided. Replace -1 values with default full size in given dimension. + auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1)); + s = context.mark_node(std::make_shared(full_s_cond, default_s, raw_s)); + } else { + // Value for s was set to default. + s = default_s; + } + + // Handle norm parameter indicating normalization mode to use. Defaults to "backward". + std::string norm = "backward"; + if (!context.input_is_none(3)) { + norm = context.const_input(3); + } + + auto irdft = context.mark_node(std::make_shared(input, dim, s)); + + // Apply normalizations. + auto n_int = context.mark_node(std::make_shared(s, const_0)); + auto n = context.mark_node(std::make_shared(n_int, irdft)); + Output normalized_irfftn; + if (norm == "forward") { + normalized_irfftn = context.mark_node(std::make_shared(irdft, n)); + } else if (norm == "backward") { + normalized_irfftn = irdft; + } else if (norm == "ortho") { + auto sqrt_n = context.mark_node(std::make_shared(n)); + normalized_irfftn = context.mark_node(std::make_shared(irdft, sqrt_n)); + } else { + FRONT_END_THROW( + "aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); + } + return {normalized_irfftn}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/permute.cpp b/src/frontends/pytorch/src/op/permute.cpp index 46016ca8ca16a0..c724e38b8077b2 100644 --- a/src/frontends/pytorch/src/op/permute.cpp +++ b/src/frontends/pytorch/src/op/permute.cpp @@ -3,7 +3,10 @@ // #include "openvino/core/validation_util.hpp" +#include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "utils.hpp" @@ -12,17 +15,41 @@ namespace frontend { namespace pytorch { namespace op { +using namespace ov::op; + OutputVector translate_permute(const NodeContext& context) { - num_inputs_check(context, 2, 2); + num_inputs_check(context, 2, 2, true); auto data = context.get_input(0); auto order = get_input_concat_if_list(context, 1); - auto rank = std::get<1>(get_shape_rank(context, data)); - auto rank_converted = context.mark_node(std::make_shared(rank, order)); + + Output rank; + auto complex_type_mark = as_type_ptr(data.get_node_shared_ptr()); + if (complex_type_mark) { + data = complex_type_mark->input_value(0); + rank = std::get<1>(get_shape_rank(context, data)); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + rank = context.mark_node(std::make_shared(rank, const_1)); + } else { + rank = std::get<1>(get_shape_rank(context, data)); + } + + auto rank_converted = context.mark_node(std::make_shared(rank, order)); auto order_normalized = normalize_axis(context, order, rank_converted); + + if (complex_type_mark) { + auto to_concat = OutputVector{order_normalized, rank_converted}; + order_normalized = context.mark_node(std::make_shared(to_concat, 0)); + } + if (const auto order_const = ov::util::get_constant_from_source(order_normalized)) { order_normalized = order_const; } - return {context.mark_node(std::make_shared(data, order_normalized))}; + auto permute = context.mark_node(std::make_shared(data, order_normalized)); + if (complex_type_mark) { + const auto& complex_dtype = complex_type_mark->get_complex_part_type(); + permute = context.mark_node(std::make_shared(permute, complex_dtype)); + } + return {permute}; } } // namespace op diff --git a/src/frontends/pytorch/src/op/reshape.cpp b/src/frontends/pytorch/src/op/reshape.cpp index 7524d0e3c4aaf4..b9dcfc8d9afc4a 100644 --- a/src/frontends/pytorch/src/op/reshape.cpp +++ b/src/frontends/pytorch/src/op/reshape.cpp @@ -4,6 +4,7 @@ #include "openvino/op/reshape.hpp" +#include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/squeeze.hpp" @@ -15,15 +16,34 @@ namespace frontend { namespace pytorch { namespace op { +using namespace ov::op; + OutputVector translate_reshape(const NodeContext& context) { // Translation is used by both aten::view and aten::reshape. // Schema: aten::view(Tensor input, int[] shape) -> Tensor // Schema: aten::reshape(Tensor input, int[] shape) -> Tensor // For shape parameter, int[] is converted into single dimensional Tensor. - num_inputs_check(context, 2, 2); + num_inputs_check(context, 2, 2, true); + auto tensor = context.get_input(0); auto shape = get_input_concat_if_list(context, 1); - auto reshape = std::make_shared(context.get_input(0), shape, false); - return {context.mark_node(reshape)}; + + auto complex_type_mark = as_type_ptr(tensor.get_node_shared_ptr()); + if (complex_type_mark) { + tensor = complex_type_mark->input_value(0); + auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2})); + const_2 = context.mark_node(std::make_shared(const_2, shape)); + + shape = context.mark_node(std::make_shared(OutputVector{shape, const_2}, 0)); + } + + auto reshape = context.mark_node(std::make_shared(tensor, shape, false)); + + if (complex_type_mark) { + const auto& complex_dtype = complex_type_mark->get_complex_part_type(); + return {context.mark_node(std::make_shared(reshape, complex_dtype))}; + } else { + return {reshape}; + } }; } // namespace op diff --git a/src/frontends/pytorch/src/op/size.cpp b/src/frontends/pytorch/src/op/size.cpp index d8f1ee28123c10..2eca5f2707e53d 100644 --- a/src/frontends/pytorch/src/op/size.cpp +++ b/src/frontends/pytorch/src/op/size.cpp @@ -2,10 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" #include "utils.hpp" namespace ov { @@ -16,10 +18,25 @@ namespace op { using namespace ov::op; OutputVector translate_size(const NodeContext& context) { - num_inputs_check(context, 1, 2); - auto shape = context.mark_node(std::make_shared(context.get_input(0), element::i64)); + num_inputs_check(context, 1, 2, true); + auto data = context.get_input(0); + Output shape; + + auto complex_type_mark = as_type_ptr(data.get_node_shared_ptr()); + if (complex_type_mark) { + data = complex_type_mark->input_value(0); + shape = context.mark_node(std::make_shared(data, element::i64)); + + auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + auto stop = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + auto step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); + shape = context.mark_node(std::make_shared(shape, zero, stop, step, zero)); + } else { + shape = context.mark_node(std::make_shared(data, element::i64)); + } + if (context.input_is_none(1)) { - return shape->outputs(); + return {shape}; } else { auto axis_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); return {context.mark_node(std::make_shared(shape, context.get_input(1), axis_0))}; diff --git a/src/frontends/pytorch/src/op/stft.cpp b/src/frontends/pytorch/src/op/stft.cpp index 8e478835fdcdd6..678f44dcbe1edf 100644 --- a/src/frontends/pytorch/src/op/stft.cpp +++ b/src/frontends/pytorch/src/op/stft.cpp @@ -4,6 +4,7 @@ #include "openvino/op/stft.hpp" +#include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/constant.hpp" @@ -78,8 +79,6 @@ OutputVector translate_stft(const NodeContext& context) { if (!context.input_is_none(7)) { return_complex = context.const_input(7); } - PYTORCH_OP_CONVERSION_CHECK(!return_complex, - "aten::stft conversion is currently supported with return_complex=False only."); // Perform STFT constexpr bool transpose_frames = true; @@ -88,8 +87,10 @@ OutputVector translate_stft(const NodeContext& context) { if (normalized) { const auto nfft_convert = context.mark_node(std::make_shared(n_fft, stft)); const auto divisor = context.mark_node(std::make_shared(nfft_convert)); - const auto norm_stft = context.mark_node(std::make_shared(stft, divisor)); - return {norm_stft}; + stft = context.mark_node(std::make_shared(stft, divisor)); + } + if (return_complex) { + return {context.mark_node(std::make_shared(stft, stft->get_element_type()))}; } else { return {stft}; } diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index fe4e84bd47d45e..f00391e08e2a32 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -59,6 +59,7 @@ OP_CONVERTER(translate_celu); OP_CONVERTER(translate_channel_shuffle); OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_col2im); +OP_CONVERTER(translate_complex); OP_CONVERTER(translate_constant); OP_CONVERTER(translate_conv_transposend); OP_CONVERTER(translate_convnd); @@ -86,6 +87,8 @@ OP_CONVERTER(translate_expm1); OP_CONVERTER(translate_eye); OP_CONVERTER(translate_fake_quantize_per_channel_affine); OP_CONVERTER(translate_fake_quantize_per_tensor_affine); +OP_CONVERTER(translate_fft_irfftn); +OP_CONVERTER(translate_fft_rfftn); OP_CONVERTER(translate_fill); OP_CONVERTER(translate_fill_diagonal); OP_CONVERTER(translate_flatten); @@ -108,6 +111,7 @@ OP_CONVERTER(translate_hann_window); OP_CONVERTER(translate_hardtanh); OP_CONVERTER(translate_if); OP_CONVERTER(translate_im2col); +OP_CONVERTER(translate_imag); OP_CONVERTER(translate_index); OP_CONVERTER(translate_index_add); OP_CONVERTER(translate_index_copy_); @@ -192,6 +196,7 @@ OP_CONVERTER(translate_randn); OP_CONVERTER(translate_randint); OP_CONVERTER(translate_rand_like); OP_CONVERTER(translate_randn_like); +OP_CONVERTER(translate_real); OP_CONVERTER(translate_reciprocal); OP_CONVERTER(translate_relu6); OP_CONVERTER(translate_remainder); @@ -246,6 +251,8 @@ OP_CONVERTER(translate_upsample_nearest3d); OP_CONVERTER(translate_upsample_trilinear3d); OP_CONVERTER(translate_var); OP_CONVERTER(translate_var_mean); +OP_CONVERTER(translate_view_as_complex); +OP_CONVERTER(translate_view_as_real); OP_CONVERTER(translate_weight_norm); OP_CONVERTER(translate_where); OP_CONVERTER(translate_zeros); @@ -423,7 +430,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::clip", op::translate_clamp}, {"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd {"aten::col2im", op::translate_col2im}, - // aten::complex - Supported in limited set of patterns + {"aten::complex", op::translate_complex}, {"aten::concat", op::translate_cat}, {"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail, // we assume all tensors are contiguous @@ -468,8 +475,8 @@ const std::unordered_map get_supported_ops_ts() { {"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine}, {"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine}, {"aten::feature_dropout", op::skip_node}, - // aten::fft_irfftn - Supported in limited set of patterns - // aten::fft_rfftn - Supported in limited set of patterns + {"aten::fft_irfftn", op::translate_fft_irfftn}, + {"aten::fft_rfftn", op::translate_fft_rfftn}, {"aten::fill", op::translate_fill}, {"aten::fill_diagonal", op::translate_fill_diagonal}, {"aten::flatten", op::quantizable_op}, @@ -496,7 +503,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::hardswish", op::quantizable_op>}, {"aten::hardtanh", op::quantizable_op}, {"aten::im2col", op::translate_im2col}, - // aten::imag - Supported in limited set of patterns + {"aten::imag", op::translate_imag}, // aten::index - Supported in limited set of patterns {"aten::index_copy_", op::inplace_op}, {"aten::index_fill_", op::inplace_op}, @@ -604,7 +611,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::randint", op::translate_randint}, {"aten::randn", op::translate_randn}, {"aten::randn_like", op::translate_randn_like}, - // aten::real - Supported in limited set of patterns + {"aten::real", op::translate_real}, {"aten::reciprocal", op::optional_out}, {"aten::reciprocal_", op::inplace_op}, // aten::reflection_pad2d - Supported in limited set of patterns @@ -696,6 +703,8 @@ const std::unordered_map get_supported_ops_ts() { {"aten::var_mean", op::translate_var_mean}, {"aten::view", op::quantizable_op}, {"aten::view_as", op::translate_reshape_as}, + {"aten::view_as_complex", op::translate_view_as_complex}, + {"aten::view_as_real", op::translate_view_as_real}, {"aten::wait", op::skip_node}, {"aten::where", op::translate_where}, {"aten::zero", op::translate_zeros_like}, @@ -979,6 +988,8 @@ const std::unordered_map get_supported_ops_fx() { {"aten.var.correction", op::translate_var_fx}, {"aten.var_mean.correction", op::translate_var_mean_fx}, {"aten.view.default", op::translate_reshape}, + {"aten.view_as_complex.default", op::translate_view_as_complex}, + {"aten.view_as_real.default", op::translate_view_as_real}, {"aten.where.self", op::translate_where}, {"aten.zeros.default", op::translate_zeros_fx}, {"aten.zeros.names", op::translate_zeros_fx}, diff --git a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp b/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp deleted file mode 100644 index cb80987e4511ae..00000000000000 --- a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "irfftn_complex_replacer.hpp" - -#include "openvino/core/rt_info.hpp" -#include "openvino/op/concat.hpp" -#include "openvino/op/convert.hpp" -#include "openvino/op/convert_like.hpp" -#include "openvino/op/equal.hpp" -#include "openvino/op/gather.hpp" -#include "openvino/op/irdft.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/range.hpp" -#include "openvino/op/reduce_prod.hpp" -#include "openvino/op/scatter_update.hpp" -#include "openvino/op/select.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/sqrt.hpp" -#include "openvino/op/squeeze.hpp" -#include "openvino/op/subtract.hpp" -#include "openvino/op/unsqueeze.hpp" -#include "openvino/op/util/framework_node.hpp" -#include "openvino/pass/pattern/matcher.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace pass { - -using namespace ov::pass; -using namespace ov::op; - -IRFFTNComplexReplacer::IRFFTNComplexReplacer() { - // Transformation used to replace combination of aten::complex -> aten::fft_irfftn torch operators. - // Pattern: aten::complex -> aten::fft_irfftn - auto fft_op = pattern::wrap_type(); - - ov::matcher_pass_callback irfftn_callback = [](pattern::Matcher& m) { - // "aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor" - auto irfftn_op = cast_fw_node(m.get_match_root(), "aten::fft_irfftn"); - if (!irfftn_op) { - return false; - } - auto const_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); - auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0}); - auto const_scalar_0 = v0::Constant::create(element::i32, Shape{}, {0}); - auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1}); - auto const_scalar_1 = v0::Constant::create(element::i32, Shape{}, {1}); - auto const_2 = v0::Constant::create(element::i32, Shape{1}, {2}); - - // Check whether input node being aten::complex. - auto fw_node_complex_input = cast_fw_node(irfftn_op->input_value(0).get_node_shared_ptr(), "aten::complex"); - if (!fw_node_complex_input) { - return false; - } - - // Concatenate real and imag parts over additional, last dimension. - auto real = std::make_shared(fw_node_complex_input->input_value(0), const_neg_1); - auto imag = std::make_shared(fw_node_complex_input->input_value(1), const_neg_1); - NodeVector complex = {real, imag}; - auto input = std::make_shared(complex, -1); - - // Input shape of complex number (excluding dimension created by concatenation of real and imag) - auto complex_input_shape = std::make_shared(fw_node_complex_input->input_value(0), element::i32); - auto input_rank = std::make_shared(complex_input_shape, element::i32); - auto input_rank_scalar = std::make_shared(input_rank); - - // Inputs can be either none or ListConstruct. Check whether input values should be used or should be set to - // default values. - bool dim_use_default = is_none_node(irfftn_op->input_value(2)); - bool s_use_default = is_none_node(irfftn_op->input_value(1)); - // Can be None constant, when used check s_use_default. - auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1)); - raw_s_input_maybe = std::make_shared(raw_s_input_maybe, element::i32); - - // Handle dim parameter containing vector of integers indicating dimensions to be transformed. - std::shared_ptr dim; - if (!dim_use_default) { - // Dim values is provided, load from input. - dim = std::make_shared(concat_list_construct(irfftn_op->input_value(2)), element::i32); - } else if (!s_use_default) { - // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. - auto s_len = std::make_shared(raw_s_input_maybe, element::i32); - auto range_start = std::make_shared(input_rank, s_len); - auto range_start_scalar = std::make_shared(range_start); - dim = std::make_shared(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32); - } else { - // Dim and s are set to default, use all of dimensions. - dim = std::make_shared(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32); - } - - // Calculate default s values. Use full available size except last element, which is set to even value in last - // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) - auto default_s_raw = std::make_shared(complex_input_shape, dim, const_0); - auto last_s = std::make_shared(default_s_raw, const_neg_1, const_0); - auto last_s_m_1 = std::make_shared(last_s, const_1); - auto s_upd = std::make_shared(last_s_m_1, const_2); - auto s_shape = std::make_shared(default_s_raw, element::i32); - auto last_s_idx = std::make_shared(s_shape, const_1); - auto default_s = std::make_shared(default_s_raw, last_s_idx, s_upd, const_0); - - // Handle s parameter containing vector of intigers indicating signal sizes for dimensions. - std::shared_ptr s; - if (!s_use_default) { - // Values for s were provided. Replace -1 values with default full size in given dimension. - auto full_s_cond = std::make_shared(raw_s_input_maybe, const_neg_1); - s = std::make_shared(full_s_cond, default_s, raw_s_input_maybe); - } else { - // Value for s was set to default. - s = default_s; - } - - // Handle norm parameter indicating normalization mode to use. Defaults to "backward". - std::string norm; - if (const auto& fw_node_mode = - ov::as_type_ptr(irfftn_op->input_value(3).get_node_shared_ptr())) { - const auto& attrs = fw_node_mode->get_attrs(); - if (attrs.find("string_value") != attrs.end()) { - norm = attrs.at("string_value"); - } else { - norm = "backward"; - } - } else { - add_exception_to_fw_node(irfftn_op, "aten::fft_irfftn: could not retrive value for norm attribute."); - return false; - } - - auto irdft = std::make_shared(input, dim, s); - - // Apply normalizations. - auto n_int = std::make_shared(s, const_0); - auto n = std::make_shared(n_int, irdft); - std::shared_ptr normalized_irfftn; - if (norm == "forward") { - normalized_irfftn = std::make_shared(irdft, n); - } else if (norm == "backward") { - normalized_irfftn = irdft; - } else if (norm == "ortho") { - auto sqrt_n = std::make_shared(n); - normalized_irfftn = std::make_shared(irdft, sqrt_n); - } else { - add_exception_to_fw_node( - irfftn_op, - "aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); - return false; - } - - copy_runtime_info({irfftn_op, fw_node_complex_input}, normalized_irfftn); - normalized_irfftn->set_friendly_name(irfftn_op->get_friendly_name()); - replace_node(irfftn_op, normalized_irfftn); - return true; - }; - auto m = std::make_shared(fft_op, "ov::frontend::pytorch::pass::IRFFTNComplexReplacer"); - this->register_matcher(m, irfftn_callback); -}; - -} // namespace pass -} // namespace pytorch -} // namespace frontend -} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp b/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp deleted file mode 100644 index c75c6e51f92571..00000000000000 --- a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pass.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace pass { - -class IRFFTNComplexReplacer : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ov::frontend::pytorch::pass::IRFFTNComplexReplacer"); - IRFFTNComplexReplacer(); -}; - -} // namespace pass -} // namespace pytorch -} // namespace frontend -} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp b/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp deleted file mode 100644 index b90e3121930c71..00000000000000 --- a/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "rfftn_complex_replacer.hpp" - -#include "openvino/core/rt_info.hpp" -#include "openvino/op/convert.hpp" -#include "openvino/op/convert_like.hpp" -#include "openvino/op/divide.hpp" -#include "openvino/op/equal.hpp" -#include "openvino/op/gather.hpp" -#include "openvino/op/range.hpp" -#include "openvino/op/rdft.hpp" -#include "openvino/op/reduce_prod.hpp" -#include "openvino/op/select.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/slice.hpp" -#include "openvino/op/split.hpp" -#include "openvino/op/sqrt.hpp" -#include "openvino/op/squeeze.hpp" -#include "openvino/op/subtract.hpp" -#include "openvino/op/util/framework_node.hpp" -#include "openvino/pass/pattern/matcher.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace pass { - -using namespace ov::pass; -using namespace ov::op; - -RFFTNComplexReplacer::RFFTNComplexReplacer() { - // Transformation used to replace combination of aten::fft_rfftn -> {aten::real, aten::imag} torch operators. - // Pattern: aten::fft_rfftn -> {aten::real, aten::imag} - auto fft_op = pattern::wrap_type(); - ov::matcher_pass_callback rfftn_callback = [](pattern::Matcher& m) { - // Schema: "aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor" - auto rfftn_op = cast_fw_node(m.get_match_root(), "aten::fft_rfftn"); - if (!rfftn_op) { - return false; - } - auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1}); - auto const_0 = v0::Constant::create(element::i32, Shape{}, {0}); - auto const_1 = v0::Constant::create(element::i32, Shape{}, {1}); - - auto input = rfftn_op->input_value(0); - auto input_shape = std::make_shared(input, element::i32); - auto input_rank = std::make_shared(input_shape, element::i32); - auto input_rank_scalar = std::make_shared(input_rank); - - // Inputs can be either none or ListConstruct. Check whether input values should be used or should be set to - // default values. - bool dim_use_default = is_none_node(rfftn_op->input_value(2)); - bool s_use_default = is_none_node(rfftn_op->input_value(1)); - // Can be None constant, when used check s_use_default. - auto raw_s_input_maybe = concat_list_construct(rfftn_op->input_value(1)); - raw_s_input_maybe = std::make_shared(raw_s_input_maybe, element::i32); - - // Handle dim parameter containing vector of intigers indicating dimensions to be transformed. - std::shared_ptr dim; - if (!dim_use_default) { - // Dim values is provided, load from input. - dim = std::make_shared(concat_list_construct(rfftn_op->input_value(2)), element::i32); - } else if (!s_use_default) { - // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. - auto s_len = std::make_shared(raw_s_input_maybe, element::i32); - auto slice_start = std::make_shared(input_rank, s_len); - auto slice_start_scalar = std::make_shared(slice_start); - dim = std::make_shared(slice_start_scalar, input_rank_scalar, const_1, element::i32); - } else { - // Dim and s are set to default, use all of dimensions. - dim = std::make_shared(const_0, input_rank_scalar, const_1, element::i32); - } - - // Handle s parameter containing vector of intigers indicating signal sizes for dimensions. - std::shared_ptr s; - if (!s_use_default) { - // Values for s were provided. Replace -1 values with default full size in given dimension. - auto full_s_cond = std::make_shared(raw_s_input_maybe, const_neg_1); - auto full_s_values = std::make_shared(input_shape, dim, const_0); - s = std::make_shared(full_s_cond, full_s_values, raw_s_input_maybe); - } else { - // Value for s was set to default, use full size for all dimensions. - s = std::make_shared(input_shape, dim, const_0); - } - - // Handle norm parameter indicating normalization mode to use. Defaults to "backward". - std::string norm; - if (const auto& fw_node_mode = - ov::as_type_ptr(rfftn_op->input_value(3).get_node_shared_ptr())) { - const auto& attrs = fw_node_mode->get_attrs(); - if (attrs.find("string_value") != attrs.end()) { - norm = attrs.at("string_value"); - } else { - norm = "backward"; - } - } else { - add_exception_to_fw_node(rfftn_op, "aten::fft_rfftn: could not retrive value for norm attribute."); - return false; - } - - auto rdft = std::make_shared(input, dim, s); - - // Apply normalizations - auto n_int = std::make_shared(s, const_0); - auto n = std::make_shared(n_int, rdft); - std::shared_ptr normalized_rfftn; - if (norm == "forward") { - // Normalize by 1/n - normalized_rfftn = std::make_shared(rdft, n); - } else if (norm == "backward") { - // No normalization - normalized_rfftn = rdft; - } else if (norm == "ortho") { - // Normalize by 1/sqrt(n) - auto sqrt_n = std::make_shared(n); - normalized_rfftn = std::make_shared(rdft, sqrt_n); - } else { - add_exception_to_fw_node( - rfftn_op, - "aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); - return false; - } - - // Replace outputs that are either torch operators aten::real or aten::imag. Apply squeeze to remove last - // dimension used to concatenate. - auto normalized_rfftn_splitted = std::make_shared(normalized_rfftn, const_neg_1, 2); - auto rfftn_outs = rfftn_op->get_users(); - bool rval = false; - for (auto& out : rfftn_outs) { - if (auto real_op = cast_fw_node(out, "aten::real")) { - auto squeezed = std::make_shared(normalized_rfftn_splitted->output(0), const_neg_1); - copy_runtime_info({rfftn_op, real_op}, squeezed); - squeezed->set_friendly_name(real_op->get_friendly_name()); - replace_node(real_op, squeezed); - rval = true; - } - if (auto imag_op = cast_fw_node(out, "aten::imag")) { - auto squeezed = std::make_shared(normalized_rfftn_splitted->output(1), const_neg_1); - copy_runtime_info({rfftn_op, imag_op}, squeezed); - squeezed->set_friendly_name(imag_op->get_friendly_name()); - replace_node(imag_op, squeezed); - rval = true; - } - } - add_exception_to_fw_node( - rfftn_op, - "aten::fft_rfftn: Unsupported output node. Only aten::real and aten::imag are supported."); - return rval; - }; - - auto m = std::make_shared(fft_op, "ov::frontend::pytorch::pass::RFFTNComplexReplacer"); - this->register_matcher(m, rfftn_callback); -}; - -} // namespace pass -} // namespace pytorch -} // namespace frontend -} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp b/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp deleted file mode 100644 index 5420b7c9a01a04..00000000000000 --- a/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (C) 2018-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pass.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace pass { - -class RFFTNComplexReplacer : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ov::frontend::pytorch::pass::RFFTNComplexReplacer"); - RFFTNComplexReplacer(); -}; - -} // namespace pass -} // namespace pytorch -} // namespace frontend -} // namespace ov diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index da0b5c5cd24d61..85f9dc55a3b862 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -7,6 +7,7 @@ #include "op_table.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/core/validation_util.hpp" +#include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/pytorch/decoder.hpp" #include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" @@ -40,15 +41,23 @@ namespace pytorch { using namespace ov::op; -void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) { +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs, bool allow_complex) { auto num_inputs = context.get_input_size(); FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs, "Got less inputs ", num_inputs, " than expected ", min_inputs); + if (!allow_complex) { + // verify that no input is complex + for (size_t i = 0; i < std::min(num_inputs, max_inputs); i++) { + auto input = context.get_input(static_cast(i)); + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + PYTORCH_OP_CONVERSION_CHECK(!complex_type_mark, "The operation doesn't allow complex type."); + } + } for (auto i = max_inputs; i < num_inputs; i++) { - FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected."); + FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected: ", i + 1); } } @@ -836,6 +845,16 @@ bool index_tensor_on_list(ov::pass::NodeRegistry& rg, return true; } +Output get_complex_shape(const NodeContext& context, const Output& complex_input) { + auto input_shape = context.mark_node(std::make_shared(complex_input, element::i32)); + + auto zero = v0::Constant::create(element::i32, Shape{1}, {0}); + auto stop = v0::Constant::create(element::i32, Shape{1}, {-1}); + auto step = v0::Constant::create(element::i32, Shape{1}, {1}); + // Removing last dim from shape + return context.mark_node(std::make_shared(input_shape, zero, stop, step, zero)); +} + } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 5eb3f4aa4f64c0..ece73b3ea86ea1 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -35,7 +35,7 @@ const std::string& get_pytorch_prefix(); OPENVINO_ASSERT_HELPER(::ov::frontend::OpConversionFailure, "", (COND), get_pytorch_prefix(), __VA_ARGS__) #endif -void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs); +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs, bool allow_complex = false); Output make_optional_bias(const Output& base_op, const NodeContext& context, @@ -136,6 +136,8 @@ bool index_tensor_on_list(ov::pass::NodeRegistry& rg, Output& new_output, bool& use_input_as_output); +Output get_complex_shape(const NodeContext& context, const Output& complex_input); + namespace op { template OutputVector inplace_op(const NodeContext& context) { diff --git a/tests/layer_tests/pytorch_tests/test_permute.py b/tests/layer_tests/pytorch_tests/test_permute.py index d8fb94145bada7..efbd77d371eb89 100644 --- a/tests/layer_tests/pytorch_tests/test_permute.py +++ b/tests/layer_tests/pytorch_tests/test_permute.py @@ -11,46 +11,54 @@ def _prepare_input(self): import numpy as np return (np.random.randn(1, 3, 224, 224).astype(np.float32),) - def create_model(self, order): + def create_model(self, order, complex_type): import torch class aten_permute(torch.nn.Module): - def __init__(self, order): - super(aten_permute, self).__init__() + def __init__(self, order, complex_type): + super().__init__() self.order = order + self.complex_type = complex_type def forward(self, x): - return torch.permute(x, self.order) - - ref_net = None - - return aten_permute(order), ref_net, "aten::permute" - - @pytest.mark.parametrize("order", [[0, 2, 3, 1], [0, 3, 1, 2], [0, -1, 1, -2]]) + if self.complex_type: + x = torch.reshape(x, x.shape[:-1] + (-1, 2)) + x = torch.view_as_complex(x) + res = torch.permute(x, self.order) + if self.complex_type: + res = torch.view_as_real(res) + return res + + return aten_permute(order, complex_type), None, "aten::permute" + + @pytest.mark.parametrize("order", [[0, 2, 3, 1], + [0, 3, 1, 2], + [0, -1, 1, -2]]) + @pytest.mark.parametrize("complex_type", [True, False]) @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export - def test_permute(self, order, ie_device, precision, ir_version): - self._test(*self.create_model(order), ie_device, precision, ir_version) + def test_permute(self, order, complex_type, ie_device, precision, ir_version): + self._test(*self.create_model(order, complex_type), ie_device, precision, ir_version) class TestPermuteList(PytorchLayerTest): def _prepare_input(self, permute_shape): import numpy as np - return (np.random.randn(1, 3, 224, 224).astype(np.float32), np.random.randn(*permute_shape).astype(np.float32)) + return (np.random.randn(1, 3, 224, 224).astype(np.float32), + np.random.randn(*permute_shape).astype(np.float32)) def create_model(self): import torch - class aten_permute(torch.nn.Module): - + class aten_permute_list(torch.nn.Module): def forward(self, x, y): y_shape = y.shape return torch.permute(x, [y_shape[0] - 1, y_shape[1] - 1, y_shape[2] - 1, y_shape[3] - 1]) ref_net = None - return aten_permute(), ref_net, ["aten::permute", "prim::ListConstruct"] + return aten_permute_list(), ref_net, ["aten::permute", "prim::ListConstruct"] @pytest.mark.parametrize("order", [[1, 3, 4, 2], [1, 4, 2, 3]]) @pytest.mark.nightly @@ -58,4 +66,5 @@ def forward(self, x, y): @pytest.mark.precommit_torch_export def test_permute_list(self, order, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, - kwargs_to_prepare_input={"permute_shape": order}, dynamic_shapes=ie_device != "GPU") + kwargs_to_prepare_input={"permute_shape": order}, + dynamic_shapes=ie_device != "GPU") diff --git a/tests/layer_tests/pytorch_tests/test_reshape.py b/tests/layer_tests/pytorch_tests/test_reshape.py index 7174d6022b4ca1..5266e8e00c5c1d 100644 --- a/tests/layer_tests/pytorch_tests/test_reshape.py +++ b/tests/layer_tests/pytorch_tests/test_reshape.py @@ -1,31 +1,38 @@ # Copyright (C) 2018-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import random import numpy as np import pytest -import random from pytorch_layer_test_class import PytorchLayerTest class TestReshape(PytorchLayerTest): - def _prepare_input(self): - return (np.random.uniform(0, 50, (1, 12, 12, 24)).astype(np.float32)) + def _prepare_input(self, complex_type): + shape = (1, 12, 12, 24) + if complex_type: + shape += (2,) + return (np.random.uniform(0, 50, shape).astype(np.float32)) - def create_model(self, shape): + def create_model(self, shape, complex_type): import torch class aten_reshape(torch.nn.Module): - def __init__(self, shape): - super(aten_reshape, self).__init__() + def __init__(self, shape, complex_type): + super().__init__() self.shape = shape + self.complex_type = complex_type def forward(self, x): - return torch.reshape(x, self.shape) + if self.complex_type: + x = torch.view_as_complex(x) + res = torch.reshape(x, self.shape) + if self.complex_type: + res = torch.view_as_real(res) + return res - ref_net = None - - return aten_reshape(shape), ref_net, "aten::reshape" + return aten_reshape(shape, complex_type), None, "aten::reshape" @pytest.mark.parametrize(("shape"), [ [-1, 6], @@ -37,16 +44,20 @@ def forward(self, x): [24, 1, -1, 12], [24, 1, 1, -1, 12], ]) + @pytest.mark.parametrize("complex_type", [True, False]) @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export @pytest.mark.precommit_fx_backend - def test_reshape(self, shape, ie_device, precision, ir_version): - self._test(*self.create_model(shape), ie_device, precision, ir_version) + def test_reshape(self, shape, complex_type, ie_device, precision, ir_version): + self._test(*self.create_model(shape, complex_type), + ie_device, precision, ir_version, + kwargs_to_prepare_input={"complex_type": complex_type}) + class TestDynamicReshape(PytorchLayerTest): def _prepare_input(self): - last_dym = random.randint(1,2) + last_dym = random.randint(1, 2) return (np.random.uniform(0, 50, (1, 12, 12, 24)).astype(np.float32), last_dym) def create_model(self, shape): @@ -54,17 +65,14 @@ def create_model(self, shape): class aten_reshape(torch.nn.Module): def __init__(self, shape): - super(aten_reshape, self).__init__() + super().__init__() self.shape = shape def forward(self, x, dym): - #return torch.reshape(x, self.shape) dym2 = int(torch.ops.aten.sym_size(x, 3)/dym) return torch.reshape(x, [12, 12, dym2, dym]) - ref_net = None - - return aten_reshape(shape), ref_net, "aten::reshape" + return aten_reshape(shape), None, "aten::reshape" @pytest.mark.parametrize(("shape"), [ [12, 12, 24, 1], diff --git a/tests/layer_tests/pytorch_tests/test_size.py b/tests/layer_tests/pytorch_tests/test_size.py index 050d1d818df1b2..f3e0e98dccb327 100644 --- a/tests/layer_tests/pytorch_tests/test_size.py +++ b/tests/layer_tests/pytorch_tests/test_size.py @@ -7,24 +7,38 @@ class TestSize(PytorchLayerTest): - def _prepare_input(self, input_shape): + def _prepare_input(self, input_shape, complex_type): import numpy as np + if complex_type: + input_shape += [2] return (np.random.randn(*input_shape).astype(np.float32),) - def create_model(self): + def create_model(self, complex_type): import torch class aten_size(torch.nn.Module): + def __init__(self, complex_type): + super().__init__() + self.complex_type = complex_type + def forward(self, x): + if self.complex_type: + x = torch.view_as_complex(x) return torch.tensor(x.shape) - ref_net = None + op = aten_size(complex_type) - op = aten_size() + return op, None, "aten::size" - return op, ref_net, "aten::size" @pytest.mark.nightly @pytest.mark.precommit - @pytest.mark.parametrize("input_shape", [[1,], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]) - def test_size(self, input_shape, ie_device, precision, ir_version): - self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape}) + @pytest.mark.parametrize("input_shape", [[1,], + [1, 2], + [1, 2, 3], + [1, 2, 3, 4], + [1, 2, 3, 4, 5]]) + @pytest.mark.parametrize("complex_type", [True, False]) + def test_size(self, input_shape, complex_type, ie_device, precision, ir_version): + self._test(*self.create_model(complex_type), ie_device, precision, ir_version, + kwargs_to_prepare_input={"input_shape": input_shape, + "complex_type": complex_type}) diff --git a/tests/layer_tests/pytorch_tests/test_stft.py b/tests/layer_tests/pytorch_tests/test_stft.py index f90962e5f1daa7..a2097b1f1fe453 100644 --- a/tests/layer_tests/pytorch_tests/test_stft.py +++ b/tests/layer_tests/pytorch_tests/test_stft.py @@ -98,7 +98,7 @@ def __init__(self, n_fft, hop_length, win_length, center, pad_mode, normalized, self.return_complex = return_complex def forward(self, x): - return torch.stft( + stft = torch.stft( x, self.n_fft, hop_length=self.hop_length, @@ -110,6 +110,10 @@ def forward(self, x): onesided=self.onesided, return_complex=self.return_complex, ) + if self.return_complex: + return torch.view_as_real(stft) + else: + return stft ref_net = None @@ -128,9 +132,9 @@ def forward(self, x): [16, None, None, False, "reflect", False, True, False], # hop & win length None [16, 4, None, False, "reflect", False, True, False], # win_length None [16, 4, 16, False, "reflect", True, True, False], # normalized True + [16, 4, 16, False, "reflect", False, True, True], # return_complex True # Unsupported cases: [16, 4, 16, False, "reflect", False, False, False], # onesided False - [16, 4, 16, False, "reflect", False, True, True], # reutrn_complex True ]) def test_stft_not_supported_attrs(self, n_fft, hop_length, win_length, center, pad_mode, normalized, onesided, return_complex, ie_device, precision, ir_version, trace_model): if ie_device == "GPU": @@ -144,9 +148,5 @@ def test_stft_not_supported_attrs(self, n_fft, hop_length, win_length, center, p pytest.xfail( reason="aten::stft conversion is currently supported with onesided=True only") - if return_complex is True: - pytest.xfail( - reason="aten::stft conversion is currently supported with return_complex=False only") - self._test(*self.create_model_with_attrs(n_fft, hop_length, win_length, center, pad_mode, normalized, onesided, return_complex), ie_device, precision, ir_version, kwargs_to_prepare_input={}, trace_model=trace_model)