Skip to content

Commit

Permalink
bank combiner: Add ability to copy all subgroups within a specified g…
Browse files Browse the repository at this point in the history
…roup (#64)

* Add ability to copy all subgroups within a specified group. Useful for compressed waveforms in banks

* Fix dtype issue
  • Loading branch information
GarethCabournDavies authored Dec 20, 2024
1 parent 1e98013 commit d249212
Showing 1 changed file with 50 additions and 7 deletions.
57 changes: 50 additions & 7 deletions bin/sbank_hdf5_bankcombiner
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ single template bank.
import argparse
import numpy
import h5py
import logging

__author__ = "Ian Harry <[email protected]>"
__program__ = "sbank_hdf5_bankcombiner"
Expand All @@ -36,27 +37,40 @@ parser.add_argument("--output-file", type=str,
parser.add_argument("--input-filenames", nargs='*', default=None,
action="store",
help="List of input hdf bank files.")
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--copy-subgroups-directly", nargs="+",
help="Directly copy subgroup(s) of this group to the "
"file rather than appending. Each subgroup must be "
"unique across the banks to combine")
parser.add_argument("--verbose", action="count")

args = parser.parse_args()

if args.copy_subgroups_directly is None:
args.copy_subgroups_directly = []

if args.verbose is not None:
logging.basicConfig(level=logging.WARNING - int(args.verbose) * 10)

attrs_dict = None
items_dict = None
approx_map_dict = {}
approx_map_dict['counter'] = 1
logging.info("Copying bank values")

for file_name in args.input_filenames:
n_banks = len(args.input_filenames)
for i, file_name in enumerate(args.input_filenames):
hdf_fp = h5py.File(file_name, 'r')
if 'empty_file' in hdf_fp.attrs:
continue
logging.debug("Bank %s; %d / %d", file_name, i, n_banks)
if attrs_dict is None:
attrs_dict = {}
for key, item in hdf_fp.attrs.items():
attrs_dict[key] = item

if items_dict is None:
items_dict = {}
for item, entries in hdf_fp.items():
items_dict[item] = entries[:]
dt = entries.dtype if hasattr(entries, "dtype") else None
items_dict[item] = numpy.array([], dtype=dt)
else:
curr_items = set(items_dict.keys())
new_items = set(hdf_fp.keys())
Expand All @@ -68,16 +82,45 @@ for file_name in args.input_filenames:
err_msg += "contains fields {} ".format(new_items)
err_msg += "other files contain {}.".format(curr_items)
raise ValueError(err_msg)
for item, entries in hdf_fp.items():
items_dict[item] = numpy.append(items_dict[item], entries[:])
for item, entries in hdf_fp.items():
if not isinstance(entries, h5py.Dataset):
continue
items_dict[item] = numpy.append(items_dict[item], entries[:])
hdf_fp.close()


out_fp = h5py.File(args.output_file, 'w')
if attrs_dict is None:
out_fp.attrs['empty_file'] = True
else:
for item, value in items_dict.items():
if item in args.copy_subgroups_directly:
continue
out_fp[item] = value
for item, value in attrs_dict.items():
out_fp.attrs[item] = value

if args.copy_subgroups_directly == []:
out_fp.close()
logging.info("Done!")
exit(0)


for grp_to_copy in args.copy_subgroups_directly:
all_hashes = set()
all_hash_groups = set()
logging.info("Directly copying groups under %s", grp_to_copy)
out_copied_group = out_fp.create_group(grp_to_copy)
for i, file_name in enumerate(args.input_filenames):
logging.debug("Bank %s; %d / %d", file_name, i, n_banks)
with h5py.File(file_name, 'r') as hdf_fp:
for key_to_copy in hdf_fp[grp_to_copy].keys():
hdf_fp.copy(
hdf_fp[grp_to_copy][key_to_copy],
out_copied_group,
name=key_to_copy
)

out_fp.close()

logging.info("Done!")

0 comments on commit d249212

Please sign in to comment.