From 2bcbe4c53c622f6acda39bcda5296ba68f413fb6 Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 29 Aug 2023 15:21:10 +0800 Subject: [PATCH] Add tests --- paconvert/api_mapping.json | 29 +++++++++++ paconvert/api_matcher.py | 31 ++++++++++++ tests/test_linalg_lu_factor.py | 76 +++++++++++++++++++++++++++++ tests/test_linalg_lu_factor_ex.py | 81 +++++++++++++++++++++++++++++++ 4 files changed, 217 insertions(+) create mode 100644 tests/test_linalg_lu_factor.py create mode 100644 tests/test_linalg_lu_factor_ex.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index dd045688a..adb41357a 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -5771,6 +5771,35 @@ "A": "x" } }, + "torch.linalg.lu_factor": { + "Matcher": "TupleAssignMatcher", + "paddle_api": "paddle.linalg.lu", + "args_list": [ + "A", + "pivot", + "out" + ], + "kwargs_change": { + "A": "x" + } + }, + "torch.linalg.lu_factor_ex": { + "Matcher": "TripleAssignMatcher", + "paddle_api": "paddle.linalg.lu", + "args_list": [ + "A", + "pivot", + "check_errors", + "out" + ], + "kwargs_change": { + "A": "x", + "check_errors": "" + }, + "paddle_default_kwargs": { + "get_infos": "True" + } + }, "torch.linalg.matmul": { "Matcher": "GenericMatcher", "paddle_api": "paddle.matmul", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 4951d4000..9d14da0ad 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3443,6 +3443,37 @@ def generate_code(self, kwargs): return code.strip("\n") +class TripleAssignMatcher(BaseMatcher): + def generate_code(self, kwargs): + kwargs = self.set_paddle_default_kwargs(kwargs) + kwargs_change = {} + if "kwargs_change" in self.api_mapping: + kwargs_change = self.api_mapping["kwargs_change"] + + for k in kwargs_change: + if k in kwargs: + if kwargs_change[k]: + kwargs[kwargs_change[k]] = kwargs.pop(k) + else: + kwargs.pop(k) + + if "out" in kwargs: + out_v = kwargs.pop("out") + API_TEMPLATE = textwrap.dedent( + """ + out1, out2, out3 = {}({}) + paddle.assign(out1, {}[0]), paddle.assign(out2, {}[1]), paddle.assign(out3, {}[2]) + """ + ) + code = API_TEMPLATE.format( + self.get_paddle_api(), self.kwargs_to_str(kwargs), out_v, out_v, out_v + ) + return code.strip("\n") + else: + code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) + return code.strip("\n") + + class RoundMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: diff --git a/tests/test_linalg_lu_factor.py b/tests/test_linalg_lu_factor.py new file mode 100644 index 000000000..2e2f83f22 --- /dev/null +++ b/tests/test_linalg_lu_factor.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.lu_factor") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots = torch.linalg.lu_factor(x) + """ + ) + obj.run(pytorch_code, ["LU", "pivots"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots = torch.linalg.lu_factor(A=x) + """ + ) + obj.run(pytorch_code, ["LU", "pivots"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots = torch.linalg.lu_factor(pivot=True, A=x) + """ + ) + obj.run(pytorch_code, ["LU", "pivots"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int)) + LU, pivots = torch.linalg.lu_factor(x, pivot=True, out=out) + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int)) + LU, pivots = torch.linalg.lu_factor(A=x, pivot=True, out=out) + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "out"]) diff --git a/tests/test_linalg_lu_factor_ex.py b/tests/test_linalg_lu_factor_ex.py new file mode 100644 index 000000000..3d4f61357 --- /dev/null +++ b/tests/test_linalg_lu_factor_ex.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.lu_factor_ex") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots, info = torch.linalg.lu_factor_ex(x) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots, info = torch.linalg.lu_factor_ex(A=x) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots, info = torch.linalg.lu_factor_ex(pivot=True, A=x) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int), torch.tensor([], dtype=torch.int)) + LU, pivots, info = torch.linalg.lu_factor_ex(x, pivot=True, check_errors=False, out=out) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int), torch.tensor([], dtype=torch.int)) + LU, pivots, info = torch.linalg.lu_factor_ex(A=x, pivot=True, check_errors=True, out=out) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"])