Skip to content

Commit

Permalink
[PT FE] Support different aliases of existing operations
Browse files Browse the repository at this point in the history
Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin committed Jan 17, 2025
1 parent 87887f5 commit c6854f6
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/index_put_.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_index_put_(const NodeContext& context) {
OutputVector translate_index_put(const NodeContext& context) {
// Pass as PtFrameworkNode to register as `inplace_op`. Conversion to OV operators is done as transformation.
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
return {context.mark_node(node)};
Expand Down
8 changes: 6 additions & 2 deletions src/frontends/pytorch/src/op/log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,20 @@ OutputVector translate_log10(const NodeContext& context) {
};

OutputVector translate_logsumexp(const NodeContext& context) {
num_inputs_check(context, 1, 2);
num_inputs_check(context, 1, 3);
auto input = context.get_input(0);
ov::Output<ov::Node> dim;
if (!context.input_is_none(1)) {
dim = context.get_input(1);
} else {
dim = context.mark_node(get_axes_range(context, 0));
}
bool keepdim = false;
if (!context.input_is_none(2)) {
keepdim = context.const_input<bool>(2);
}
auto exp = context.mark_node(std::make_shared<v0::Exp>(input));
auto sum = context.mark_node(std::make_shared<v1::ReduceSum>(exp, dim, false));
auto sum = context.mark_node(std::make_shared<v1::ReduceSum>(exp, dim, keepdim));
auto log = context.mark_node(std::make_shared<v0::Log>(sum));
return {log};
};
Expand Down
7 changes: 5 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ OP_CONVERTER(translate_index);
OP_CONVERTER(translate_index_add);
OP_CONVERTER(translate_index_copy_);
OP_CONVERTER(translate_index_fill_);
OP_CONVERTER(translate_index_put_);
OP_CONVERTER(translate_index_put);
OP_CONVERTER(translate_index_select);
OP_CONVERTER(translate_instance_norm);
OP_CONVERTER(translate_int);
Expand Down Expand Up @@ -457,6 +457,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::empty", op::translate_empty},
{"aten::empty_like", op::translate_empty_like},
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::equal", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::erf", op::translate_erf},
{"aten::erfc", op::translate_erfc},
{"aten::exp", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>, 1>},
Expand Down Expand Up @@ -500,7 +501,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
// aten::index - Supported in limited set of patterns
{"aten::index_copy_", op::inplace_op<op::translate_index_copy_>},
{"aten::index_fill_", op::inplace_op<op::translate_index_fill_>},
{"aten::index_put_", op::inplace_op<op::translate_index_put_>},
{"aten::index_put", op::translate_index_put},
{"aten::index_add", op::translate_index_add},
{"aten::index_select", op::translate_index_select},
{"aten::instance_norm", op::translate_instance_norm},
Expand Down Expand Up @@ -543,6 +544,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::log2_", op::inplace_op<op::translate_log2>},
{"aten::log10", op::optional_out<op::translate_log10, 1>},
{"aten::log10_", op::inplace_op<op::translate_log10>},
{"aten::logsumexp", op::translate_logsumexp},
{"aten::lstm", op::translate_lstm},
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
{"aten::masked_fill", op::translate_masked_fill},
Expand Down Expand Up @@ -705,6 +707,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"ov_ext::embedding", op::translate_embedding_ext},
{"ov_ext::conv1d", op::translate_conv1d_ext},
{"ov_ext::linear", op::translate_linear},
{"prim::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"prim::Constant", op::translate_constant},
{"prim::device", op::translate_constant},
// prim::DictConstruct - Supported in limited set of patterns
Expand Down
34 changes: 34 additions & 0 deletions tests/layer_tests/pytorch_tests/test_logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class aten_logsumexp(torch.nn.Module):
def __init__(self, dim, keepdim) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, input_tensor):
return torch.logsumexp(input_tensor, dim=self.dim, keepdim=self.keepdim)


class TestLogsumexp(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 5, 9, 7),)

@pytest.mark.parametrize("dim", [
0, 1, 2, 3, -1, -2, -3, -4
])
@pytest.mark.parametrize("keepdim", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_fx_backend
def test_logsumexp(self, dim, keepdim, ie_device, precision, ir_version):
self._test(aten_logsumexp(dim, keepdim), None, "aten::logsumexp",
ie_device, precision, ir_version)
27 changes: 24 additions & 3 deletions tests/layer_tests/pytorch_tests/test_unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

class unary_op_net(torch.nn.Module):
def __init__(self, op, dtype):
super(unary_op_net, self).__init__()
super().__init__()
self.dtype = dtype
self.op = op

Expand All @@ -87,7 +87,7 @@ def forward(self, x):

class unary_op_out_net(torch.nn.Module):
def __init__(self, op, dtype):
super(unary_op_out_net, self).__init__()
super().__init__()
self.dtype = dtype
self.op = op

Expand All @@ -101,7 +101,7 @@ def forward(self, x):

class unary_func_op_inplace_net(torch.nn.Module):
def __init__(self, op, dtype):
super(unary_func_op_inplace_net, self).__init__()
super().__init__()
self.dtype = dtype
self.op = op

Expand All @@ -111,6 +111,17 @@ def forward(self, x):
return y, x1


class prim_abs_net(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.dtype = dtype

def forward(self, x):
x1 = x.to(self.dtype)
y = abs(x1)
return y, x1


class TestUnaryOp(PytorchLayerTest):
def _prepare_input(self):
# random number in range [1, 11)
Expand Down Expand Up @@ -265,3 +276,13 @@ def test_unary_func_op_inplace(self, op_type, dtype, ie_device, precision, ir_ve
self.dtype = dtype
self._test(unary_func_op_inplace_net(OPS[op_type], dtype), None, op_type + "_",
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int8, torch.uint8, torch.int32, torch.int64])
def test_prim_abs(self, dtype, ie_device, precision, ir_version):
self.dtype = dtype
self._test(prim_abs_net(dtype), None, "prim::abs",
ie_device, precision, ir_version)

0 comments on commit c6854f6

Please sign in to comment.