Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/PaddlePaddle/PaConvert in…
Browse files Browse the repository at this point in the history
…to dev
  • Loading branch information
txyugood committed Jul 3, 2023
2 parents 2395ee0 + b7a6194 commit 2bca96b
Show file tree
Hide file tree
Showing 12 changed files with 654 additions and 13 deletions.
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,6 @@ unsupport_args:可选,Paddle API不支持的参数功能,通过该字段配
paddle_default_kwargs :可选,当 paddle 参数更多 或者 参数默认值不一致 时,可以通过该配置,设置参数默认值。
```

**需要注意的是**,如果一个API是别名API(alias API), 比如 `torch.nn.modules.GroupNorm``torch.nn.GroupNorm` 是同一个API,只是按照模块路径采用了不同的调用方式,那么就无需编写相关 Matcher,只需在 paconvert/api_alias_mapping.json 中增加如下该 API 的配置即可:
```bash
{
"torch.nn.modules.GroupNorm": "torch.nn.GroupNorm"
}
```

对于一个待开发API,首先依据步骤1的映射关系,确定其属于哪种分类情况。

对于以下映射关系的分类,都可以通过框架封装好的通用转换器:`GenericMatcher` 来处理:
Expand Down Expand Up @@ -348,7 +341,7 @@ paddle_default_kwargs :可选,当 paddle 参数更多 或者 参数默认值
}
```

如果不属于上述分类,则需要开发 **自定义的Matcher**,命名标准为:`API名+Matcher` 。例如 `torch.transpose` 可命名为`TransposeMatcher``torch.Tensor.transpose` 可命名为 `TensorTransposeMatcher`详见下面步骤3
如果不属于上述分类,则需要开发 **自定义的Matcher**,命名标准为:`API名+Matcher` 。例如 `torch.transpose` 可命名为`TransposeMatcher``torch.Tensor.transpose` 可命名为 `TensorTransposeMatcher`详见下面步骤

## 步骤4:编写Matcher(转换规则)

Expand Down Expand Up @@ -571,6 +564,14 @@ x.reshape(2, 3)

3) API功能缺失。如果是整个API都缺失的,只需在API映射表中标注 **功能缺失** 即可,无需其他开发。如果是API局部功能缺失,则对功能缺失点,在代码中返回None表示不支持,同时在API映射表中说明此功能点 **Paddle暂无转写方式**,同时编写单测但可以注释掉不运行;对其他功能点正常开发即可。

4) 别名实现。如果一个API是别名API(alias API),例如 `torch.nn.modules.GroupNorm``torch.nn.GroupNorm` 的别名,那么就无需编写相关 Matcher,只需在 `paconvert/api_alias_mapping.json` 中增加该别名 API 的配置,同时也无需增加相应单测文件,只需在主API的单测文件中增加 `test_alias_case_1/test_alias_case_2...` 即可。

```bash
{
"torch.nn.modules.GroupNorm": "torch.nn.GroupNorm"
}
```

### 开发技巧

