From 5807f5df1de11387f3e7f0fecbdb1cc53a0a2a36 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Jan 2022 14:38:59 +0000 Subject: [PATCH] 3628 0-d array to_contiguous (#3629) * 0-d array Signed-off-by: Wenqi Li * handling string types Signed-off-by: Wenqi Li --- monai/transforms/utils.py | 8 ++-- .../utils_pytorch_numpy_unification.py | 6 ++- tests/test_to_contiguous.py | 44 +++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/test_to_contiguous.py diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index f81099c48a..fd795f6ca9 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -14,7 +14,7 @@ import warnings from contextlib import contextmanager from inspect import getmembers, isclass -from typing import Any, Callable, Hashable, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Hashable, Iterable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -1508,11 +1508,11 @@ def convert_to_contiguous(data, **kwargs): https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous. """ - if isinstance(data, (np.ndarray, torch.Tensor)): + if isinstance(data, (np.ndarray, torch.Tensor, str, bytes)): return ascontiguousarray(data, **kwargs) - if isinstance(data, dict): + if isinstance(data, Mapping): return {k: convert_to_contiguous(v, **kwargs) for k, v in data.items()} - if isinstance(data, (list, tuple)): + if isinstance(data, Sequence): return [convert_to_contiguous(i, **kwargs) for i in data] return data diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 93066a132e..5085ded2b5 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -334,5 +334,9 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs): """ if isinstance(x, np.ndarray): + if x.ndim == 0: + return x return np.ascontiguousarray(x) - return x.contiguous(**kwargs) + if isinstance(x, torch.Tensor): + return x.contiguous(**kwargs) + return x diff --git a/tests/test_to_contiguous.py b/tests/test_to_contiguous.py new file mode 100644 index 0000000000..a9c2a78278 --- /dev/null +++ b/tests/test_to_contiguous.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.transforms import convert_to_contiguous +from tests.utils import assert_allclose + + +class TestToContiguous(unittest.TestCase): + def test_decollation_dict(self): + tochange = np.moveaxis(np.zeros((2, 3, 4)), 0, -1) + test_dict = {"test_key": [[1]], 0: np.array(0), 1: np.array([0]), "nested": {"nested": [tochange]}} + output = convert_to_contiguous(test_dict) + self.assertEqual(output["test_key"], [[1]]) + assert_allclose(output[0], np.array(0)) + assert_allclose(output[1], np.array([0])) + self.assertTrue(output["nested"]["nested"][0].flags.c_contiguous) + + def test_decollation_seq(self): + tochange = torch.zeros(2, 3, 4).transpose(0, 1) + test_dict = [[[1]], np.array(0), np.array([0]), torch.tensor(1.0), [[tochange]], "test_string"] + output = convert_to_contiguous(test_dict) + self.assertEqual(output[0], [[1]]) + assert_allclose(output[1], np.array(0)) + assert_allclose(output[2], np.array([0])) + assert_allclose(output[3], torch.tensor(1.0)) + self.assertTrue(output[4][0][0].is_contiguous()) + self.assertEqual(output[5], "test_string") + + +if __name__ == "__main__": + unittest.main()