Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(memh5): fix a bug where shared datasets weren't actually shared #239

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 90 additions & 33 deletions caput/memh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,9 +2051,7 @@ def copy(self, shared: list = [], shallow: bool = False) -> MemDiskGroup:
"""
cls = self.__class__.__new__(self.__class__)
MemDiskGroup.__init__(cls, distributed=self.distributed, comm=self.comm)
deep_group_copy(
self._data, cls._data, deep_copy_dsets=not shallow, shared=shared
)
deep_group_copy(self._data, cls._data, shallow=shallow, shared=shared)

return cls

Expand Down Expand Up @@ -2519,12 +2517,19 @@ def deep_group_copy(
file_format=fileformats.HDF5,
skip_distributed=False,
postprocess=None,
deep_copy_dsets=False,
shallow=False,
shared=[],
):
"""Copy full data tree from one group to another.

Copies from g1 to g2. An axis downselection can be specified by supplying the
Copies from g1 to g2:
- The default behaviour creates a deep copy of each dataset.
- If `g2` is on disk, the behaviour is the same as making a deep copy. In this
case, both `shallow` and `shared` are ignored.
- Otherwise, when shallow is False, datasets not listed in `shared` are fully
deep copied and any datasets in `shared` will point to the object in `g1` storage.

An axis downselection can be specified by supplying the
parameter 'selections'. For example to select the first two indexes in
g1["foo"]["bar"], do

