Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
txyugood authored Jul 2, 2023
2 parents 233b489 + 389b839 commit 2395ee0
Show file tree
Hide file tree
Showing 42 changed files with 1,887 additions and 15 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,13 @@ unsupport_args:可选,Paddle API不支持的参数功能,通过该字段配
paddle_default_kwargs :可选,当 paddle 参数更多 或者 参数默认值不一致 时,可以通过该配置,设置参数默认值。
```

**需要注意的是**,如果一个API是别名API(alias API), 比如 `torch.nn.modules.GroupNorm``torch.nn.GroupNorm` 是同一个API,只是按照模块路径采用了不同的调用方式,那么就无需编写相关 Matcher,只需在 paconvert/api_alias_mapping.json 中增加如下该 API 的配置即可:
```bash
{
"torch.nn.modules.GroupNorm": "torch.nn.GroupNorm"
}
```

对于一个待开发API,首先依据步骤1的映射关系,确定其属于哪种分类情况。

对于以下映射关系的分类,都可以通过框架封装好的通用转换器:`GenericMatcher` 来处理:
Expand Down
32 changes: 22 additions & 10 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@
]
},
"torch.Tensor.coalesce": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.coalesce"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.sparse.coalesce"
},
"torch.Tensor.conj": {
"Matcher": "TensorUnchangeMatcher"
Expand Down Expand Up @@ -632,8 +632,11 @@
]
},
"torch.Tensor.diagflat": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.diagflat"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.diagflat",
"args_list": [
"offset"
]
},
"torch.Tensor.diagonal": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -2002,8 +2005,11 @@
},
"torch.Tensor.slice_scatter": {},
"torch.Tensor.slogdet": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.slogdet"
"Matcher": "TensorSLogDetMatcher",
"paddle_api": "paddle.linalg.slogdet",
"args_list": [
"out"
]
},
"torch.Tensor.smm": {},
"torch.Tensor.softmax": {
Expand Down Expand Up @@ -2210,13 +2216,19 @@
"torch.Tensor.transpose_": {},
"torch.Tensor.triangular_solve": {},
"torch.Tensor.tril": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.tril"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.tril",
"args_list": [
"diagonal"
]
},
"torch.Tensor.tril_": {},
"torch.Tensor.triu": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.triu"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.triu",
"args_list": [
"diagonal"
]
},
"torch.Tensor.triu_": {},
"torch.Tensor.true_divide": {
Expand Down
24 changes: 24 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3642,3 +3642,27 @@ def get_paddle_nodes(self, args, kwargs):
code = "{}({})".format(self.get_paddle_api(), code)
node = ast.parse(code.strip("\n")).body
return node


class TensorSLogDetMatcher(BaseMatcher):
def generate_code(self, kwargs):
out_v = kwargs.pop("out") if "out" in kwargs else None

if out_v:
API_TEMPLATE = textwrap.dedent(
"""
res = paddle.linalg.slogdet({})
paddle.assign(res[0], {}[0]), paddle.assign(res[1], {}[1])
"""
)
code = API_TEMPLATE.format(self.paddleClass, out_v, out_v)
else:
API_TEMPLATE = textwrap.dedent(
"""
res = paddle.linalg.slogdet({})
res[0], res[1]
"""
)
code = API_TEMPLATE.format(self.paddleClass)

return code
1 change: 1 addition & 0 deletions paconvert/transformer/basic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def trans_class_method(self, node, torch_api):
ast.Attribute,
ast.Subscript,
ast.BinOp,
ast.Tuple,
),
):
self.insert_multi_node(node_list[0:-1])
Expand Down
169 changes: 169 additions & 0 deletions tests/test_NLLLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.NLLLoss")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss()
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight)
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, size_average=True)
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, size_average=False)
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, size_average=True, ignore_index=1)
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, reduce=True)
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, reduce=False)
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, reduction='none')
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, reduction='mean')
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
m = torch.tensor([[0.2, 0.1, 0.7], [0.2, 0.1, 0.7]])
weight = torch.tensor([3.,1.,1.])
target = torch.tensor([1,2])
loss = nn.NLLLoss(weight=weight, reduction='sum')
result = loss(m, target)
"""
)
obj.run(pytorch_code, ["result"])
52 changes: 52 additions & 0 deletions tests/test_Tensor_bernoulli_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.bernoulli_")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
src = torch.tensor([1., 2., 3., 4., 5., 6.])
result = src.bernoulli_(0.5)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
src = torch.tensor([1., 2., 3., 4., 5., 6.])
result = src.bernoulli_(0.5, generator=None)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
src = torch.tensor([1., 2., 3., 4., 5., 6.])
result = src.bernoulli_(p=0.5, generator=None)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)
41 changes: 41 additions & 0 deletions tests/test_Tensor_bfloat16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.bfloat16")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
src = torch.tensor([1., 2., 3., 4., 5., 6.])
result = src.bfloat16().float()
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
src = torch.tensor([1., 2., 3., 4., 5., 6.])
result = src.bfloat16(memory_format=torch.preserve_format).float()
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit 2395ee0

Please sign in to comment.