From 5477e642f66cf19cb28ee6bc009041c71741b86c Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 5 Sep 2023 07:51:46 +0800 Subject: [PATCH] Fix lu_factor_ex --- paconvert/api_mapping.json | 2 +- paconvert/api_matcher.py | 41 +++++++++++++++++++++++++++++++ tests/test_linalg_lu_factor_ex.py | 5 ---- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index e4ae5183f..4a626e87b 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -5816,7 +5816,7 @@ } }, "torch.linalg.lu_factor_ex": { - "Matcher": "TripleAssignMatcher", + "Matcher": "LinalgLufactorexMatcher", "paddle_api": "paddle.linalg.lu", "args_list": [ "A", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 953b3c7c2..438aa7489 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3492,6 +3492,47 @@ def generate_code(self, kwargs): return code.strip("\n") +class LinalgLufactorexMatcher(BaseMatcher): + def generate_code(self, kwargs): + kwargs = self.set_paddle_default_kwargs(kwargs) + kwargs_change = {} + if "kwargs_change" in self.api_mapping: + kwargs_change = self.api_mapping["kwargs_change"] + + for k in kwargs_change: + if k in kwargs: + if kwargs_change[k]: + kwargs[kwargs_change[k]] = kwargs.pop(k) + else: + kwargs.pop(k) + + if "out" in kwargs: + out_v = kwargs.pop("out") + API_TEMPLATE = textwrap.dedent( + """ + out1, out2, out3 = {}({}) + out3 = paddle.to_tensor(out3.item(), dtype='int32') + paddle.assign(out1, {}[0]), paddle.assign(out2, {}[1]), paddle.assign(out3, {}[2]) + """ + ) + code = API_TEMPLATE.format( + self.get_paddle_api(), self.kwargs_to_str(kwargs), out_v, out_v, out_v + ) + return code.strip("\n") + else: + API_TEMPLATE = textwrap.dedent( + """ + out1, out2, out3 = {}({}) + out3 = paddle.to_tensor(out3.item(), dtype='int32') + (out1, out2, out3) + """ + ) + code = API_TEMPLATE.format( + self.get_paddle_api(), self.kwargs_to_str(kwargs) + ) + return code.strip("\n") + + class RoundMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: diff --git a/tests/test_linalg_lu_factor_ex.py b/tests/test_linalg_lu_factor_ex.py index 3d4f61357..f39168564 100644 --- a/tests/test_linalg_lu_factor_ex.py +++ b/tests/test_linalg_lu_factor_ex.py @@ -25,7 +25,6 @@ def test_case_1(): import torch x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) LU, pivots, info = torch.linalg.lu_factor_ex(x) - info = info.item() """ ) obj.run(pytorch_code, ["LU", "pivots", "info"]) @@ -37,7 +36,6 @@ def test_case_2(): import torch x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) LU, pivots, info = torch.linalg.lu_factor_ex(A=x) - info = info.item() """ ) obj.run(pytorch_code, ["LU", "pivots", "info"]) @@ -49,7 +47,6 @@ def test_case_3(): import torch x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) LU, pivots, info = torch.linalg.lu_factor_ex(pivot=True, A=x) - info = info.item() """ ) obj.run(pytorch_code, ["LU", "pivots", "info"]) @@ -62,7 +59,6 @@ def test_case_4(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int), torch.tensor([], dtype=torch.int)) LU, pivots, info = torch.linalg.lu_factor_ex(x, pivot=True, check_errors=False, out=out) - info = info.item() """ ) obj.run(pytorch_code, ["LU", "pivots", "info"]) @@ -75,7 +71,6 @@ def test_case_5(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int), torch.tensor([], dtype=torch.int)) LU, pivots, info = torch.linalg.lu_factor_ex(A=x, pivot=True, check_errors=True, out=out) - info = info.item() """ ) obj.run(pytorch_code, ["LU", "pivots", "info"])