Skip to content

Commit

Permalink
[JAX FE]: Support jax.lax.eq and jax.lax.ne operation for JAX (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#26719)

### Details:
 - support jax.lax.eq and jax.lax.ne operation
 - enable unit tests for new operations
### Tickets:
 - [None](openvinotoolkit#26571)
  • Loading branch information
hub-bla authored Sep 22, 2024
1 parent 94749db commit f00ac41
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/frontends/jax/src/op/binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/not_equal.hpp"
#include "utils.hpp"

using namespace std;
Expand All @@ -25,8 +27,10 @@ OutputVector translate_binary_op(const NodeContext& context) {
return {binary_op};
}

template OutputVector translate_binary_op<v1::Equal>(const NodeContext& context);
template OutputVector translate_binary_op<v1::GreaterEqual>(const NodeContext& context);
template OutputVector translate_binary_op<v1::Greater>(const NodeContext& context);
template OutputVector translate_binary_op<v1::NotEqual>(const NodeContext& context);

} // namespace op
} // namespace jax
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

#include "openvino/op/add.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/erf.hpp"
#include "openvino/op/exp.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/not_equal.hpp"
#include "openvino/op/reduce_max.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/sqrt.hpp"
Expand Down Expand Up @@ -63,13 +65,15 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"device_put", op::skip_node},
{"div", op::translate_1to1_match_2_inputs<v1::Divide>},
{"dot_general", op::translate_dot_general},
{"eq", op::translate_binary_op<v1::Equal>},
{"erf", op::translate_1to1_match_1_input<v0::Erf>},
{"exp", op::translate_1to1_match_1_input<v0::Exp>},
{"ge", op::translate_binary_op<v1::GreaterEqual>},
{"gt", op::translate_binary_op<v1::Greater>},
{"integer_pow", op::translate_integer_pow},
{"max", op::translate_1to1_match_2_inputs<v1::Maximum>},
{"mul", op::translate_1to1_match_2_inputs<v1::Multiply>},
{"ne", op::translate_binary_op<v1::NotEqual>},
{"reduce_max", op::translate_reduce_op<v1::ReduceMax>},
{"reduce_sum", op::translate_reduce_op<v1::ReduceSum>},
{"reduce_window_max", op::translate_reduce_window_max},
Expand Down
6 changes: 4 additions & 2 deletions tests/layer_tests/jax_tests/test_binary_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def _prepare_input(self):

def create_model(self, input_shapes, binary_op, input_type):
reduce_map = {
'eq': lax.eq,
'ge': lax.ge,
'gt': lax.gt
'gt': lax.gt,
'ne': lax.ne
}

self.input_shapes = input_shapes
Expand All @@ -42,7 +44,7 @@ def jax_binary(x, y):
@pytest.mark.parametrize('input_shapes', [[[5], [1]], [[1], [5]], [[2, 2, 4], [1, 1, 4]],
[[5, 10], [5, 10]], [[2, 4, 6], [1, 4, 6]],
[[5, 8, 10, 128], [5, 1, 10, 128]]])
@pytest.mark.parametrize('binary_op', ['ge', 'gt'])
@pytest.mark.parametrize('binary_op', ['eq', 'ge', 'gt', 'ne'])
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64,
np.float16, np.float32, np.float64])
Expand Down

0 comments on commit f00ac41

Please sign in to comment.