From dbb84095b85aa9fa16977660f1ab5c3a7f3502e6 Mon Sep 17 00:00:00 2001 From: Ruibin Cheung Date: Tue, 26 Sep 2023 10:25:35 +0800 Subject: [PATCH] [PIR] NO.33 Migrate paddle.equal into pir (#57678) * [PIR] Migrate paddle.equal into pir * add ut --- python/paddle/tensor/logic.py | 8 ++++---- test/legacy_test/test_compare_op.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index e0d051ea509f59..2755f9b75df21a 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -533,16 +533,16 @@ def equal(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [True , False, False]) """ - if not isinstance(y, (int, bool, float, Variable)): + if not isinstance(y, (int, bool, float, Variable, paddle.pir.OpResult)): raise TypeError( "Type of input args must be float, bool, int or Tensor, but received type {}".format( type(y) ) ) - if not isinstance(y, Variable): + if not isinstance(y, (Variable, paddle.pir.OpResult)): y = full(shape=[], dtype=x.dtype, fill_value=y) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.equal(x, y) else: check_variable_and_dtype( @@ -598,7 +598,7 @@ def equal_(x, y, name=None): out_shape, x.shape ) ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.equal_(x, y) diff --git a/test/legacy_test/test_compare_op.py b/test/legacy_test/test_compare_op.py index 3d29e25248554e..2bae19d180e2c6 100755 --- a/test/legacy_test/test_compare_op.py +++ b/test/legacy_test/test_compare_op.py @@ -64,7 +64,7 @@ def test_errors(self): create_test_class( 'greater_equal', _type_name, lambda _a, _b: _a >= _b, True ) - create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + create_test_class('equal', _type_name, lambda _a, _b: _a == _b, True) create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) @@ -473,7 +473,7 @@ def test_check_output(self): create_bf16_case('less_equal', lambda _a, _b: _a <= _b) create_bf16_case('greater_than', lambda _a, _b: _a > _b) create_bf16_case('greater_equal', lambda _a, _b: _a >= _b, True) -create_bf16_case('equal', lambda _a, _b: _a == _b) +create_bf16_case('equal', lambda _a, _b: _a == _b, True) create_bf16_case('not_equal', lambda _a, _b: _a != _b)