Skip to content

Commit

Permalink
[NewIR] No.49 Migrate paddle.tril into pir (PaddlePaddle#57393)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccsuzzh authored Sep 22, 2023
1 parent 0717491 commit a2bc204
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
'slice_double',
'poisson',
'gumbel_softmax',
'tril',
'triu',
]
vjp_interface_implementation_gen_op_list = [
"tanh",
Expand Down Expand Up @@ -94,4 +96,6 @@
'slice_double',
'poisson',
'gumbel_softmax',
'tril',
'triu',
]
4 changes: 4 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@


VJPS = [
'tril_grad',
'triu_grad',
'tanh_grad',
'mean_grad',
'add_grad',
Expand Down Expand Up @@ -92,6 +94,8 @@
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

BACKENDS = [
'tril_grad',
'triu_grad',
'add_n',
'mean',
'sum',
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ def tril(x, diagonal=0, name=None):
[5 , 0 , 0 , 0 ],
[9 , 10, 0 , 0 ]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.tril(x, diagonal)
else:
return _tril_triu_op(LayerHelper('tril', **locals()))
Expand Down Expand Up @@ -1581,7 +1581,7 @@ def triu(x, diagonal=0, name=None):
[0 , 10, 11, 12]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.triu(x, diagonal)
else:
return _tril_triu_op(LayerHelper('triu', **locals()))
Expand Down
12 changes: 8 additions & 4 deletions test/legacy_test/test_tril_triu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_new_ir=True)

def init_dtype(self):
self.dtype = np.float64
Expand Down Expand Up @@ -86,11 +86,15 @@ def initTestCase(self):
self.X = np.arange(1, 101, dtype="float32").reshape([10, -1])

def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
self.check_output_with_place(core.CUDAPlace(0), check_new_ir=True)

def test_check_grad_normal(self):
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Out', numeric_grad_delta=0.05
core.CUDAPlace(0),
['X'],
'Out',
numeric_grad_delta=0.05,
check_new_ir=True,
)


Expand Down

0 comments on commit a2bc204

Please sign in to comment.