Skip to content

Commit

Permalink
[PIR] NO.33 Migrate paddle.equal into pir (PaddlePaddle#57678)
Browse files Browse the repository at this point in the history
* [PIR] Migrate paddle.equal into pir

* add ut
  • Loading branch information
BeingGod authored Sep 26, 2023
1 parent b63b983 commit dbb8409
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,16 +533,16 @@ def equal(x, y, name=None):
Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True,
[True , False, False])
"""
if not isinstance(y, (int, bool, float, Variable)):
if not isinstance(y, (int, bool, float, Variable, paddle.pir.OpResult)):
raise TypeError(
"Type of input args must be float, bool, int or Tensor, but received type {}".format(
type(y)
)
)
if not isinstance(y, Variable):
if not isinstance(y, (Variable, paddle.pir.OpResult)):
y = full(shape=[], dtype=x.dtype, fill_value=y)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.equal(x, y)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -598,7 +598,7 @@ def equal_(x, y, name=None):
out_shape, x.shape
)
)
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.equal_(x, y)


Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_errors(self):
create_test_class(
'greater_equal', _type_name, lambda _a, _b: _a >= _b, True
)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b, True)
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)


Expand Down Expand Up @@ -473,7 +473,7 @@ def test_check_output(self):
create_bf16_case('less_equal', lambda _a, _b: _a <= _b)
create_bf16_case('greater_than', lambda _a, _b: _a > _b)
create_bf16_case('greater_equal', lambda _a, _b: _a >= _b, True)
create_bf16_case('equal', lambda _a, _b: _a == _b)
create_bf16_case('equal', lambda _a, _b: _a == _b, True)
create_bf16_case('not_equal', lambda _a, _b: _a != _b)


Expand Down

0 comments on commit dbb8409

Please sign in to comment.