Skip to content

Commit

Permalink
[NewIR] No.30 Migrate paddle.matmul into pir (PaddlePaddle#57277)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix
  • Loading branch information
enkilee authored Sep 15, 2023
1 parent a46ed2a commit ffbb446
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..base.data_feeder import check_dtype, check_type, check_variable_and_dtype
from ..common_ops_import import Variable
from ..framework import LayerHelper, in_dynamic_mode
from ..framework import LayerHelper, in_dynamic_mode, in_dynamic_or_pir_mode
from .creation import full
from .manipulation import cast
from .math import _get_reduce_axis
Expand Down Expand Up @@ -225,7 +225,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
[10, 3, 5, 5]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.matmul(x, y, transpose_x, transpose_y)
else:
attrs = {
Expand Down
16 changes: 13 additions & 3 deletions test/legacy_test/test_matmul_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def setUp(self):

def test_check_output(self):
self.check_output(
check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True
check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True,
check_new_ir=True,
)

def test_check_grad(self):
Expand All @@ -110,6 +111,7 @@ def test_check_grad(self):
check_cinn=self.check_cinn
if hasattr(self, 'check_cinn')
else True,
check_new_ir=True,
)
else:
self.check_grad(
Expand All @@ -118,6 +120,7 @@ def test_check_grad(self):
check_cinn=self.check_cinn
if hasattr(self, 'check_cinn')
else True,
check_new_ir=True,
)


Expand Down Expand Up @@ -359,6 +362,7 @@ def test_check_output(self):
check_cinn=self.check_cinn
if hasattr(self, 'check_cinn')
else True,
check_new_ir=True,
)

def test_check_grad(self):
Expand All @@ -372,6 +376,7 @@ def test_check_grad(self):
check_cinn=self.check_cinn
if hasattr(self, 'check_cinn')
else True,
check_new_ir=True,
)

cls_name = "{}_{}".format(parent.__name__, "Fp16")
Expand Down Expand Up @@ -431,6 +436,7 @@ def test_check_output(self):
check_cinn=self.check_cinn
if hasattr(self, 'check_cinn')
else True,
check_new_ir=True,
)

def test_check_grad_x(self):
Expand All @@ -447,6 +453,7 @@ def test_check_grad_x(self):
check_cinn=self.check_cinn
if hasattr(self, 'check_cinn')
else True,
check_new_ir=True,
)

def test_check_grad_y(self):
Expand All @@ -463,6 +470,7 @@ def test_check_grad_y(self):
check_cinn=self.check_cinn
if hasattr(self, 'check_cinn')
else True,
check_new_ir=True,
)

def test_check_grad(self):
Expand Down Expand Up @@ -499,6 +507,7 @@ def setUp(self):
self.places.append(base.CUDAPlace(0))

def check_static_result(self, place):
paddle.enable_static()
with base.program_guard(base.Program(), base.Program()):
input_x = paddle.static.data(
name="input_x", shape=[4, 3], dtype="float32"
Expand All @@ -518,6 +527,7 @@ def check_static_result(self, place):
feed={"input_x": x_np, "input_y": y_np},
fetch_list=[result],
)
paddle.disable_static()

def test_static(self):
for place in self.places:
Expand Down Expand Up @@ -735,7 +745,7 @@ def init_input_output(self):
self.out = np.matmul(self.x, self.y)

def test_check_output(self):
self.check_output(check_cinn=False)
self.check_output(check_cinn=False, check_new_ir=True)


class TestInt32MatMulOpBroadcast(OpTest):
Expand Down Expand Up @@ -787,7 +797,7 @@ def init_input_output(self):
self.out = np.matmul(self.x, self.y)

def test_check_output(self):
self.check_output(check_cinn=False)
self.check_output(check_cinn=False, check_new_ir=True)


class TestInt64MatMulOpBroadcast(OpTest):
Expand Down

0 comments on commit ffbb446

Please sign in to comment.