From 8a1f48c5c86e5dec18e5436b51c3b49ebb06e61d Mon Sep 17 00:00:00 2001 From: ljgray Date: Wed, 1 Mar 2023 14:05:10 -0800 Subject: [PATCH] test(test_memh5): add test for memh5 copying and equality --- caput/tests/test_memh5.py | 29 +++++++++++++++++++++++++++++ caput/tests/test_memh5_parallel.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/caput/tests/test_memh5.py b/caput/tests/test_memh5.py index 72ad6dbf..c09bb06c 100644 --- a/caput/tests/test_memh5.py +++ b/caput/tests/test_memh5.py @@ -12,6 +12,7 @@ import pytest from pytest_lazyfixture import lazy_fixture import zarr +import copy from caput import memh5, fileformats @@ -28,6 +29,34 @@ def test_ro_dict(): a["b"] = 6 +# Unit test for MemDataset + + +def test_dataset_copy(): + # Check for string types + x = memh5.MemDatasetCommon(shape=(4, 5), dtype=np.float32) + x[:] = 0 + + # Check a deepcopy using .copy + y = x.copy() + assert x == y + y[:] = 1 + # Check this this is in fact a deep copy + assert x != y + + # This is a shallow copy + y = x.copy(shallow=True) + assert x == y + y[:] = 1 + assert x == y + + # Check a deepcopy using copy.deepcopy + y = copy.deepcopy(x) + assert x == y + y[:] = 2 + assert x != y + + # Unit tests for MemGroup. diff --git a/caput/tests/test_memh5_parallel.py b/caput/tests/test_memh5_parallel.py index 22bb61e0..5d1667a3 100644 --- a/caput/tests/test_memh5_parallel.py +++ b/caput/tests/test_memh5_parallel.py @@ -5,6 +5,7 @@ import numpy as np import h5py import zarr +import copy from caput import fileformats, memh5, mpiarray, mpiutil @@ -228,3 +229,31 @@ def test_redistribute(): assert g["data"].distributed_axis == 0 g.redistribute(1) assert g["data"].distributed_axis == 1 + + +# Unit test for MemDataset + + +def test_dataset_copy(): + # Check for string types + x = memh5.MemDatasetDistributed(shape=(4, 5), dtype=np.float32) + x[:] = 0 + + # Check a deepcopy using .copy + y = x.copy() + assert x == y + y[:] = 1 + # Check this this is in fact a deep copy + assert x != y + + # This is a shallow copy + y = x.copy(shallow=True) + assert x == y + y[:] = 1 + assert x == y + + # Check a deepcopy using copy.deepcopy + y = copy.deepcopy(x) + assert x == y + y[:] = 2 + assert x != y