Skip to content

Commit

Permalink
[PIR] No.43 Migrate paddle.where into pir (PaddlePaddle#57667)
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 authored and jiahy0825 committed Oct 16, 2023
1 parent e6ac3a0 commit 305e2d8
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


vjp_interface_declare_gen_op_list = [
'where',
"tanh",
"mean",
"divide",
Expand Down Expand Up @@ -66,6 +67,7 @@
'triu',
]
vjp_interface_implementation_gen_op_list = [
'where',
"tanh",
"mean",
"divide",
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@


VJPS = [
'where_grad',
'tril_grad',
'triu_grad',
'tanh_grad',
Expand Down Expand Up @@ -152,6 +153,7 @@
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

BACKENDS = [
'where_grad',
'tril_grad',
'triu_grad',
'add_n',
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
convert_np_dtype_to_dtype_,
core,
in_dynamic_mode,
in_dynamic_or_pir_mode,
)

# from ..base.layers import has_inf #DEFINE_ALIAS
Expand Down Expand Up @@ -686,7 +687,7 @@ def where(condition, x=None, y=None, name=None):
broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
broadcast_condition = paddle.cast(broadcast_condition, 'bool')

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
else:
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
Expand Down
11 changes: 8 additions & 3 deletions test/legacy_test/test_where_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def setUp(self):
self.outputs = {'Out': np.where(self.cond, self.x, self.y)}

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn)
self.check_output(check_cinn=self.check_cinn, check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out', check_cinn=self.check_cinn)
self.check_grad(
['X', 'Y'], 'Out', check_cinn=self.check_cinn, check_new_ir=True
)

def init_config(self):
self.x = np.random.uniform((-3), 5, 100).astype('float64')
Expand Down Expand Up @@ -82,7 +84,9 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_cinn=self.check_cinn)
self.check_output_with_place(
place, check_cinn=self.check_cinn, check_new_ir=True
)

def test_check_grad(self):
place = core.CUDAPlace(0)
Expand All @@ -92,6 +96,7 @@ def test_check_grad(self):
'Out',
numeric_grad_delta=0.05,
check_cinn=self.check_cinn,
check_new_ir=True,
)

def init_config(self):
Expand Down

0 comments on commit 305e2d8

Please sign in to comment.