1)可以参考一些写的较为规范的Matcher:
Expand Down
2 changes: 2 additions & 0 deletions paconvert/api_alias_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"torch.nn.modules.conv.Conv2d": "torch.nn.Conv2d",
"torch.nn.modules.module.Module": "torch.nn.Module",
"torch.nn.parameter.Parameter": "torch.nn.Parameter",
"torch.utils.data._utils.collate.default_collate": "torch.utils.data.default_collate",
"torch.utils.data.dataloader.default_collate": "torch.utils.data.default_collate",
"torch.utils.data.sampler.BatchSampler": "torch.utils.data.BatchSampler",
"torch.utils.data.sampler.RandomSampler": "torch.utils.data.RandomSampler",
"torch.utils.data.sampler.Sampler": "torch.utils.data.Sampler",
Expand Down
87 changes: 84 additions & 3 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -5163,6 +5163,25 @@
"out"
]
},
"torch.multiprocessing.spawn": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distributed.spawn",
"args_list": [
"fn",
"args",
"nprocs",
"join",
"daemon",
"start_method"
],
"kwargs_change": {
"fn": "func",
"start_method": ""
},
"paddle_default_kwargs": {
"nprocs": "1"
}
},
"torch.mv": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.mv",
Expand Down Expand Up @@ -7897,7 +7916,10 @@
"defaults"
]
},
"torch.optim.Optimizer.add_param_group": {},
"torch.optim.Optimizer.add_param_group": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.optimizer.Optimizer._add_param_group"
},
"torch.optim.Optimizer.load_state_dict": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.optimizer.Optimizer.set_state_dict",
Expand Down Expand Up @@ -8279,6 +8301,24 @@
"pickle_protocol": "protocol"
}
},
"torch.scalar_tensor": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.to_tensor",
"args_list": [
"s",
"dtype",
"layout",
"device",
"requires_grad",
"pin_memory"
],
"kwargs_change": {
"s": "data"
},
"paddle_default_kwargs": {
"dtype": "paddle.float32"
}
},
"torch.seed": {
"Matcher": "SeedMatcher"
},
Expand Down Expand Up @@ -8736,6 +8776,22 @@
"dims": "axes"
}
},
"torch.testing.assert_allclose": {
"Matcher": "Assert_AllcloseMatcher",
"paddle_api": "paddle.allclose",
"args_list": [
"actual",
"expected",
"rtol",
"atol",
"equal_nan",
"msg"
],
"kwargs_change": {
"expected": "y",
"acltual": "x"
}
},
"torch.tile": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.tile",
Expand Down Expand Up @@ -8963,10 +9019,28 @@
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.Dataset"
},
"torch.utils.data.DistributedSampler": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.DistributedBatchSampler",
"args_list": [
"dataset",
"num_replicas",
"rank",
"shuffle",
"seed",
"drop_last"
],
"kwargs_change": {
"seed": ""
},
"paddle_default_kwargs": {
"shuffle": "True",
"batch_size": "1"
}
},
"torch.utils.data.IterableDataset": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.IterableDataset",
"args_list": []
"paddle_api": "paddle.io.IterableDataset"
},
"torch.utils.data.RandomSampler": {
"Matcher": "RandomSamplerMatcher",
Expand Down Expand Up @@ -8997,6 +9071,13 @@
"Matcher": "TensorDatasetMatcher",
"paddle_api": "paddle.io.TensorDataset"
},
"torch.utils.data.default_collate": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.dataloader.collate.default_collate_fn",
"args_list": [
"batch"
]
},
"torch.utils.data.random_split": {
"Matcher": "RandomSplitMatcher",
"paddle_api": "paddle.io.random_split",
Expand Down
12 changes: 12 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,18 @@ def generate_code(self, kwargs):
return code


class Assert_AllcloseMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs["x"], kwargs["y"] = kwargs.pop("actual"), kwargs.pop("expected")
msg = "''"
if "msg" in kwargs:
msg = kwargs.pop("msg")
code = "assert paddle.allclose({}).item(), {}".format(
self.kwargs_to_str(kwargs), msg
)
return code


class Num2TensorBinaryMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" in kwargs:
Expand Down
3 changes: 2 additions & 1 deletion paconvert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def get_full_attr(self, node):
# 4. x[0].transpose(1, 0) -> 'torchClass'
# 5. (-x).transpose(1, 0) -> 'torchClass'
elif isinstance(
node, (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript)
node,
(ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript, ast.Assert),
):
node_str = astor.to_source(node).strip("\n")
for item in self.black_list:
Expand Down
5 changes: 4 additions & 1 deletion paconvert/transformer/basic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def visit_Attribute(self, node):
# 4. (-x).transpose(1, 0)
# 5. x[0].transpose(1, 0)
if isinstance(
node.value, (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript)
node.value,
(ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript, ast.Assert),
):
super(BasicTransformer, self).generic_visit(node)

Expand Down Expand Up @@ -320,6 +321,7 @@ def visit_Call(self, node):
ast.BinOp,
ast.UnaryOp,
ast.Tuple,
ast.Assert,
),
):
self.insert_multi_node(node_list[0:-1])
Expand Down Expand Up @@ -458,6 +460,7 @@ def trans_class_method(self, node, torch_api):
ast.Attribute,
ast.Subscript,
ast.BinOp,
ast.Assert,
ast.Tuple,
),
):
Expand Down
55 changes: 55 additions & 0 deletions tests/test_multiprocess_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.multiprocessing.spawn")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
def train():
return torch.tensor([1])
torch.multiprocessing.spawn(train)
"""
)
obj.run(pytorch_code)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
def train():
return torch.tensor([1])
torch.multiprocessing.spawn(fn=train, args=(True,), nprocs=2, join=True, daemon=False, start_method='spawn')
"""
)
obj.run(pytorch_code)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
def train():
return torch.tensor([1])
torch.multiprocessing.spawn(train, args=(True,), nprocs=2, join=True, daemon=False, start_method='spawn')
"""
)
obj.run(pytorch_code)
79 changes: 79 additions & 0 deletions tests/test_scalar_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.scalar_tensor")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(False, dtype=torch.int32)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(1)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=1, dtype=torch.bool)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=1, dtype=torch.bool, pin_memory=False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=1, dtype=torch.float32, pin_memory=False, requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit 2bca96b

Please sign in to comment.