Skip to content

Commit

Permalink
fix(memh5): fix a bug where shared dataset names weren't correct
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Apr 25, 2023
1 parent 05af2fa commit b914bc3
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions caput/memh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2526,7 +2526,12 @@ def deep_group_copy(
):
"""Copy full data tree from one group to another.
Copies from g1 to g2. An axis downselection can be specified by supplying the
Copies from g1 to g2:
- The default behaviour creates a shallow copy of each dataset
- When deep_copy_dsets is True, datasets are fully deep copied
- Any datasets in `shared` will point to the object in g1 storage
An axis downselection can be specified by supplying the
parameter 'selections'. For example to select the first two indexes in
g1["foo"]["bar"], do
Expand Down Expand Up @@ -2568,6 +2573,8 @@ def deep_group_copy(
enabled could break legacy code, so be very careful
shared : list, optional
List of datasets to share, if `deep_copy_dsets` is True. Otherwise, no effect.
Shared datasets just point to the existing object in g1 storage, and override
any other behaviour
Returns
-------
Expand All @@ -2580,6 +2587,7 @@ def deep_group_copy(
# only the case if zarr is not installed
if file_format.module is None:
raise RuntimeError("Can't deep_group_copy zarr file. Please install zarr.")

to_file = isinstance(g2, file_format.module.Group)

# Prepare a dataset for writing out, applying selections and transforming any
Expand Down Expand Up @@ -2636,6 +2644,10 @@ def _prepare_dataset(dset):
# needed until fixed: https://github.com/mpi4py/mpi4py/issues/177
data = ensure_native_byteorder(data)

if deep_copy_dsets:
# Make sure that we get a deep copy of this dataset
data = deep_copy_dataset(data)

dset_args = {"dtype": data.dtype, "shape": data.shape, "data": data}
# If we're copying memory to memory we can allow distributed datasets
if not to_file and isinstance(dset, MemDatasetDistributed):
Expand Down Expand Up @@ -2677,9 +2689,13 @@ def _prepare_compression_args(dset):

return compression_kwargs

# Make sure shared dataset names are properly formatted
shared = {"/" + k if k[0] != "/" else k for k in shared}

# Do a non-recursive traversal of the tree, recreating the structure and attributes,
# and copying over any non-distributed datasets
stack = [g1]

while stack:
entry = stack.pop()
key = entry.name
Expand All @@ -2688,15 +2704,21 @@ def _prepare_compression_args(dset):
if key != g1.name:
# Only create group if we are above the starting level
g2.create_group(key)

stack += [entry[k] for k in sorted(entry, reverse=True)]
else: # Is a dataset

elif key in shared:
# Just point to the existing dataset
parent_name, name = posixpath.split(posixpath.join(g2.name, key))
parent_name = format_abs_path(parent_name)
# Get the proper storage location for this dataset
g2[parent_name]._get_storage()[name] = g1._get_storage()[key]

else:
# Copy over this dataset
dset_args = _prepare_dataset(entry)
compression_kwargs = _prepare_compression_args(entry)

if deep_copy_dsets and key not in shared:
# Make a deep copy of the dataset
dset_args["data"] = deep_copy_dataset(dset_args["data"])

g2.create_dataset(
key,
**dset_args,
Expand Down

0 comments on commit b914bc3

Please sign in to comment.