Skip to content

Commit

Permalink
add scalar_tensor etc. (#136)
Browse files Browse the repository at this point in the history
* add torch.is_nonzero .etc

* add torch.is_nonzero .etc

* add torch.is_nonzero .etc

* add torch.is_nonzero .etc

* add is_nonzero .etc

* add pytest is_nonzero etc.

* add pytest is_nonzero etc.

* add is_nonzero etc.

* pull

* add xlogy etc.

* add xlogy etc.

* add tensor logaddexp2 .etc

* add vdot etc.

* fix bug

* fix bug

* add chain_matmul etc.

* del sys

* change chain_matmul

* add sinc etc.

* add sinc etc.

* add sinc etc.

* add sinc etc.

* modify tensor strategy

* add cov etc.

* add cov etc.

* add cov etc.

* test

* add cov etc.

* add cov etc.

* develop

* add cov etc.

* add cov etc.

* develop

* add cov etc.

* add cov etc.

* develop

* develop

* add cov etc.

* develop

* add cov etc.

* add cov etc.

* test

* test

* add unique etc.

* add unique etc.

* add unique etc.

* add regiser etc.

* test

* test

* add backward etc.

* add backward etc.

* modify get_cuda_rng_state to get_rng_state

* add scalar_tensor etc.

* add scalar_tensor etc.

* modify default_collate alias

* modify default_collate alias

---------

Co-authored-by: zengpengcheng01 <[email protected]>
  • Loading branch information
zpceng314 and zengpengcheng01 authored Jul 3, 2023
1 parent 0407751 commit b7a6194
Show file tree
Hide file tree
Showing 11 changed files with 648 additions and 3 deletions.
2 changes: 2 additions & 0 deletions paconvert/api_alias_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"torch.nn.modules.conv.Conv2d": "torch.nn.Conv2d",
"torch.nn.modules.module.Module": "torch.nn.Module",
"torch.nn.parameter.Parameter": "torch.nn.Parameter",
"torch.utils.data._utils.collate.default_collate": "torch.utils.data.default_collate",
"torch.utils.data.dataloader.default_collate": "torch.utils.data.default_collate",
"torch.utils.data.sampler.BatchSampler": "torch.utils.data.BatchSampler",
"torch.utils.data.sampler.RandomSampler": "torch.utils.data.RandomSampler",
"torch.utils.data.sampler.Sampler": "torch.utils.data.Sampler",
Expand Down
88 changes: 87 additions & 1 deletion paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -5163,6 +5163,25 @@
"out"
]
},
"torch.multiprocessing.spawn": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distributed.spawn",
"args_list": [
"fn",
"args",
"nprocs",
"join",
"daemon",
"start_method"
],
"kwargs_change": {
"fn": "func",
"start_method": ""
},
"paddle_default_kwargs": {
"nprocs": "1"
}
},
"torch.mv": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.mv",
Expand Down Expand Up @@ -7897,7 +7916,10 @@
"defaults"
]
},
"torch.optim.Optimizer.add_param_group": {},
"torch.optim.Optimizer.add_param_group": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.optimizer.Optimizer._add_param_group"
},
"torch.optim.Optimizer.load_state_dict": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.optimizer.Optimizer.set_state_dict",
Expand Down Expand Up @@ -8279,6 +8301,24 @@
"pickle_protocol": "protocol"
}
},
"torch.scalar_tensor": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.to_tensor",
"args_list": [
"s",
"dtype",
"layout",
"device",
"requires_grad",
"pin_memory"
],
"kwargs_change": {
"s": "data"
},
"paddle_default_kwargs": {
"dtype": "paddle.float32"
}
},
"torch.seed": {
"Matcher": "SeedMatcher"
},
Expand Down Expand Up @@ -8736,6 +8776,22 @@
"dims": "axes"
}
},
"torch.testing.assert_allclose": {
"Matcher": "Assert_AllcloseMatcher",
"paddle_api": "paddle.allclose",
"args_list": [
"actual",
"expected",
"rtol",
"atol",
"equal_nan",
"msg"
],
"kwargs_change": {
"expected": "y",
"acltual": "x"
}
},
"torch.tile": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.tile",
Expand Down Expand Up @@ -8956,6 +9012,29 @@
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.Dataset"
},
"torch.utils.data.DistributedSampler": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.DistributedBatchSampler",
"args_list": [
"dataset",
"num_replicas",
"rank",
"shuffle",
"seed",
"drop_last"
],
"kwargs_change": {
"seed": ""
},
"paddle_default_kwargs": {
"shuffle": "True",
"batch_size": "1"
}
},
"torch.utils.data.IterableDataset": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.IterableDataset"
},
"torch.utils.data.RandomSampler": {
"Matcher": "RandomSamplerMatcher",
"paddle_api": "paddle.io.RandomSampler",
Expand All @@ -8973,6 +9052,13 @@
"data_source"
]
},
"torch.utils.data.default_collate": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.dataloader.collate.default_collate_fn",
"args_list": [
"batch"
]
},
"torch.utils.data.random_split": {
"Matcher": "RandomSplitMatcher",
"paddle_api": "paddle.io.random_split",
Expand Down
12 changes: 12 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,18 @@ def generate_code(self, kwargs):
return code


class Assert_AllcloseMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs["x"], kwargs["y"] = kwargs.pop("actual"), kwargs.pop("expected")
msg = "''"
if "msg" in kwargs:
msg = kwargs.pop("msg")
code = "assert paddle.allclose({}).item(), {}".format(
self.kwargs_to_str(kwargs), msg
)
return code


class Num2TensorBinaryMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" in kwargs:
Expand Down
3 changes: 2 additions & 1 deletion paconvert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def get_full_attr(self, node):
# 4. x[0].transpose(1, 0) -> 'torchClass'
# 5. (-x).transpose(1, 0) -> 'torchClass'
elif isinstance(
node, (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript)
node,
(ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript, ast.Assert),
):
node_str = astor.to_source(node).strip("\n")
for item in self.black_list:
Expand Down
5 changes: 4 additions & 1 deletion paconvert/transformer/basic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def visit_Attribute(self, node):
# 4. (-x).transpose(1, 0)
# 5. x[0].transpose(1, 0)
if isinstance(
node.value, (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript)
node.value,
(ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript, ast.Assert),
):
super(BasicTransformer, self).generic_visit(node)

Expand Down Expand Up @@ -320,6 +321,7 @@ def visit_Call(self, node):
ast.BinOp,
ast.UnaryOp,
ast.Tuple,
ast.Assert,
),
):
self.insert_multi_node(node_list[0:-1])
Expand Down Expand Up @@ -458,6 +460,7 @@ def trans_class_method(self, node, torch_api):
ast.Attribute,
ast.Subscript,
ast.BinOp,
ast.Assert,
ast.Tuple,
),
):
Expand Down
55 changes: 55 additions & 0 deletions tests/test_multiprocess_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.multiprocessing.spawn")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
def train():
return torch.tensor([1])
torch.multiprocessing.spawn(train)
"""
)
obj.run(pytorch_code)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
def train():
return torch.tensor([1])
torch.multiprocessing.spawn(fn=train, args=(True,), nprocs=2, join=True, daemon=False, start_method='spawn')
"""
)
obj.run(pytorch_code)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
def train():
return torch.tensor([1])
torch.multiprocessing.spawn(train, args=(True,), nprocs=2, join=True, daemon=False, start_method='spawn')
"""
)
obj.run(pytorch_code)
79 changes: 79 additions & 0 deletions tests/test_scalar_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.scalar_tensor")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(False, dtype=torch.int32)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(1)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=1, dtype=torch.bool)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=1, dtype=torch.bool, pin_memory=False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
result = y = torch.scalar_tensor(s=1, dtype=torch.float32, pin_memory=False, requires_grad=True)
"""
)
obj.run(pytorch_code, ["result"])
53 changes: 53 additions & 0 deletions tests/test_testing_assert_allclose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.histc")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
torch.testing.assert_allclose(x, x)
"""
)
obj.run(pytorch_code)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
y = x + 1
torch.testing.assert_allclose(actual=x, expected=y)
"""
)
obj.run(pytorch_code)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., float('nan')])
y = x
torch.testing.assert_allclose(x, y, equal_nan=True)
"""
)
obj.run(pytorch_code)
Loading

0 comments on commit b7a6194

Please sign in to comment.