diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index dd045688a..f806122be 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -5771,6 +5771,25 @@ "A": "x" } }, + "torch.linalg.lu_factor": { + "Matcher": "LinalgLufactorMatcher", + "paddle_api": "paddle.linalg.lu", + "args_list": [ + "A", + "pivot", + "out" + ] + }, + "torch.linalg.lu_factor_ex": { + "Matcher": "LinalgLufactorexMatcher", + "paddle_api": "paddle.linalg.lu", + "args_list": [ + "A", + "pivot", + "check_errors", + "out" + ] + }, "torch.linalg.matmul": { "Matcher": "GenericMatcher", "paddle_api": "paddle.matmul", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 4951d4000..7fa23ef69 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -4061,6 +4061,60 @@ def generate_code(self, kwargs): return code +class LinalgLufactorMatcher(BaseMatcher): + def generate_code(self, kwargs): + out_v = kwargs.pop("out") if "out" in kwargs else None + new_kwargs = {} + new_kwargs["x"] = kwargs.pop("A") + new_kwargs.update(kwargs) + if out_v: + API_TEMPLATE = textwrap.dedent( + """ + tmp_lu, tmp_p = {}({}) + paddle.assign(tmp_lu, {}[0]), paddle.assign(tmp_p, {}[1]) + """ + ) + code = API_TEMPLATE.format( + self.get_paddle_api(), self.kwargs_to_str(new_kwargs), out_v, out_v + ) + else: + code = "{}({})".format( + self.get_paddle_api(), self.kwargs_to_str(new_kwargs) + ) + return code + + +class LinalgLufactorexMatcher(BaseMatcher): + def generate_code(self, kwargs): + out_v = kwargs.pop("out") if "out" in kwargs else None + if "check_errors" in kwargs: + kwargs.pop("check_errors") + + new_kwargs = {} + new_kwargs["x"] = kwargs.pop("A") + new_kwargs.update(kwargs) + new_kwargs["get_infos"] = "True" + if out_v: + API_TEMPLATE = textwrap.dedent( + """ + tmp_lu, tmp_p, tmp_info = {}({}) + paddle.assign(tmp_lu, {}[0]), paddle.assign(tmp_p, {}[1]), paddle.assign(tmp_info, {}[2]) + """ + ) + code = API_TEMPLATE.format( + self.get_paddle_api(), + self.kwargs_to_str(new_kwargs), + out_v, + out_v, + out_v, + ) + else: + code = "{}({})".format( + self.get_paddle_api(), self.kwargs_to_str(new_kwargs) + ) + return code + + class QrMatcher(BaseMatcher): def generate_code(self, kwargs): some_v = kwargs.pop("some") if "some" in kwargs else None diff --git a/tests/test_linalg_lu_factor.py b/tests/test_linalg_lu_factor.py new file mode 100644 index 000000000..2e2f83f22 --- /dev/null +++ b/tests/test_linalg_lu_factor.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.lu_factor") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots = torch.linalg.lu_factor(x) + """ + ) + obj.run(pytorch_code, ["LU", "pivots"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots = torch.linalg.lu_factor(A=x) + """ + ) + obj.run(pytorch_code, ["LU", "pivots"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots = torch.linalg.lu_factor(pivot=True, A=x) + """ + ) + obj.run(pytorch_code, ["LU", "pivots"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int)) + LU, pivots = torch.linalg.lu_factor(x, pivot=True, out=out) + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int)) + LU, pivots = torch.linalg.lu_factor(A=x, pivot=True, out=out) + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "out"]) diff --git a/tests/test_linalg_lu_factor_ex.py b/tests/test_linalg_lu_factor_ex.py new file mode 100644 index 000000000..3d4f61357 --- /dev/null +++ b/tests/test_linalg_lu_factor_ex.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.lu_factor_ex") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots, info = torch.linalg.lu_factor_ex(x) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots, info = torch.linalg.lu_factor_ex(A=x) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + LU, pivots, info = torch.linalg.lu_factor_ex(pivot=True, A=x) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int), torch.tensor([], dtype=torch.int)) + LU, pivots, info = torch.linalg.lu_factor_ex(x, pivot=True, check_errors=False, out=out) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float64) + out = (torch.tensor([], dtype=torch.float64), torch.tensor([], dtype=torch.int), torch.tensor([], dtype=torch.int)) + LU, pivots, info = torch.linalg.lu_factor_ex(A=x, pivot=True, check_errors=True, out=out) + info = info.item() + """ + ) + obj.run(pytorch_code, ["LU", "pivots", "info"])