From f00ac415cf1bd3f7896076aba4f89b8fb827077e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hubert=20B=C5=82aszczyk?= <56601011+hub-bla@users.noreply.github.com> Date: Sun, 22 Sep 2024 17:25:03 +0200 Subject: [PATCH] [JAX FE]: Support jax.lax.eq and jax.lax.ne operation for JAX (#26719) ### Details: - support jax.lax.eq and jax.lax.ne operation - enable unit tests for new operations ### Tickets: - [None](https://github.com/openvinotoolkit/openvino/issues/26571) --- src/frontends/jax/src/op/binary_op.cpp | 4 ++++ src/frontends/jax/src/op_table.cpp | 4 ++++ tests/layer_tests/jax_tests/test_binary_comparison.py | 6 ++++-- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/frontends/jax/src/op/binary_op.cpp b/src/frontends/jax/src/op/binary_op.cpp index 9fd5cc12525b89..11e26166cfcf47 100644 --- a/src/frontends/jax/src/op/binary_op.cpp +++ b/src/frontends/jax/src/op/binary_op.cpp @@ -3,8 +3,10 @@ // #include "openvino/frontend/jax/node_context.hpp" +#include "openvino/op/equal.hpp" #include "openvino/op/greater.hpp" #include "openvino/op/greater_eq.hpp" +#include "openvino/op/not_equal.hpp" #include "utils.hpp" using namespace std; @@ -25,8 +27,10 @@ OutputVector translate_binary_op(const NodeContext& context) { return {binary_op}; } +template OutputVector translate_binary_op(const NodeContext& context); template OutputVector translate_binary_op(const NodeContext& context); template OutputVector translate_binary_op(const NodeContext& context); +template OutputVector translate_binary_op(const NodeContext& context); } // namespace op } // namespace jax diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index c0bcb85c724dbb..dcd54203f5b9bc 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -6,12 +6,14 @@ #include "openvino/op/add.hpp" #include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" #include "openvino/op/erf.hpp" #include "openvino/op/exp.hpp" #include "openvino/op/greater.hpp" #include "openvino/op/greater_eq.hpp" #include "openvino/op/maximum.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/not_equal.hpp" #include "openvino/op/reduce_max.hpp" #include "openvino/op/reduce_sum.hpp" #include "openvino/op/sqrt.hpp" @@ -63,6 +65,7 @@ const std::map get_supported_ops_jaxpr() { {"device_put", op::skip_node}, {"div", op::translate_1to1_match_2_inputs}, {"dot_general", op::translate_dot_general}, + {"eq", op::translate_binary_op}, {"erf", op::translate_1to1_match_1_input}, {"exp", op::translate_1to1_match_1_input}, {"ge", op::translate_binary_op}, @@ -70,6 +73,7 @@ const std::map get_supported_ops_jaxpr() { {"integer_pow", op::translate_integer_pow}, {"max", op::translate_1to1_match_2_inputs}, {"mul", op::translate_1to1_match_2_inputs}, + {"ne", op::translate_binary_op}, {"reduce_max", op::translate_reduce_op}, {"reduce_sum", op::translate_reduce_op}, {"reduce_window_max", op::translate_reduce_window_max}, diff --git a/tests/layer_tests/jax_tests/test_binary_comparison.py b/tests/layer_tests/jax_tests/test_binary_comparison.py index 0a85b1e4a551c2..2e2f7d917d68a7 100644 --- a/tests/layer_tests/jax_tests/test_binary_comparison.py +++ b/tests/layer_tests/jax_tests/test_binary_comparison.py @@ -27,8 +27,10 @@ def _prepare_input(self): def create_model(self, input_shapes, binary_op, input_type): reduce_map = { + 'eq': lax.eq, 'ge': lax.ge, - 'gt': lax.gt + 'gt': lax.gt, + 'ne': lax.ne } self.input_shapes = input_shapes @@ -42,7 +44,7 @@ def jax_binary(x, y): @pytest.mark.parametrize('input_shapes', [[[5], [1]], [[1], [5]], [[2, 2, 4], [1, 1, 4]], [[5, 10], [5, 10]], [[2, 4, 6], [1, 4, 6]], [[5, 8, 10, 128], [5, 1, 10, 128]]]) - @pytest.mark.parametrize('binary_op', ['ge', 'gt']) + @pytest.mark.parametrize('binary_op', ['eq', 'ge', 'gt', 'ne']) @pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64, np.float16, np.float32, np.float64])