diff --git a/caput/memh5.py b/caput/memh5.py index 6a7d553c..fcdef4f0 100644 --- a/caput/memh5.py +++ b/caput/memh5.py @@ -2492,7 +2492,12 @@ def deep_group_copy( """ Copy full data tree from one group to another. - Copies from g1 to g2. An axis downselection can be specified by supplying the + Copies from g1 to g2: + - The default behaviour creates a shallow copy of each dataset + - When deep_copy_dsets is True, datasets are fully deep copied + - Any datasets in `shared` will point to the object in g1 storage + + An axis downselection can be specified by supplying the parameter 'selections'. For example to select the first two indexes in g1["foo"]["bar"], do @@ -2534,6 +2539,8 @@ def deep_group_copy( enabled could break legacy code, so be very careful shared : list, optional List of datasets to share, if `deep_copy_dsets` is True. Otherwise, no effect. + Shared datasets just point to the existing object in g1 storage, and override + any other behaviour Returns ------- @@ -2547,6 +2554,7 @@ def deep_group_copy( # only the case if zarr is not installed if file_format.module is None: raise RuntimeError("Can't deep_group_copy zarr file. Please install zarr.") + to_file = isinstance(g2, file_format.module.Group) # Prepare a dataset for writing out, applying selections and transforming any @@ -2603,6 +2611,10 @@ def _prepare_dataset(dset): # needed until fixed: https://github.com/mpi4py/mpi4py/issues/177 data = ensure_native_byteorder(data) + if deep_copy_dsets: + # Make sure that we get a deep copy of this dataset + data = deep_copy_dataset(data) + dset_args = {"dtype": data.dtype, "shape": data.shape, "data": data} # If we're copying memory to memory we can allow distributed datasets if not to_file and isinstance(dset, MemDatasetDistributed): @@ -2644,9 +2656,13 @@ def _prepare_compression_args(dset): return compression_kwargs + # Make sure shared dataset names are properly formatted + shared = {"/" + k if k[0] != "/" else k for k in shared} + # Do a non-recursive traversal of the tree, recreating the structure and attributes, # and copying over any non-distributed datasets stack = [g1] + while stack: entry = stack.pop() key = entry.name @@ -2655,20 +2671,25 @@ def _prepare_compression_args(dset): if key != g1.name: # Only create group if we are above the starting level g2.create_group(key) + stack += [entry[k] for k in sorted(entry, reverse=True)] - else: # Is a dataset - dset_args = _prepare_dataset(entry) - compression_kwargs = _prepare_compression_args(entry) - - if deep_copy_dsets and key not in shared: - # Make a deep copy of the dataset - dset_args["data"] = deep_copy_dataset(dset_args["data"]) - - g2.create_dataset( - key, - **dset_args, - **compression_kwargs, - ) + else: + # This is a dataset + if key in shared: + # Just point to the existing dataset + parent_name, name = posixpath.split(posixpath.join(g2.name, key)) + parent_name = format_abs_path(parent_name) + # Get the proper storage location for this dataset + g2[parent_name]._get_storage()[name] = g1._get_storage()[key] + else: + dset_args = _prepare_dataset(entry) + compression_kwargs = _prepare_compression_args(entry) + + g2.create_dataset( + key, + **dset_args, + **compression_kwargs, + ) target = g2[key] copyattrs(entry.attrs, target.attrs, convert_strings=convert_attribute_strings)