diff --git a/README.md b/README.md index f6040ab6b..551ca1016 100644 --- a/README.md +++ b/README.md @@ -312,13 +312,6 @@ unsupport_args:可选,Paddle API不支持的参数功能,通过该字段配 paddle_default_kwargs :可选,当 paddle 参数更多 或者 参数默认值不一致 时,可以通过该配置,设置参数默认值。 ``` -**需要注意的是**,如果一个API是别名API(alias API), 比如 `torch.nn.modules.GroupNorm` 和 `torch.nn.GroupNorm` 是同一个API,只是按照模块路径采用了不同的调用方式,那么就无需编写相关 Matcher,只需在 paconvert/api_alias_mapping.json 中增加如下该 API 的配置即可: -```bash -{ - "torch.nn.modules.GroupNorm": "torch.nn.GroupNorm" -} -``` - 对于一个待开发API,首先依据步骤1的映射关系,确定其属于哪种分类情况。 对于以下映射关系的分类,都可以通过框架封装好的通用转换器:`GenericMatcher` 来处理: @@ -348,7 +341,7 @@ paddle_default_kwargs :可选,当 paddle 参数更多 或者 参数默认值 } ``` -如果不属于上述分类,则需要开发 **自定义的Matcher**,命名标准为:`API名+Matcher` 。例如 `torch.transpose` 可命名为`TransposeMatcher` ,`torch.Tensor.transpose` 可命名为 `TensorTransposeMatcher`。详见下面步骤3。 +如果不属于上述分类,则需要开发 **自定义的Matcher**,命名标准为:`API名+Matcher` 。例如 `torch.transpose` 可命名为`TransposeMatcher` ,`torch.Tensor.transpose` 可命名为 `TensorTransposeMatcher`。详见下面步骤。 ## 步骤4:编写Matcher(转换规则) @@ -571,6 +564,14 @@ x.reshape(2, 3) 3) API功能缺失。如果是整个API都缺失的,只需在API映射表中标注 **功能缺失** 即可,无需其他开发。如果是API局部功能缺失,则对功能缺失点,在代码中返回None表示不支持,同时在API映射表中说明此功能点 **Paddle暂无转写方式**,同时编写单测但可以注释掉不运行;对其他功能点正常开发即可。 +4) 别名实现。如果一个API是别名API(alias API),例如 `torch.nn.modules.GroupNorm` 是 `torch.nn.GroupNorm` 的别名,那么就无需编写相关 Matcher,只需在 `paconvert/api_alias_mapping.json` 中增加该别名 API 的配置,同时也无需增加相应单测文件,只需在主API的单测文件中增加 `test_alias_case_1/test_alias_case_2...` 即可。 + + ```bash + { + "torch.nn.modules.GroupNorm": "torch.nn.GroupNorm" + } + ``` + ### 开发技巧 1)可以参考一些写的较为规范的Matcher: diff --git a/paconvert/api_alias_mapping.json b/paconvert/api_alias_mapping.json index 819aea3ce..199106dc4 100644 --- a/paconvert/api_alias_mapping.json +++ b/paconvert/api_alias_mapping.json @@ -4,6 +4,8 @@ "torch.nn.modules.conv.Conv2d": "torch.nn.Conv2d", "torch.nn.modules.module.Module": "torch.nn.Module", "torch.nn.parameter.Parameter": "torch.nn.Parameter", + "torch.utils.data._utils.collate.default_collate": "torch.utils.data.default_collate", + "torch.utils.data.dataloader.default_collate": "torch.utils.data.default_collate", "torch.utils.data.sampler.BatchSampler": "torch.utils.data.BatchSampler", "torch.utils.data.sampler.RandomSampler": "torch.utils.data.RandomSampler", "torch.utils.data.sampler.Sampler": "torch.utils.data.Sampler", diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 70444a0a5..6306a357a 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -5163,6 +5163,25 @@ "out" ] }, + "torch.multiprocessing.spawn": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distributed.spawn", + "args_list": [ + "fn", + "args", + "nprocs", + "join", + "daemon", + "start_method" + ], + "kwargs_change": { + "fn": "func", + "start_method": "" + }, + "paddle_default_kwargs": { + "nprocs": "1" + } + }, "torch.mv": { "Matcher": "GenericMatcher", "paddle_api": "paddle.mv", @@ -7897,7 +7916,10 @@ "defaults" ] }, - "torch.optim.Optimizer.add_param_group": {}, + "torch.optim.Optimizer.add_param_group": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.optimizer.Optimizer._add_param_group" + }, "torch.optim.Optimizer.load_state_dict": { "Matcher": "GenericMatcher", "paddle_api": "paddle.optimizer.Optimizer.set_state_dict", @@ -8279,6 +8301,24 @@ "pickle_protocol": "protocol" } }, + "torch.scalar_tensor": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.to_tensor", + "args_list": [ + "s", + "dtype", + "layout", + "device", + "requires_grad", + "pin_memory" + ], + "kwargs_change": { + "s": "data" + }, + "paddle_default_kwargs": { + "dtype": "paddle.float32" + } + }, "torch.seed": { "Matcher": "SeedMatcher" }, @@ -8736,6 +8776,22 @@ "dims": "axes" } }, + "torch.testing.assert_allclose": { + "Matcher": "Assert_AllcloseMatcher", + "paddle_api": "paddle.allclose", + "args_list": [ + "actual", + "expected", + "rtol", + "atol", + "equal_nan", + "msg" + ], + "kwargs_change": { + "expected": "y", + "acltual": "x" + } + }, "torch.tile": { "Matcher": "GenericMatcher", "paddle_api": "paddle.tile", @@ -8963,10 +9019,28 @@ "Matcher": "GenericMatcher", "paddle_api": "paddle.io.Dataset" }, + "torch.utils.data.DistributedSampler": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.io.DistributedBatchSampler", + "args_list": [ + "dataset", + "num_replicas", + "rank", + "shuffle", + "seed", + "drop_last" + ], + "kwargs_change": { + "seed": "" + }, + "paddle_default_kwargs": { + "shuffle": "True", + "batch_size": "1" + } + }, "torch.utils.data.IterableDataset": { "Matcher": "GenericMatcher", - "paddle_api": "paddle.io.IterableDataset", - "args_list": [] + "paddle_api": "paddle.io.IterableDataset" }, "torch.utils.data.RandomSampler": { "Matcher": "RandomSamplerMatcher", @@ -8997,6 +9071,13 @@ "Matcher": "TensorDatasetMatcher", "paddle_api": "paddle.io.TensorDataset" }, + "torch.utils.data.default_collate": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.io.dataloader.collate.default_collate_fn", + "args_list": [ + "batch" + ] + }, "torch.utils.data.random_split": { "Matcher": "RandomSplitMatcher", "paddle_api": "paddle.io.random_split", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 7826b9b00..60c22c933 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2699,6 +2699,18 @@ def generate_code(self, kwargs): return code +class Assert_AllcloseMatcher(BaseMatcher): + def generate_code(self, kwargs): + kwargs["x"], kwargs["y"] = kwargs.pop("actual"), kwargs.pop("expected") + msg = "''" + if "msg" in kwargs: + msg = kwargs.pop("msg") + code = "assert paddle.allclose({}).item(), {}".format( + self.kwargs_to_str(kwargs), msg + ) + return code + + class Num2TensorBinaryMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" in kwargs: diff --git a/paconvert/base.py b/paconvert/base.py index 28cc2bad3..45de8b91c 100644 --- a/paconvert/base.py +++ b/paconvert/base.py @@ -183,7 +183,8 @@ def get_full_attr(self, node): # 4. x[0].transpose(1, 0) -> 'torchClass' # 5. (-x).transpose(1, 0) -> 'torchClass' elif isinstance( - node, (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript) + node, + (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript, ast.Assert), ): node_str = astor.to_source(node).strip("\n") for item in self.black_list: diff --git a/paconvert/transformer/basic_transformer.py b/paconvert/transformer/basic_transformer.py index 6182ac02b..68063af3d 100644 --- a/paconvert/transformer/basic_transformer.py +++ b/paconvert/transformer/basic_transformer.py @@ -68,7 +68,8 @@ def visit_Attribute(self, node): # 4. (-x).transpose(1, 0) # 5. x[0].transpose(1, 0) if isinstance( - node.value, (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript) + node.value, + (ast.Call, ast.Compare, ast.BinOp, ast.UnaryOp, ast.Subscript, ast.Assert), ): super(BasicTransformer, self).generic_visit(node) @@ -320,6 +321,7 @@ def visit_Call(self, node): ast.BinOp, ast.UnaryOp, ast.Tuple, + ast.Assert, ), ): self.insert_multi_node(node_list[0:-1]) @@ -458,6 +460,7 @@ def trans_class_method(self, node, torch_api): ast.Attribute, ast.Subscript, ast.BinOp, + ast.Assert, ast.Tuple, ), ): diff --git a/tests/test_multiprocess_spawn.py b/tests/test_multiprocess_spawn.py new file mode 100644 index 000000000..b8b0beaef --- /dev/null +++ b/tests/test_multiprocess_spawn.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.multiprocessing.spawn") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + def train(): + return torch.tensor([1]) + torch.multiprocessing.spawn(train) + """ + ) + obj.run(pytorch_code) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + def train(): + return torch.tensor([1]) + torch.multiprocessing.spawn(fn=train, args=(True,), nprocs=2, join=True, daemon=False, start_method='spawn') + """ + ) + obj.run(pytorch_code) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + def train(): + return torch.tensor([1]) + torch.multiprocessing.spawn(train, args=(True,), nprocs=2, join=True, daemon=False, start_method='spawn') + """ + ) + obj.run(pytorch_code) diff --git a/tests/test_scalar_tensor.py b/tests/test_scalar_tensor.py new file mode 100644 index 000000000..27fef26f1 --- /dev/null +++ b/tests/test_scalar_tensor.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.scalar_tensor") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = y = torch.scalar_tensor(False, dtype=torch.int32) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = y = torch.scalar_tensor(1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = y = torch.scalar_tensor(s=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = y = torch.scalar_tensor(s=1, dtype=torch.bool) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = y = torch.scalar_tensor(s=1, dtype=torch.bool, pin_memory=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + result = y = torch.scalar_tensor(s=1, dtype=torch.float32, pin_memory=False, requires_grad=True) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_testing_assert_allclose.py b/tests/test_testing_assert_allclose.py new file mode 100644 index 000000000..688f41198 --- /dev/null +++ b/tests/test_testing_assert_allclose.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.histc") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2., 3.]) + torch.testing.assert_allclose(x, x) + """ + ) + obj.run(pytorch_code) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2., 3.]) + y = x + 1 + torch.testing.assert_allclose(actual=x, expected=y) + """ + ) + obj.run(pytorch_code) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2., float('nan')]) + y = x + torch.testing.assert_allclose(x, y, equal_nan=True) + """ + ) + obj.run(pytorch_code) diff --git a/tests/test_utils_data_IterableDataset.py b/tests/test_utils_data_IterableDataset.py new file mode 100644 index 000000000..8f4aa4e92 --- /dev/null +++ b/tests/test_utils_data_IterableDataset.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.data.IterableDataset") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data import IterableDataset + + class MyIterableDataset(IterableDataset): + def __init__(self, start, end): + super(MyIterableDataset).__init__() + assert end > start, "this example code only works with end >= start" + self.start = start + self.end = end + + def __iter__(self): + return iter(range(self.start, self.end)) + + ds = MyIterableDataset(start=3, end=7) + result = [] + for i in ds: + result.append(i) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data import IterableDataset + + class MyIterableDataset(IterableDataset): + def __init__(self, start, end): + super(MyIterableDataset).__init__() + assert end > start, "this example code only works with end >= start" + self.start = start + self.end = end + + def __iter__(self): + return iter(range(self.start, self.end)) + + ds = MyIterableDataset(start=3, end=7) + result = next(ds.__iter__()) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_utils_data_default_collate.py b/tests/test_utils_data_default_collate.py new file mode 100644 index 000000000..2b0618e65 --- /dev/null +++ b/tests/test_utils_data_default_collate.py @@ -0,0 +1,247 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.data.default_collate") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import default_collate + result = torch.tensor(default_collate([0, 1, 2, 3])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data import default_collate + result = default_collate(['a', 'b', 'c']) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import default_collate + result = default_collate([torch.tensor([0, 1, 2, 3])]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import default_collate + result = default_collate((torch.tensor([1, 3, 3]), torch.tensor([3, 1, 1]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import default_collate + result = default_collate(batch=(torch.tensor([1, 3, 3]), torch.tensor([3, 1, 1]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Pytorch returns Tensor by default, while paddle returns narray. +def _test_case_6(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data import default_collate + result = default_collate(batch=default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Pytorch returns Tensor by default, while paddle returns narray. +def _test_case_7(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data import default_collate + result = default_collate([0, 1, 2, 3]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data.dataloader import default_collate + result = torch.tensor(default_collate([0, 1, 2, 3])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_2(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data.dataloader import default_collate + result = default_collate(['a', 'b', 'c']) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data.dataloader import default_collate + result = default_collate([torch.tensor([0, 1, 2, 3])]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data.dataloader import default_collate + result = default_collate((torch.tensor([1, 3, 3]), torch.tensor([3, 1, 1]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data.dataloader import default_collate + result = default_collate(batch=(torch.tensor([1, 3, 3]), torch.tensor([3, 1, 1]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Pytorch returns Tensor by default, while paddle returns narray. +def _test_alias_case_6(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data.dataloader import default_collate + result = default_collate(batch=default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Pytorch returns Tensor by default, while paddle returns narray. +def _test_alias_case_7(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data.dataloader import default_collate + result = default_collate([0, 1, 2, 3]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data._utils.collate import default_collate + result = torch.tensor(default_collate([0, 1, 2, 3])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_9(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data._utils.collate import default_collate + result = default_collate(['a', 'b', 'c']) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data._utils.collate import default_collate + result = default_collate([torch.tensor([0, 1, 2, 3])]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_11(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data._utils.collate import default_collate + result = default_collate((torch.tensor([1, 3, 3]), torch.tensor([3, 1, 1]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_alias_case_12(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data._utils.collate import default_collate + result = default_collate(batch=(torch.tensor([1, 3, 3]), torch.tensor([3, 1, 1]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Pytorch returns Tensor by default, while paddle returns narray. +def _test_alias_case_13(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data._utils.collate import default_collate + result = default_collate(batch=default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]))) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Pytorch returns Tensor by default, while paddle returns narray. +def _test_alias_case_14(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data._utils.collate import default_collate + result = default_collate([0, 1, 2, 3]) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_utils_date_DistributedSampler.py b/tests/test_utils_date_DistributedSampler.py new file mode 100644 index 000000000..4a9a043bd --- /dev/null +++ b/tests/test_utils_date_DistributedSampler.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.data.DistributedSampler") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + from torch.utils.data import Dataset, DistributedSampler + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + dataset = RandomDataset(100) + dataset = DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False) + """ + ) + obj.run(pytorch_code)