Expand All @@ -2536,16 +2541,18 @@ def deep_group_copy(
>>> list(g2["foo"]["bar"])
[0, 1]

Axis downselections cannot be applied to shared datasets.

Parameters
----------
g1 : h5py.Group or zarr.Group
g1 : h5py.Group or zarr.Group or MemGroup
Deep copy from this group.
g2 : h5py.Group or zarr.Group
g2 : h5py.Group or zarr.Group or MemGroup
Deep copy to this group.
selections : dict
If this is not None, it should have a subset of the same hierarchical structure
as g1, but ultimately describe axis selections for group entries as valid
numpy indexes.
numpy indexes. Selections cannot be applied to shared datasets.
convert_attribute_strings : bool, optional
Convert string attributes (or lists/arrays of them) to ensure that they are
unicode.
Expand All @@ -2558,14 +2565,17 @@ def deep_group_copy(
names of all datasets that were skipped. If `False` (default) throw a
`ValueError` if any distributed datasets are encountered.
postprocess : function, optional
A function that takes is called on each node, with the source and destination
A function which is called on each node, with the source and destination
entries, and can modify either.
deep_copy_dsets : bool, optional
Explicitly deep copy all datasets. This will only alter behaviour when copying
from memory to memory. XXX: enabling this in places where it is not currently
enabled could break legacy code, so be very careful
shared : list, optional
List of datasets to share, if `deep_copy_dsets` is True. Otherwise, no effect.
shallow : bool, optional
Explicitly share all datasets. This will only alter behaviour when copying
from memory to memory. If False, any dataset listed in `shared` will NOT be copied.
Default is False.
shared : iterable, optional
Iterable (list, set, generator) of datasets to share, if `shallow` is False.
Shared datasets just point to the existing object in g1 storage. Axis selections
cannot be applied to shared datasets. Ignored if `shallow` is True, since, in that
case, _all_ datasets are shared.

Returns
-------
Expand All @@ -2578,12 +2588,15 @@ def deep_group_copy(
# only the case if zarr is not installed
if file_format.module is None:
raise RuntimeError("Can't deep_group_copy zarr file. Please install zarr.")

to_file = isinstance(g2, file_format.module.Group)

# Prepare a dataset for writing out, applying selections and transforming any
# datatypes
# Returns: dict(dtype, shape, data_to_write)
def _prepare_dataset(dset):
# Define functions applied to each dataset
def _get_selection(dset):
"""Get the selections associated with this dataset.

Returns: slice
"""
# Look for a selection for this dataset (also try without the leading "/")
try:
selection = selections.get(
Expand All @@ -2592,6 +2605,15 @@ def _prepare_dataset(dset):
except AttributeError:
selection = slice(None)

return selection

def _prepare_dataset(dset):
"""Prepare a dataset for writing, applying selections and transforming datatypes.

Returns: dict(dtype, shape, data_to_write)
"""
selection = _get_selection(dset)

# Check if this is a distributed dataset and figure out if we can make this work
# out
if to_file and isinstance(dset, MemDatasetDistributed):
Expand All @@ -2606,12 +2628,17 @@ def _prepare_dataset(dset):
f"({dset.name}) via this method."
)

# If we get here, we should create the dataset, but not write out any data into it (i.e. return None)
# If we get here, we should create the dataset, but not write out
# any data into it (i.e. return None)
distributed_dset_names.append(dset.name)
return {"dtype": dset.dtype, "shape": dset.shape, "data": None}

# Extract the data for the selection
data = dset[selection]
# If copying memory to memory, make a deep copy of this dataset
# We don't need to make a copy if writing to disk
if not to_file:
data = deep_copy_dataset(data)

if convert_dataset_strings:
# Convert unicode strings back into ascii byte strings. This will break
Expand All @@ -2628,12 +2655,6 @@ def _prepare_dataset(dset):
# Unicode characters before writing
data = check_unicode(entry)

if not to_file:
# reading from h5py can result in arrays with explicit endian set
# which mpi4py cannot handle when Bcasting memh5.Group
# needed until fixed: https://github.com/mpi4py/mpi4py/issues/177
data = ensure_native_byteorder(data)

dset_args = {"dtype": data.dtype, "shape": data.shape, "data": data}
# If we're copying memory to memory we can allow distributed datasets
if not to_file and isinstance(dset, MemDatasetDistributed):
Expand All @@ -2643,9 +2664,11 @@ def _prepare_dataset(dset):

return dset_args

# get compression options/chunking for this dataset
# Returns dict of compression and chunking arguments for create_dataset
def _prepare_compression_args(dset):
"""Get compression options and chunking for this dataset.

Returns: dict(compression, compression_opts, chunks)
"""
compression = getattr(dset, "compression", None)
compression_opts = getattr(dset, "compression_opts", None)

Expand Down Expand Up @@ -2675,9 +2698,31 @@ def _prepare_compression_args(dset):

return compression_kwargs

# If copying to file, datasets are not shared, so ensure that these
# datasets are properly processed
if to_file:
if shared:
warnings.warn(
f"Attempted to share datasets {(*shared,)}, but target group "
f"{g2} is on disk. Datasets cannot be shared."
)
shared = {}

if shallow:
warnings.warn(
f"Attempted to make a shallow copy of group {g1}, but target "
f"group {g2} is on disk. Datasets cannot be shared."
)
shallow = False

elif not shallow:
# Make sure shared dataset names are properly formatted
shared = {"/" + k if k[0] != "/" else k for k in shared}

# Do a non-recursive traversal of the tree, recreating the structure and attributes,
# and copying over any non-distributed datasets
stack = [g1]

while stack:
entry = stack.pop()
key = entry.name
Expand All @@ -2686,15 +2731,27 @@ def _prepare_compression_args(dset):
if key != g1.name:
# Only create group if we are above the starting level
g2.create_group(key)

stack += [entry[k] for k in sorted(entry, reverse=True)]
else: # Is a dataset

elif shallow or (key in shared):
# Make sure that we aren't trying to apply a selection to this dataset
if _get_selection(entry) != slice(None):
raise ValueError(
f"Cannot apply a selection to a shared dataset ({entry.name})"
)
# Just point to the existing dataset
parent_name, name = posixpath.split(posixpath.join(g2.name, key))
parent_name = format_abs_path(parent_name)
# Get the proper storage location for this dataset
g2[parent_name]._get_storage()[name] = g1._get_storage()[key]

else:
# Copy over this dataset. `_prepare_dataset` will make
# a deep copy of the dataset
dset_args = _prepare_dataset(entry)
compression_kwargs = _prepare_compression_args(entry)

if deep_copy_dsets and key not in shared:
# Make a deep copy of the dataset
dset_args["data"] = deep_copy_dataset(dset_args["data"])

g2.create_dataset(
key,
**dset_args,
Expand Down
Loading