From 7c8b5d6a81f8f852377cf4eceaac5313b08461fb Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 28 Aug 2023 20:14:53 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E8=A7=84=E5=88=99=20No.228/2?= =?UTF-8?q?29=20(#248)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add tests --- paconvert/api_mapping.json | 27 +++++++++ tests/test_sparse_addmm.py | 109 +++++++++++++++++++++++++++++++++++++ tests/test_sparse_mm.py | 67 +++++++++++++++++++++++ 3 files changed, 203 insertions(+) create mode 100644 tests/test_sparse_addmm.py create mode 100644 tests/test_sparse_mm.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 50f5b35bd..dd045688a 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -10427,6 +10427,33 @@ "stable": "" } }, + "torch.sparse.addmm": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.sparse.addmm", + "args_list": [ + "input", + "mat1", + "mat2", + "alpha", + "beta" + ], + "kwargs_change": { + "mat1": "x", + "mat2": "y" + } + }, + "torch.sparse.mm": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.sparse.matmul", + "args_list": [ + "sparse", + "dense" + ], + "kwargs_change": { + "sparse": "x", + "dense": "y" + } + }, "torch.sparse.sum": { "Matcher": "GenericMatcher", "paddle_api": "paddle.sparse.sum", diff --git a/tests/test_sparse_addmm.py b/tests/test_sparse_addmm.py new file mode 100644 index 000000000..2f1722a4f --- /dev/null +++ b/tests/test_sparse_addmm.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.sparse.addmm") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + i = torch.tensor([[0, 1, 2], + [1, 0, 1]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + mat1 = torch.sparse_coo_tensor(i, v, [3, 3]) + mat2 = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.addmm(x, mat1, mat2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + i = torch.tensor([[0, 1, 2], + [1, 0, 1]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + mat1 = torch.sparse_coo_tensor(i, v, [3, 3]) + mat2 = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.addmm(input=x, mat1=mat1, mat2=mat2, beta=0.6, alpha=0.7) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + i = torch.tensor([[0, 1, 2], + [1, 0, 1]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + mat1 = torch.sparse_coo_tensor(i, v, [3, 3]) + mat2 = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.addmm(beta=0.6, alpha=0.7, input=x, mat1=mat1, mat2=mat2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + i = torch.tensor([[0, 1, 2], + [1, 0, 1]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + mat1 = torch.sparse_coo_tensor(i, v, [3, 3]) + mat2 = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.addmm(x, mat1, mat2, beta=0.6, alpha=0.7) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + i = torch.tensor([[0, 1, 2], + [1, 0, 1]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + mat1 = torch.sparse_coo_tensor(i, v, [3, 3]) + mat2 = torch.tensor([[1., 2, 3], [3, 4, 5], [3, 4, 5]]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.addmm(input=x, mat1=mat1, mat2=mat2, beta=0.6, alpha=0.7) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_sparse_mm.py b/tests/test_sparse_mm.py new file mode 100644 index 000000000..b3fe441d7 --- /dev/null +++ b/tests/test_sparse_mm.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.sparse.mm") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + indices = [[0, 1, 2], [1, 2, 0]] + values = [1., 2., 3.] + x = torch.sparse_coo_tensor(indices, values, [3, 3]) + dense = torch.ones([3, 2]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.mm(x, dense) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + indices = [[0, 1, 2], [1, 2, 0]] + values = [1., 2., 3.] + x = torch.sparse_coo_tensor(indices, values, [3, 3]) + dense = torch.ones([3, 2]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.mm(sparse=x, dense=dense) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + indices = [[0, 1, 2], [1, 2, 0]] + values = [1., 2., 3.] + x = torch.sparse_coo_tensor(indices, values, [3, 3]) + dense = torch.ones([3, 2]) + result = None + if torch.cuda.is_available(): + result = torch.sparse.mm(dense=dense, sparse=x) + """ + ) + obj.run(pytorch_code, ["result"])