Skip to content

Commit

Permalink
fix(memh5): fix a bug where shared dataset names weren't correct
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Mar 21, 2023
1 parent 4320fe3 commit f6021e3
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions caput/memh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2492,7 +2492,12 @@ def deep_group_copy(
"""
Copy full data tree from one group to another.
Copies from g1 to g2. An axis downselection can be specified by supplying the
Copies from g1 to g2:
- The default behaviour creates a shallow copy of each dataset
- When deep_copy_dsets is True, datasets are fully deep copied
- Any datasets in `shared` will point to the object in g1 storage
An axis downselection can be specified by supplying the
parameter 'selections'. For example to select the first two indexes in
g1["foo"]["bar"], do
Expand Down Expand Up @@ -2534,6 +2539,8 @@ def deep_group_copy(
enabled could break legacy code, so be very careful
shared : list, optional
List of datasets to share, if `deep_copy_dsets` is True. Otherwise, no effect.
Shared datasets just point to the existing object in g1 storage, and override
any other behaviour
Returns
-------
Expand All @@ -2547,6 +2554,7 @@ def deep_group_copy(
# only the case if zarr is not installed
if file_format.module is None:
raise RuntimeError("Can't deep_group_copy zarr file. Please install zarr.")

to_file = isinstance(g2, file_format.module.Group)

# Prepare a dataset for writing out, applying selections and transforming any
Expand Down Expand Up @@ -2603,6 +2611,10 @@ def _prepare_dataset(dset):
# needed until fixed: https://github.com/mpi4py/mpi4py/issues/177
data = ensure_native_byteorder(data)

if deep_copy_dsets:
# Make sure that we get a deep copy of this dataset
data = deep_copy_dataset(data)

dset_args = {"dtype": data.dtype, "shape": data.shape, "data": data}
# If we're copying memory to memory we can allow distributed datasets
if not to_file and isinstance(dset, MemDatasetDistributed):
Expand Down Expand Up @@ -2644,9 +2656,13 @@ def _prepare_compression_args(dset):

return compression_kwargs

# Make sure shared dataset names are properly formatted
shared = {"/" + k if k[0] != "/" else k for k in shared}

# Do a non-recursive traversal of the tree, recreating the structure and attributes,
# and copying over any non-distributed datasets
stack = [g1]

while stack:
entry = stack.pop()
key = entry.name
Expand All @@ -2655,20 +2671,25 @@ def _prepare_compression_args(dset):
if key != g1.name:
# Only create group if we are above the starting level
g2.create_group(key)

stack += [entry[k] for k in sorted(entry, reverse=True)]
else: # Is a dataset
dset_args = _prepare_dataset(entry)
compression_kwargs = _prepare_compression_args(entry)

if deep_copy_dsets and key not in shared:
# Make a deep copy of the dataset
dset_args["data"] = deep_copy_dataset(dset_args["data"])

g2.create_dataset(
key,
**dset_args,
**compression_kwargs,
)
else:
# This is a dataset
if key in shared:
# Just point to the existing dataset
parent_name, name = posixpath.split(posixpath.join(g2.name, key))
parent_name = format_abs_path(parent_name)
# Get the proper storage location for this dataset
g2[parent_name]._get_storage()[name] = g1._get_storage()[key]
else:
dset_args = _prepare_dataset(entry)
compression_kwargs = _prepare_compression_args(entry)

g2.create_dataset(
key,
**dset_args,
**compression_kwargs,
)

target = g2[key]
copyattrs(entry.attrs, target.attrs, convert_strings=convert_attribute_strings)
Expand Down

0 comments on commit f6021e3

Please sign in to comment.