Skip to content

Commit

Permalink
[PT FE] Support aten::atan2 for pytorch models (openvinotoolkit#27026)
Browse files Browse the repository at this point in the history
### Details:
- add atan2 operator and unit tests

### Tickets:
- [[Good First Issue]: Support aten::atan2 for pytorch
models](openvinotoolkit#20575)

---------

Co-authored-by: Michal Lukaszewski <[email protected]>
Co-authored-by: Maxim Vafin <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent dfa6235 commit c5e16fc
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
99 changes: 99 additions & 0 deletions src/frontends/pytorch/src/op/atan2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#define _USE_MATH_DEFINES

#include <math.h>

#include <memory>

#include "openvino/core/type/element_type.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/atan.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_atan2(const NodeContext& context) {
// atan2(input, other, *) → Tensor
num_inputs_check(context, 2, 2);
Output<Node> lhs;
Output<Node> rhs;

std::tie(lhs, rhs) = get_inputs_with_promoted_types(context, 0, 1);

auto div = context.mark_node(std::make_shared<v1::Divide>(lhs, rhs));

auto atan = context.mark_node(std::make_shared<v0::Atan>(div));

// create some constants to adjust result according to quadrant.
auto zero = context.mark_node(v0::Constant::create(ov::element::i32, Shape{}, {0}));
auto pi = context.mark_node(v0::Constant::create(ov::element::f64, Shape{}, {M_PI}));
auto half_pi = context.mark_node(v0::Constant::create(ov::element::f64, Shape{}, {M_PI_2}));
auto neg_half_pi = context.mark_node(v0::Constant::create(ov::element::f64, Shape{}, {-M_PI_2}));

zero = context.mark_node(std::make_shared<v1::ConvertLike>(zero, rhs));
pi = context.mark_node(std::make_shared<v1::ConvertLike>(pi, rhs));
half_pi = context.mark_node(std::make_shared<v1::ConvertLike>(half_pi, rhs));
neg_half_pi = context.mark_node(std::make_shared<v1::ConvertLike>(neg_half_pi, rhs));

// x > 0, no adjustment needed
auto x_greater_than_zero = context.mark_node(std::make_shared<v1::Greater>(rhs, zero));

// x < 0 and y >= 0, need to plus pi
auto y_greater_equal_zero = context.mark_node(std::make_shared<v1::GreaterEqual>(lhs, zero));
auto x_less_than_zero = context.mark_node(std::make_shared<v1::Less>(rhs, zero));
auto add_pi_condition = context.mark_node(std::make_shared<v1::LogicalAnd>(x_less_than_zero, y_greater_equal_zero));

// x < 0 and y < 0, need to minus pi
auto y_less_than_zero = std::make_shared<v1::Less>(lhs, zero);
auto subtract_pi_condition =
context.mark_node(std::make_shared<v1::LogicalAnd>(x_less_than_zero, y_less_than_zero));

// x = 0 and y > 0, pi/2
auto x_equal_zero = std::make_shared<v1::Equal>(rhs, zero);
auto y_greater_than_zero = std::make_shared<v1::Greater>(lhs, zero);
auto half_pi_condition = context.mark_node(std::make_shared<v1::LogicalAnd>(x_equal_zero, y_greater_than_zero));

// x = 0 and y < 0, -pi/2
auto neg_half_pi_condition = context.mark_node(std::make_shared<v1::LogicalAnd>(x_equal_zero, y_less_than_zero));

auto special_case_condition =
context.mark_node(std::make_shared<v1::LogicalOr>(half_pi_condition, neg_half_pi_condition));

// do adjustment
auto atan_plus_pi = context.mark_node(std::make_shared<v1::Add>(atan, pi));
auto atan_minus_pi = context.mark_node(std::make_shared<v1::Subtract>(atan, pi));

// select result
auto ajusted_case = context.mark_node(std::make_shared<v1::Select>(add_pi_condition, atan_plus_pi, atan_minus_pi));
auto special_case = context.mark_node(std::make_shared<v1::Select>(half_pi_condition, half_pi, neg_half_pi));
auto adjusted_atan = context.mark_node(std::make_shared<v1::Select>(x_greater_than_zero, atan, ajusted_case));
auto result = context.mark_node(std::make_shared<v1::Select>(special_case_condition, special_case, adjusted_atan));

return {result};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
3 changes: 3 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argmin);
OP_CONVERTER(translate_as_strided);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_atan2);
OP_CONVERTER(translate_avg_pool1d);
OP_CONVERTER(translate_avg_pool2d);
OP_CONVERTER(translate_avg_pool3d);
Expand Down Expand Up @@ -385,6 +386,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::atanh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>, 1>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
{"aten::atan2", op::translate_atan2},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_pool1d>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_pool2d>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_pool3d>},
Expand Down Expand Up @@ -776,6 +778,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.asinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>},
{"aten.atan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>},
{"aten.atanh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
{"aten.atan2.default", op::translate_atan2},
{"aten.avg_pool2d.default", op::translate_avg_pool2d},
{"aten.avg_pool3d.default", op::translate_avg_pool3d},
{"aten.baddbmm.default", op::translate_addmm_fx},
Expand Down
80 changes: 80 additions & 0 deletions tests/layer_tests/pytorch_tests/test_atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest

@pytest.mark.parametrize("input_shape_rhs", [
[2, 5, 3, 4],
[1, 5, 3, 4],
[1]
])
class TestAtan2(PytorchLayerTest):

def _prepare_input(self):
return (np.random.randn(2, 5, 3, 4).astype(np.float32), self.input_rhs)

def create_model(self):

class aten_atan2(torch.nn.Module):
def __init__(self):
super(aten_atan2, self).__init__()

def forward(self, lhs, rhs):
return torch.arctan2(lhs, rhs)

ref_net = None

return aten_atan2(), ref_net, "aten::atan2"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.precommit_fx_backend
def test_atan2(self, ie_device, precision, ir_version, input_shape_rhs):
self.input_rhs = np.random.randn(*input_shape_rhs).astype(np.float32)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)

class TestAtan2Types(PytorchLayerTest):

def _prepare_input(self):
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
torch.randn(self.rhs_shape).to(self.rhs_type).numpy())

def create_model(self, lhs_type, rhs_type):

class aten_atan2(torch.nn.Module):
def __init__(self, lhs_type, rhs_type):
super(aten_atan2, self).__init__()
self.lhs_type = lhs_type
self.rhs_type = rhs_type

def forward(self, lhs, rhs):
return torch.arctan2(lhs.to(self.lhs_type), rhs.to(self.rhs_type))

ref_net = None

return aten_atan2(lhs_type, rhs_type), ref_net, "aten::atan2"

@pytest.mark.parametrize(("lhs_type", "rhs_type"),
[[torch.int, torch.float32],
[torch.int, torch.float64],
[torch.float32, torch.float64],
[torch.int64, torch.float32]
])
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
([2, 3], [1, 3]),
([3, 2, 3], [2, 3]),
])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
def test_atan2_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape):
self.lhs_type = lhs_type
self.lhs_shape = lhs_shape
self.rhs_type = rhs_type
self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, rhs_type),
ie_device, precision, ir_version, freeze_model=False, trace_model=True)

0 comments on commit c5e16fc

Please sign in to comment.