New GenAI API
Generative AI in only a few lines of code!
- Check out our guide
+ Check out our guide
OpenVINO models on Hugging Face!
@@ -194,6 +194,7 @@ Key Features
GET STARTED
LEARN OPENVINO
- OPENVINO WORKFLOW
+ HOW TO USE - MAIN WORKFLOW
+ HOW TO USE - GENERATIVE AI WORKFLOW
DOCUMENTATION
ABOUT OPENVINO
\ No newline at end of file
diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp
index af609088679e14..006a4e22e06304 100644
--- a/src/frontends/tensorflow/src/frontend.cpp
+++ b/src/frontends/tensorflow/src/frontend.cpp
@@ -471,7 +471,7 @@ std::shared_ptr FrontEnd::convert(const ov::frontend::InputModel::Ptr
"provides conversion extension(s): "
<< unsupported_ops_from_tokenizers
<< ". Install OpenVINO Tokenizers, refer to the documentation: "
- "https://docs.openvino.ai/2024/learn-openvino/llm_inference_guide/ov-tokenizers.html \n";
+ "https://docs.openvino.ai/2024/openvino-workflow-generative/ov-tokenizers.html \n";
}
}
From 0848f8630aca8e33bfbf56b68809d81c3a906c21 Mon Sep 17 00:00:00 2001
From: Maxim Vafin
Date: Fri, 17 Jan 2025 15:57:06 +0100
Subject: [PATCH 4/4] [PT FE] Improve support for complex data type (#28482)
### Details:
- *Remove transformations for FFT*
- *Use `ComplexTypeMark` to provide information about a complex type*
### Tickets:
- *CVS-159375*
---------
Signed-off-by: Maxim Vafin
Co-authored-by: Roman Kazantsev
---
src/frontends/pytorch/src/frontend.cpp | 9 +-
src/frontends/pytorch/src/op/complex.cpp | 84 +++++++
src/frontends/pytorch/src/op/fft.cpp | 208 ++++++++++++++++++
src/frontends/pytorch/src/op/permute.cpp | 35 ++-
src/frontends/pytorch/src/op/reshape.cpp | 26 ++-
src/frontends/pytorch/src/op/size.cpp | 23 +-
src/frontends/pytorch/src/op/stft.cpp | 9 +-
src/frontends/pytorch/src/op_table.cpp | 21 +-
.../transforms/irfftn_complex_replacer.cpp | 164 --------------
.../transforms/irfftn_complex_replacer.hpp | 24 --
.../src/transforms/rfftn_complex_replacer.cpp | 163 --------------
.../src/transforms/rfftn_complex_replacer.hpp | 24 --
src/frontends/pytorch/src/utils.cpp | 24 +-
src/frontends/pytorch/src/utils.hpp | 4 +-
.../layer_tests/pytorch_tests/test_permute.py | 43 ++--
.../layer_tests/pytorch_tests/test_reshape.py | 44 ++--
tests/layer_tests/pytorch_tests/test_size.py | 30 ++-
tests/layer_tests/pytorch_tests/test_stft.py | 12 +-
18 files changed, 497 insertions(+), 450 deletions(-)
create mode 100644 src/frontends/pytorch/src/op/complex.cpp
create mode 100644 src/frontends/pytorch/src/op/fft.cpp
delete mode 100644 src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp
delete mode 100644 src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp
delete mode 100644 src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp
delete mode 100644 src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp
diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp
index 04ba9a9c92c281..bb69e8fa313130 100644
--- a/src/frontends/pytorch/src/frontend.cpp
+++ b/src/frontends/pytorch/src/frontend.cpp
@@ -30,7 +30,6 @@
#include "transforms/dict_resolver.hpp"
#include "transforms/einsum_list_construct.hpp"
#include "transforms/index_loop_getitem_replacer.hpp"
-#include "transforms/irfftn_complex_replacer.hpp"
#include "transforms/listconstruct_replacer.hpp"
#include "transforms/min_max_prim_list_construct_replacer.hpp"
#include "transforms/prim_list_construct_pad.hpp"
@@ -40,7 +39,6 @@
#include "transforms/quantized_node_remover.hpp"
#include "transforms/remove_packing_ops.hpp"
#include "transforms/reverseprop_resolver.hpp"
-#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/softmax_reshape_elimination.hpp"
#include "transforms/string_equality_replacer.hpp"
#include "transforms/torchfx_gptq_pattern_replacer.hpp"
@@ -69,6 +67,11 @@ std::map get_unconverted_types_from_model(const std::s
if (!unconverted_ops_types.count(op_type_it->second)) {
unconverted_ops_types.emplace(op_type_it->second, std::move(exception_msg));
}
+ } else if (const auto& fw_node = ov::as_type_ptr(node)) {
+ auto op_type = std::string(fw_node->get_type_name());
+ if (!unconverted_ops_types.count(op_type)) {
+ unconverted_ops_types.emplace(op_type, "This is OpenVINO internal type.");
+ }
}
if (const auto& fw_node = ov::as_type_ptr(node)) {
for (size_t i = 0; i < fw_node->get_internal_subgraphs_size(); ++i) {
@@ -283,8 +286,6 @@ void FrontEnd::normalize(const std::shared_ptr& model) const {
manager.register_pass();
manager.register_pass();
manager.register_pass();
- manager.register_pass();
- manager.register_pass();
manager.register_pass();
manager.register_pass();
manager.register_pass();
diff --git a/src/frontends/pytorch/src/op/complex.cpp b/src/frontends/pytorch/src/op/complex.cpp
new file mode 100644
index 00000000000000..8ec0f5435e358b
--- /dev/null
+++ b/src/frontends/pytorch/src/op/complex.cpp
@@ -0,0 +1,84 @@
+// Copyright (C) 2018-2025 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "openvino/frontend/complex_type_mark.hpp"
+#include "openvino/frontend/pytorch/node_context.hpp"
+#include "openvino/op/concat.hpp"
+#include "openvino/op/split.hpp"
+#include "openvino/op/squeeze.hpp"
+#include "openvino/op/unsqueeze.hpp"
+#include "utils.hpp"
+
+namespace ov {
+namespace frontend {
+namespace pytorch {
+namespace op {
+
+using namespace ov::op;
+
+OutputVector translate_complex(const NodeContext& context) {
+ num_inputs_check(context, 2, 2);
+ auto real = context.get_input(0);
+ auto imag = context.get_input(1);
+
+ auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
+ real = context.mark_node(std::make_shared(real, const_neg_1));
+ imag = context.mark_node(std::make_shared(imag, const_neg_1));
+
+ auto complex = context.mark_node(std::make_shared(OutputVector{real, imag}, -1));
+
+ return {context.mark_node(std::make_shared(complex, complex->get_element_type()))};
+};
+
+OutputVector translate_imag(const NodeContext& context) {
+ num_inputs_check(context, 1, 1, true);
+ auto complex = context.get_input(0);
+
+ auto complex_type_mark = as_type_ptr(complex.get_node_shared_ptr());
+ PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::imag operation expects complex type tensor on input.");
+
+ complex = complex_type_mark->input_value(0);
+ auto axis = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
+ auto imag = context.mark_node(std::make_shared(complex, axis, 2))->output(1);
+
+ auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
+ return {context.mark_node(std::make_shared(imag, const_neg_1))};
+};
+
+OutputVector translate_real(const NodeContext& context) {
+ num_inputs_check(context, 1, 1, true);
+ auto complex = context.get_input(0);
+
+ auto complex_type_mark = as_type_ptr(complex.get_node_shared_ptr());
+ PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::real operation expects complex type tensor on input.");
+
+ complex = complex_type_mark->input_value(0);
+ auto axis = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
+ auto real = context.mark_node(std::make_shared(complex, axis, 2))->output(0);
+
+ auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
+ return {context.mark_node(std::make_shared(real, const_neg_1))};
+};
+
+OutputVector translate_view_as_real(const NodeContext& context) {
+ num_inputs_check(context, 1, 1, true);
+ auto complex = context.get_input(0);
+
+ auto complex_type_mark = as_type_ptr(complex.get_node_shared_ptr());
+ PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::real operation expects complex type tensor on input.");
+
+ return {complex_type_mark->input_value(0)};
+};
+
+OutputVector translate_view_as_complex(const NodeContext& context) {
+ num_inputs_check(context, 1, 1);
+ auto complex = context.get_input(0);
+
+ return {context.mark_node(std::make_shared(complex, complex.get_element_type()))};
+};
+
+} // namespace op
+} // namespace pytorch
+} // namespace frontend
+} // namespace ov
diff --git a/src/frontends/pytorch/src/op/fft.cpp b/src/frontends/pytorch/src/op/fft.cpp
new file mode 100644
index 00000000000000..0c2eb17c49d305
--- /dev/null
+++ b/src/frontends/pytorch/src/op/fft.cpp
@@ -0,0 +1,208 @@
+// Copyright (C) 2018-2025 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "openvino/frontend/complex_type_mark.hpp"
+#include "openvino/frontend/pytorch/node_context.hpp"
+#include "openvino/op/divide.hpp"
+#include "openvino/op/equal.hpp"
+#include "openvino/op/gather.hpp"
+#include "openvino/op/irdft.hpp"
+#include "openvino/op/multiply.hpp"
+#include "openvino/op/range.hpp"
+#include "openvino/op/rdft.hpp"
+#include "openvino/op/reduce_prod.hpp"
+#include "openvino/op/reshape.hpp"
+#include "openvino/op/scatter_update.hpp"
+#include "openvino/op/select.hpp"
+#include "openvino/op/shape_of.hpp"
+#include "openvino/op/sqrt.hpp"
+#include "openvino/op/squeeze.hpp"
+#include "openvino/op/subtract.hpp"
+#include "utils.hpp"
+
+namespace ov {
+namespace frontend {
+namespace pytorch {
+namespace op {
+
+using namespace ov::op;
+
+OutputVector translate_fft_rfftn(const NodeContext& context) {
+ // aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+ num_inputs_check(context, 1, 4);
+ auto input = context.get_input(0);
+
+ auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
+ auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
+ auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
+
+ Output input_shape;
+ Output input_rank_scalar;
+ std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, input, true);
+
+ Output raw_s;
+ // Inputs can be either none or List. Check whether input values should be used or should be set to default values.
+ if (!context.input_is_none(1)) {
+ // s is provided, load from input.
+ raw_s = get_input_concat_if_list(context, 1);
+ raw_s = context.mark_node(std::make_shared(raw_s, element::i32));
+ }
+ Output dim;
+ // Handle dim parameter containing vector of integers indicating dimensions to be transformed.
+ if (!context.input_is_none(2)) {
+ // dim is provided, load from input.
+ dim = get_input_concat_if_list(context, 2);
+ dim = context.mark_node(std::make_shared(dim, element::i32));
+ } else if (!context.input_is_none(1)) {
+ // If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
+ auto s_len = context.mark_node(std::make_shared(raw_s, element::i32));
+ auto slice_start = context.mark_node(std::make_shared(input_rank_scalar, s_len));
+ auto slice_start_scalar = context.mark_node(std::make_shared(slice_start));
+ dim = context.mark_node(
+ std::make_shared(slice_start_scalar, input_rank_scalar, const_1, element::i32));
+ } else {
+ // Dim and s are set to default, use all of dimensions.
+ dim = context.mark_node(std::make_shared(const_0, input_rank_scalar, const_1, element::i32));
+ }
+
+ Output s;
+ if (context.input_is_none(1)) {
+ // Value for s was set to default, use full size for all dimensions.
+ s = context.mark_node(std::make_shared(input_shape, dim, const_0));
+ } else {
+ // Values for s were provided. Replace -1 values with default full size in given dimension.
+ auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1));
+ auto full_s_values = context.mark_node(std::make_shared(input_shape, dim, const_0));
+ s = context.mark_node(std::make_shared(full_s_cond, full_s_values, raw_s));
+ }
+
+ // Handle norm parameter indicating normalization mode to use. Defaults to "backward".
+ std::string norm = "backward";
+ if (!context.input_is_none(3)) {
+ norm = context.const_input(3);
+ }
+
+ auto rdft = context.mark_node(std::make_shared(input, dim, s));
+
+ // Apply normalizations
+ auto n_int = context.mark_node(std::make_shared(s, const_0));
+ auto n = context.mark_node(std::make_shared(n_int, rdft));
+ Output normalized_rfftn;
+ if (norm == "forward") {
+ // Normalize by 1/n
+ normalized_rfftn = context.mark_node(std::make_shared(rdft, n));
+ } else if (norm == "backward") {
+ // No normalization
+ normalized_rfftn = rdft;
+ } else if (norm == "ortho") {
+ // Normalize by 1/sqrt(n)
+ auto sqrt_n = context.mark_node(std::make_shared(n));
+ normalized_rfftn = context.mark_node(std::make_shared(rdft, sqrt_n));
+ } else {
+ FRONT_END_THROW(
+ "aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
+ }
+
+ return {std::make_shared(normalized_rfftn, normalized_rfftn.get_element_type())};
+}
+
+OutputVector translate_fft_irfftn(const NodeContext& context) {
+ // aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+ num_inputs_check(context, 1, 4, true);
+ auto input = context.get_input(0);
+
+ auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr());
+ PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::fft_irfftn operation expects complex type tensor on input.");
+ input = complex_type_mark->input_value(0);
+
+ auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
+ auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
+ auto const_scalar_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
+ auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
+ auto const_scalar_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
+ auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
+
+ // Input shape of complex number (excluding dimension created by concatenation of real and imag)
+ auto complex_input_shape = get_complex_shape(context, input);
+ auto input_rank = context.mark_node(std::make_shared(complex_input_shape, element::i32));
+ auto input_rank_scalar = context.mark_node(std::make_shared(input_rank));
+
+ Output raw_s;
+ // Inputs can be either none or List. Check whether input values should be used or should be set to default values.
+ if (!context.input_is_none(1)) {
+ // s is provided, load from input.
+ raw_s = get_input_concat_if_list(context, 1);
+ raw_s = context.mark_node(std::make_shared(raw_s, element::i32));
+ }
+
+ // Handle dim parameter containing vector of integers indicating dimensions to be transformed.
+ Output dim;
+ if (!context.input_is_none(2)) {
+ // Dim values is provided, load from input.
+ dim = get_input_concat_if_list(context, 2);
+ dim = context.mark_node(std::make_shared(dim, element::i32));
+ } else if (!context.input_is_none(1)) {
+ // If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
+ auto s_len = context.mark_node(std::make_shared(raw_s, element::i32));
+ auto range_start = context.mark_node(std::make_shared(input_rank, s_len));
+ auto range_start_scalar = context.mark_node(std::make_shared(range_start));
+ dim = context.mark_node(
+ std::make_shared(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32));
+ } else {
+ // Dim and s are set to default, use all of dimensions.
+ dim = context.mark_node(
+ std::make_shared(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32));
+ }
+
+ // Calculate default s values. Use full available size except last element, which is set to even value in last
+ // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
+ auto default_s_raw = context.mark_node(std::make_shared(complex_input_shape, dim, const_0));
+ auto last_s = context.mark_node(std::make_shared(default_s_raw, const_neg_1, const_0));
+ auto last_s_m_1 = context.mark_node(std::make_shared(last_s, const_1));
+ auto s_upd = context.mark_node(std::make_shared(last_s_m_1, const_2));
+ auto s_shape = context.mark_node(std::make_shared(default_s_raw, element::i32));
+ auto last_s_idx = context.mark_node(std::make_shared(s_shape, const_1));
+ auto default_s = context.mark_node(std::make_shared(default_s_raw, last_s_idx, s_upd, const_0));
+
+ // Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
+ Output s;
+ if (!context.input_is_none(1)) {
+ // Values for s were provided. Replace -1 values with default full size in given dimension.
+ auto full_s_cond = context.mark_node(std::make_shared(raw_s, const_neg_1));
+ s = context.mark_node(std::make_shared(full_s_cond, default_s, raw_s));
+ } else {
+ // Value for s was set to default.
+ s = default_s;
+ }
+
+ // Handle norm parameter indicating normalization mode to use. Defaults to "backward".
+ std::string norm = "backward";
+ if (!context.input_is_none(3)) {
+ norm = context.const_input(3);
+ }
+
+ auto irdft = context.mark_node(std::make_shared(input, dim, s));
+
+ // Apply normalizations.
+ auto n_int = context.mark_node(std::make_shared(s, const_0));
+ auto n = context.mark_node(std::make_shared(n_int, irdft));
+ Output normalized_irfftn;
+ if (norm == "forward") {
+ normalized_irfftn = context.mark_node(std::make_shared(irdft, n));
+ } else if (norm == "backward") {
+ normalized_irfftn = irdft;
+ } else if (norm == "ortho") {
+ auto sqrt_n = context.mark_node(std::make_shared(n));
+ normalized_irfftn = context.mark_node(std::make_shared(irdft, sqrt_n));
+ } else {
+ FRONT_END_THROW(
+ "aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
+ }
+ return {normalized_irfftn};
+}
+
+} // namespace op
+} // namespace pytorch
+} // namespace frontend
+} // namespace ov
diff --git a/src/frontends/pytorch/src/op/permute.cpp b/src/frontends/pytorch/src/op/permute.cpp
index 46016ca8ca16a0..c724e38b8077b2 100644
--- a/src/frontends/pytorch/src/op/permute.cpp
+++ b/src/frontends/pytorch/src/op/permute.cpp
@@ -3,7 +3,10 @@
//
#include "openvino/core/validation_util.hpp"
+#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
+#include "openvino/op/concat.hpp"
+#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "utils.hpp"
@@ -12,17 +15,41 @@ namespace frontend {
namespace pytorch {
namespace op {
+using namespace ov::op;
+
OutputVector translate_permute(const NodeContext& context) {
- num_inputs_check(context, 2, 2);
+ num_inputs_check(context, 2, 2, true);
auto data = context.get_input(0);
auto order = get_input_concat_if_list(context, 1);
- auto rank = std::get<1>(get_shape_rank(context, data));
- auto rank_converted = context.mark_node(std::make_shared(rank, order));
+
+ Output rank;
+ auto complex_type_mark = as_type_ptr(data.get_node_shared_ptr());
+ if (complex_type_mark) {
+ data = complex_type_mark->input_value(0);
+ rank = std::get<1>(get_shape_rank(context, data));
+ auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
+ rank = context.mark_node(std::make_shared(rank, const_1));
+ } else {
+ rank = std::get<1>(get_shape_rank(context, data));
+ }
+
+ auto rank_converted = context.mark_node(std::make_shared(rank, order));
auto order_normalized = normalize_axis(context, order, rank_converted);
+
+ if (complex_type_mark) {
+ auto to_concat = OutputVector{order_normalized, rank_converted};
+ order_normalized = context.mark_node(std::make_shared(to_concat, 0));
+ }
+
if (const auto order_const = ov::util::get_constant_from_source(order_normalized)) {
order_normalized = order_const;
}
- return {context.mark_node(std::make_shared(data, order_normalized))};
+ auto permute = context.mark_node(std::make_shared(data, order_normalized));
+ if (complex_type_mark) {
+ const auto& complex_dtype = complex_type_mark->get_complex_part_type();
+ permute = context.mark_node(std::make_shared(permute, complex_dtype));
+ }
+ return {permute};
}
} // namespace op
diff --git a/src/frontends/pytorch/src/op/reshape.cpp b/src/frontends/pytorch/src/op/reshape.cpp
index 7524d0e3c4aaf4..b9dcfc8d9afc4a 100644
--- a/src/frontends/pytorch/src/op/reshape.cpp
+++ b/src/frontends/pytorch/src/op/reshape.cpp
@@ -4,6 +4,7 @@
#include "openvino/op/reshape.hpp"
+#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/squeeze.hpp"
@@ -15,15 +16,34 @@ namespace frontend {
namespace pytorch {
namespace op {
+using namespace ov::op;
+
OutputVector translate_reshape(const NodeContext& context) {
// Translation is used by both aten::view and aten::reshape.
// Schema: aten::view(Tensor input, int[] shape) -> Tensor
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor
// For shape parameter, int[] is converted into single dimensional Tensor.
- num_inputs_check(context, 2, 2);
+ num_inputs_check(context, 2, 2, true);
+ auto tensor = context.get_input(0);
auto shape = get_input_concat_if_list(context, 1);
- auto reshape = std::make_shared(context.get_input(0), shape, false);
- return {context.mark_node(reshape)};
+
+ auto complex_type_mark = as_type_ptr(tensor.get_node_shared_ptr());
+ if (complex_type_mark) {
+ tensor = complex_type_mark->input_value(0);
+ auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
+ const_2 = context.mark_node(std::make_shared(const_2, shape));
+
+ shape = context.mark_node(std::make_shared(OutputVector{shape, const_2}, 0));
+ }
+
+ auto reshape = context.mark_node(std::make_shared(tensor, shape, false));
+
+ if (complex_type_mark) {
+ const auto& complex_dtype = complex_type_mark->get_complex_part_type();
+ return {context.mark_node(std::make_shared(reshape, complex_dtype))};
+ } else {
+ return {reshape};
+ }
};
} // namespace op
diff --git a/src/frontends/pytorch/src/op/size.cpp b/src/frontends/pytorch/src/op/size.cpp
index d8f1ee28123c10..2eca5f2707e53d 100644
--- a/src/frontends/pytorch/src/op/size.cpp
+++ b/src/frontends/pytorch/src/op/size.cpp
@@ -2,10 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
//
+#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/shape_of.hpp"
+#include "openvino/op/slice.hpp"
#include "utils.hpp"
namespace ov {
@@ -16,10 +18,25 @@ namespace op {
using namespace ov::op;
OutputVector translate_size(const NodeContext& context) {
- num_inputs_check(context, 1, 2);
- auto shape = context.mark_node(std::make_shared(context.get_input(0), element::i64));
+ num_inputs_check(context, 1, 2, true);
+ auto data = context.get_input(0);
+ Output shape;
+
+ auto complex_type_mark = as_type_ptr(data.get_node_shared_ptr());
+ if (complex_type_mark) {
+ data = complex_type_mark->input_value(0);
+ shape = context.mark_node(std::make_shared(data, element::i64));
+
+ auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
+ auto stop = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
+ auto step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
+ shape = context.mark_node(std::make_shared(shape, zero, stop, step, zero));
+ } else {
+ shape = context.mark_node(std::make_shared(data, element::i64));
+ }
+
if (context.input_is_none(1)) {
- return shape->outputs();
+ return {shape};
} else {
auto axis_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
return {context.mark_node(std::make_shared(shape, context.get_input(1), axis_0))};
diff --git a/src/frontends/pytorch/src/op/stft.cpp b/src/frontends/pytorch/src/op/stft.cpp
index 8e478835fdcdd6..678f44dcbe1edf 100644
--- a/src/frontends/pytorch/src/op/stft.cpp
+++ b/src/frontends/pytorch/src/op/stft.cpp
@@ -4,6 +4,7 @@
#include "openvino/op/stft.hpp"
+#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
@@ -78,8 +79,6 @@ OutputVector translate_stft(const NodeContext& context) {
if (!context.input_is_none(7)) {
return_complex = context.const_input(7);
}
- PYTORCH_OP_CONVERSION_CHECK(!return_complex,
- "aten::stft conversion is currently supported with return_complex=False only.");
// Perform STFT
constexpr bool transpose_frames = true;
@@ -88,8 +87,10 @@ OutputVector translate_stft(const NodeContext& context) {
if (normalized) {
const auto nfft_convert = context.mark_node(std::make_shared(n_fft, stft));
const auto divisor = context.mark_node(std::make_shared(nfft_convert));
- const auto norm_stft = context.mark_node(std::make_shared(stft, divisor));
- return {norm_stft};
+ stft = context.mark_node(std::make_shared(stft, divisor));
+ }
+ if (return_complex) {
+ return {context.mark_node(std::make_shared(stft, stft->get_element_type()))};
} else {
return {stft};
}
diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp
index fe4e84bd47d45e..f00391e08e2a32 100644
--- a/src/frontends/pytorch/src/op_table.cpp
+++ b/src/frontends/pytorch/src/op_table.cpp
@@ -59,6 +59,7 @@ OP_CONVERTER(translate_celu);
OP_CONVERTER(translate_channel_shuffle);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_col2im);
+OP_CONVERTER(translate_complex);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
OP_CONVERTER(translate_convnd);
@@ -86,6 +87,8 @@ OP_CONVERTER(translate_expm1);
OP_CONVERTER(translate_eye);
OP_CONVERTER(translate_fake_quantize_per_channel_affine);
OP_CONVERTER(translate_fake_quantize_per_tensor_affine);
+OP_CONVERTER(translate_fft_irfftn);
+OP_CONVERTER(translate_fft_rfftn);
OP_CONVERTER(translate_fill);
OP_CONVERTER(translate_fill_diagonal);
OP_CONVERTER(translate_flatten);
@@ -108,6 +111,7 @@ OP_CONVERTER(translate_hann_window);
OP_CONVERTER(translate_hardtanh);
OP_CONVERTER(translate_if);
OP_CONVERTER(translate_im2col);
+OP_CONVERTER(translate_imag);
OP_CONVERTER(translate_index);
OP_CONVERTER(translate_index_add);
OP_CONVERTER(translate_index_copy_);
@@ -192,6 +196,7 @@ OP_CONVERTER(translate_randn);
OP_CONVERTER(translate_randint);
OP_CONVERTER(translate_rand_like);
OP_CONVERTER(translate_randn_like);
+OP_CONVERTER(translate_real);
OP_CONVERTER(translate_reciprocal);
OP_CONVERTER(translate_relu6);
OP_CONVERTER(translate_remainder);
@@ -246,6 +251,8 @@ OP_CONVERTER(translate_upsample_nearest3d);
OP_CONVERTER(translate_upsample_trilinear3d);
OP_CONVERTER(translate_var);
OP_CONVERTER(translate_var_mean);
+OP_CONVERTER(translate_view_as_complex);
+OP_CONVERTER(translate_view_as_real);
OP_CONVERTER(translate_weight_norm);
OP_CONVERTER(translate_where);
OP_CONVERTER(translate_zeros);
@@ -423,7 +430,7 @@ const std::unordered_map get_supported_ops_ts() {
{"aten::clip", op::translate_clamp},
{"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten::col2im", op::translate_col2im},
- // aten::complex - Supported in limited set of patterns
+ {"aten::complex", op::translate_complex},
{"aten::concat", op::translate_cat},
{"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail,
// we assume all tensors are contiguous
@@ -468,8 +475,8 @@ const std::unordered_map get_supported_ops_ts() {
{"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine},
{"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine},
{"aten::feature_dropout", op::skip_node},
- // aten::fft_irfftn - Supported in limited set of patterns
- // aten::fft_rfftn - Supported in limited set of patterns
+ {"aten::fft_irfftn", op::translate_fft_irfftn},
+ {"aten::fft_rfftn", op::translate_fft_rfftn},
{"aten::fill", op::translate_fill},
{"aten::fill_diagonal", op::translate_fill_diagonal},
{"aten::flatten", op::quantizable_op},
@@ -496,7 +503,7 @@ const std::unordered_map get_supported_ops_ts() {
{"aten::hardswish", op::quantizable_op>},
{"aten::hardtanh", op::quantizable_op},
{"aten::im2col", op::translate_im2col},
- // aten::imag - Supported in limited set of patterns
+ {"aten::imag", op::translate_imag},
// aten::index - Supported in limited set of patterns
{"aten::index_copy_", op::inplace_op},
{"aten::index_fill_", op::inplace_op},
@@ -604,7 +611,7 @@ const std::unordered_map get_supported_ops_ts() {
{"aten::randint", op::translate_randint},
{"aten::randn", op::translate_randn},
{"aten::randn_like", op::translate_randn_like},
- // aten::real - Supported in limited set of patterns
+ {"aten::real", op::translate_real},
{"aten::reciprocal", op::optional_out},
{"aten::reciprocal_", op::inplace_op},
// aten::reflection_pad2d - Supported in limited set of patterns
@@ -696,6 +703,8 @@ const std::unordered_map get_supported_ops_ts() {
{"aten::var_mean", op::translate_var_mean},
{"aten::view", op::quantizable_op},
{"aten::view_as", op::translate_reshape_as},
+ {"aten::view_as_complex", op::translate_view_as_complex},
+ {"aten::view_as_real", op::translate_view_as_real},
{"aten::wait", op::skip_node},
{"aten::where", op::translate_where},
{"aten::zero", op::translate_zeros_like},
@@ -979,6 +988,8 @@ const std::unordered_map get_supported_ops_fx() {
{"aten.var.correction", op::translate_var_fx},
{"aten.var_mean.correction", op::translate_var_mean_fx},
{"aten.view.default", op::translate_reshape},
+ {"aten.view_as_complex.default", op::translate_view_as_complex},
+ {"aten.view_as_real.default", op::translate_view_as_real},
{"aten.where.self", op::translate_where},
{"aten.zeros.default", op::translate_zeros_fx},
{"aten.zeros.names", op::translate_zeros_fx},
diff --git a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp b/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp
deleted file mode 100644
index cb80987e4511ae..00000000000000
--- a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp
+++ /dev/null
@@ -1,164 +0,0 @@
-// Copyright (C) 2018-2025 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#include "irfftn_complex_replacer.hpp"
-
-#include "openvino/core/rt_info.hpp"
-#include "openvino/op/concat.hpp"
-#include "openvino/op/convert.hpp"
-#include "openvino/op/convert_like.hpp"
-#include "openvino/op/equal.hpp"
-#include "openvino/op/gather.hpp"
-#include "openvino/op/irdft.hpp"
-#include "openvino/op/multiply.hpp"
-#include "openvino/op/range.hpp"
-#include "openvino/op/reduce_prod.hpp"
-#include "openvino/op/scatter_update.hpp"
-#include "openvino/op/select.hpp"
-#include "openvino/op/shape_of.hpp"
-#include "openvino/op/sqrt.hpp"
-#include "openvino/op/squeeze.hpp"
-#include "openvino/op/subtract.hpp"
-#include "openvino/op/unsqueeze.hpp"
-#include "openvino/op/util/framework_node.hpp"
-#include "openvino/pass/pattern/matcher.hpp"
-#include "openvino/pass/pattern/op/wrap_type.hpp"
-#include "utils.hpp"
-
-namespace ov {
-namespace frontend {
-namespace pytorch {
-namespace pass {
-
-using namespace ov::pass;
-using namespace ov::op;
-
-IRFFTNComplexReplacer::IRFFTNComplexReplacer() {
- // Transformation used to replace combination of aten::complex -> aten::fft_irfftn torch operators.
- // Pattern: aten::complex -> aten::fft_irfftn
- auto fft_op = pattern::wrap_type();
-
- ov::matcher_pass_callback irfftn_callback = [](pattern::Matcher& m) {
- // "aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor"
- auto irfftn_op = cast_fw_node(m.get_match_root(), "aten::fft_irfftn");
- if (!irfftn_op) {
- return false;
- }
- auto const_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
- auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0});
- auto const_scalar_0 = v0::Constant::create(element::i32, Shape{}, {0});
- auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
- auto const_scalar_1 = v0::Constant::create(element::i32, Shape{}, {1});
- auto const_2 = v0::Constant::create(element::i32, Shape{1}, {2});
-
- // Check whether input node being aten::complex.
- auto fw_node_complex_input = cast_fw_node(irfftn_op->input_value(0).get_node_shared_ptr(), "aten::complex");
- if (!fw_node_complex_input) {
- return false;
- }
-
- // Concatenate real and imag parts over additional, last dimension.
- auto real = std::make_shared(fw_node_complex_input->input_value(0), const_neg_1);
- auto imag = std::make_shared(fw_node_complex_input->input_value(1), const_neg_1);
- NodeVector complex = {real, imag};
- auto input = std::make_shared(complex, -1);
-
- // Input shape of complex number (excluding dimension created by concatenation of real and imag)
- auto complex_input_shape = std::make_shared(fw_node_complex_input->input_value(0), element::i32);
- auto input_rank = std::make_shared(complex_input_shape, element::i32);
- auto input_rank_scalar = std::make_shared(input_rank);
-
- // Inputs can be either none or ListConstruct. Check whether input values should be used or should be set to
- // default values.
- bool dim_use_default = is_none_node(irfftn_op->input_value(2));
- bool s_use_default = is_none_node(irfftn_op->input_value(1));
- // Can be None constant, when used check s_use_default.
- auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1));
- raw_s_input_maybe = std::make_shared(raw_s_input_maybe, element::i32);
-
- // Handle dim parameter containing vector of integers indicating dimensions to be transformed.
- std::shared_ptr dim;
- if (!dim_use_default) {
- // Dim values is provided, load from input.
- dim = std::make_shared(concat_list_construct(irfftn_op->input_value(2)), element::i32);
- } else if (!s_use_default) {
- // If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
- auto s_len = std::make_shared(raw_s_input_maybe, element::i32);
- auto range_start = std::make_shared(input_rank, s_len);
- auto range_start_scalar = std::make_shared(range_start);
- dim = std::make_shared(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32);
- } else {
- // Dim and s are set to default, use all of dimensions.
- dim = std::make_shared(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32);
- }
-
- // Calculate default s values. Use full available size except last element, which is set to even value in last
- // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
- auto default_s_raw = std::make_shared(complex_input_shape, dim, const_0);
- auto last_s = std::make_shared(default_s_raw, const_neg_1, const_0);
- auto last_s_m_1 = std::make_shared(last_s, const_1);
- auto s_upd = std::make_shared(last_s_m_1, const_2);
- auto s_shape = std::make_shared(default_s_raw, element::i32);
- auto last_s_idx = std::make_shared(s_shape, const_1);
- auto default_s = std::make_shared(default_s_raw, last_s_idx, s_upd, const_0);
-
- // Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
- std::shared_ptr s;
- if (!s_use_default) {
- // Values for s were provided. Replace -1 values with default full size in given dimension.
- auto full_s_cond = std::make_shared(raw_s_input_maybe, const_neg_1);
- s = std::make_shared(full_s_cond, default_s, raw_s_input_maybe);
- } else {
- // Value for s was set to default.
- s = default_s;
- }
-
- // Handle norm parameter indicating normalization mode to use. Defaults to "backward".
- std::string norm;
- if (const auto& fw_node_mode =
- ov::as_type_ptr(irfftn_op->input_value(3).get_node_shared_ptr())) {
- const auto& attrs = fw_node_mode->get_attrs();
- if (attrs.find("string_value") != attrs.end()) {
- norm = attrs.at("string_value");
- } else {
- norm = "backward";
- }
- } else {
- add_exception_to_fw_node(irfftn_op, "aten::fft_irfftn: could not retrive value for norm attribute.");
- return false;
- }
-
- auto irdft = std::make_shared(input, dim, s);
-
- // Apply normalizations.
- auto n_int = std::make_shared(s, const_0);
- auto n = std::make_shared(n_int, irdft);
- std::shared_ptr normalized_irfftn;
- if (norm == "forward") {
- normalized_irfftn = std::make_shared