Skip to content

Commit

Permalink
[NewIR] No.13 Migrate paddle.multiply into pir (PaddlePaddle#57175)
Browse files Browse the repository at this point in the history
  • Loading branch information
enkilee authored Sep 14, 2023
1 parent b1ad1ec commit 96983a7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ def multiply(x, y, name=None):
[2, 4, 6]]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.multiply(x, y)
else:
if x.dtype != y.dtype:
Expand Down
18 changes: 15 additions & 3 deletions test/legacy_test/test_elementwise_mul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def setUp(self):

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(not self.use_mkldnn))
self.check_output(
check_dygraph=(not self.use_mkldnn),
check_new_ir=(not self.use_mkldnn),
)

def test_check_grad_normal(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
Expand All @@ -56,6 +59,7 @@ def test_check_grad_normal(self):
'Out',
check_dygraph=(not self.use_mkldnn),
check_prim=True,
check_new_ir=(not self.use_mkldnn),
)

def test_check_grad_ingore_x(self):
Expand All @@ -66,6 +70,7 @@ def test_check_grad_ingore_x(self):
no_grad_set=set("X"),
check_dygraph=(not self.use_mkldnn),
check_prim=True,
check_new_ir=(not self.use_mkldnn),
)

def test_check_grad_ingore_y(self):
Expand All @@ -76,6 +81,7 @@ def test_check_grad_ingore_y(self):
no_grad_set=set('Y'),
check_dygraph=(not self.use_mkldnn),
check_prim=True,
check_new_ir=(not self.use_mkldnn),
)

def init_input_output(self):
Expand Down Expand Up @@ -193,14 +199,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', check_prim=True)
self.check_grad(['X', 'Y'], 'Out', check_prim=True, check_new_ir=True)

def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
check_prim=True,
check_new_ir=True,
)

def test_check_grad_ingore_y(self):
Expand All @@ -209,6 +216,7 @@ def test_check_grad_ingore_y(self):
'Out',
no_grad_set=set('Y'),
check_prim=True,
check_new_ir=True,
)

def if_enable_cinn(self):
Expand Down Expand Up @@ -264,7 +272,8 @@ def setUp(self):

def test_check_output(self):
self.check_output(
check_dygraph=self.check_dygraph, check_prim=self.check_prim
check_dygraph=self.check_dygraph,
check_prim=self.check_prim,
)

def test_check_grad_normal(self):
Expand Down Expand Up @@ -418,6 +427,7 @@ def test_check_grad_normal(self):
'Out',
check_dygraph=(not self.use_mkldnn),
check_prim=True,
check_new_ir=(not self.use_mkldnn),
)

def test_check_grad_ingore_x(self):
Expand All @@ -428,6 +438,7 @@ def test_check_grad_ingore_x(self):
no_grad_set=set("X"),
check_dygraph=(not self.use_mkldnn),
check_prim=True,
check_new_ir=(not self.use_mkldnn),
)

def test_check_grad_ingore_y(self):
Expand All @@ -438,6 +449,7 @@ def test_check_grad_ingore_y(self):
no_grad_set=set('Y'),
check_dygraph=(not self.use_mkldnn),
check_prim=True,
check_new_ir=(not self.use_mkldnn),
)


Expand Down

0 comments on commit 96983a7

Please sign in to comment.