From 25e89167e4d0777c555cbcdbf3370657b56d9bbf Mon Sep 17 00:00:00 2001 From: ljgray Date: Mon, 20 Mar 2023 14:49:52 -0700 Subject: [PATCH] fix(memh5): fix a bug where shared dataset names weren't correct --- caput/memh5.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/caput/memh5.py b/caput/memh5.py index c15d37b7..d3ecf2c6 100644 --- a/caput/memh5.py +++ b/caput/memh5.py @@ -2524,7 +2524,12 @@ def deep_group_copy( ): """Copy full data tree from one group to another. - Copies from g1 to g2. An axis downselection can be specified by supplying the + Copies from g1 to g2: + - The default behaviour creates a shallow copy of each dataset + - When deep_copy_dsets is True, datasets are fully deep copied + - Any datasets in `shared` will point to the object in g1 storage + + An axis downselection can be specified by supplying the parameter 'selections'. For example to select the first two indexes in g1["foo"]["bar"], do @@ -2566,6 +2571,8 @@ def deep_group_copy( enabled could break legacy code, so be very careful shared : list, optional List of datasets to share, if `deep_copy_dsets` is True. Otherwise, no effect. + Shared datasets just point to the existing object in g1 storage, and override + any other behaviour Returns ------- @@ -2578,6 +2585,7 @@ def deep_group_copy( # only the case if zarr is not installed if file_format.module is None: raise RuntimeError("Can't deep_group_copy zarr file. Please install zarr.") + to_file = isinstance(g2, file_format.module.Group) # Prepare a dataset for writing out, applying selections and transforming any @@ -2634,6 +2642,10 @@ def _prepare_dataset(dset): # needed until fixed: https://github.com/mpi4py/mpi4py/issues/177 data = ensure_native_byteorder(data) + if deep_copy_dsets: + # Make sure that we get a deep copy of this dataset + data = deep_copy_dataset(data) + dset_args = {"dtype": data.dtype, "shape": data.shape, "data": data} # If we're copying memory to memory we can allow distributed datasets if not to_file and isinstance(dset, MemDatasetDistributed): @@ -2675,9 +2687,13 @@ def _prepare_compression_args(dset): return compression_kwargs + # Make sure shared dataset names are properly formatted + shared = {"/" + k if k[0] != "/" else k for k in shared} + # Do a non-recursive traversal of the tree, recreating the structure and attributes, # and copying over any non-distributed datasets stack = [g1] + while stack: entry = stack.pop() key = entry.name @@ -2686,15 +2702,21 @@ def _prepare_compression_args(dset): if key != g1.name: # Only create group if we are above the starting level g2.create_group(key) + stack += [entry[k] for k in sorted(entry, reverse=True)] - else: # Is a dataset + + elif key in shared: + # Just point to the existing dataset + parent_name, name = posixpath.split(posixpath.join(g2.name, key)) + parent_name = format_abs_path(parent_name) + # Get the proper storage location for this dataset + g2[parent_name]._get_storage()[name] = g1._get_storage()[key] + + else: + # Copy over this dataset dset_args = _prepare_dataset(entry) compression_kwargs = _prepare_compression_args(entry) - if deep_copy_dsets and key not in shared: - # Make a deep copy of the dataset - dset_args["data"] = deep_copy_dataset(dset_args["data"]) - g2.create_dataset( key, **dset_args,