-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
remove old generation files and add script for concatenating hdf5 files
- Loading branch information
1 parent
31fdbd2
commit 3e674e6
Showing
6 changed files
with
132 additions
and
129 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
use-cases/virgo/synthetic_data_gen/concat_hdf5_dataset_files.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import argparse | ||
import re | ||
import h5py | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
def append_to_hdf5_dataset( | ||
file_path: Path, | ||
dataset_name: str, | ||
array: np.ndarray, | ||
expected_datapoint_shape: tuple, | ||
): | ||
if tuple(array.shape[1:]) != expected_datapoint_shape: | ||
actual_shape_str = ", ".join(str(s) for s in array.shape) | ||
expected_shape_str = ", ".join(str(s) for s in expected_datapoint_shape) | ||
raise ValueError( | ||
f"'array' has an incorrect shape: ({actual_shape_str}). " | ||
f"Should have been (x, {expected_shape_str})." | ||
) | ||
|
||
print(f"Appending to file: '{str(file_path.resolve())}'.") | ||
with h5py.File(file_path, "a") as f: | ||
dset = f[dataset_name] | ||
dset.resize(dset.shape[0] + array.shape[0], axis=0) | ||
dset[-array.shape[0] :] = array | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Virgo Dataset Generation") | ||
parser.add_argument( | ||
"--dir", type=str, help="Directory containing the HDF5 files to concatenate" | ||
) | ||
parser.add_argument( | ||
"--save-location", | ||
type=str, | ||
help="Location to save the resulting HDF5 file.", | ||
default="total_virgo_data.hdf5", | ||
) | ||
args = parser.parse_args() | ||
dir = Path(args.dir) | ||
save_location = Path(args.save_location) | ||
num_aux_channels = 3 | ||
square_size = 64 | ||
dataset_name = "virgo_dataset" | ||
|
||
# Creating empty HDF5 file | ||
datapoint_shape = (num_aux_channels + 1, square_size, square_size) | ||
save_location.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
print(f"Creating/overwriting file: '{save_location.resolve()}'.") | ||
with h5py.File(save_location, "w") as f: | ||
dataset = f.create_dataset( | ||
dataset_name, | ||
shape=(0, *datapoint_shape), | ||
maxshape=(None, *datapoint_shape), | ||
dtype=np.float32, | ||
) | ||
dataset.attrs["Description"] = ( | ||
"Synthetic time series data for the Virgo use case" | ||
) | ||
dataset.attrs["main_channel_idx"] = 0 | ||
|
||
entries = [] | ||
# NOTE: This will not necessarily iterate in same order as the suffices of the | ||
# file names | ||
for entry in dir.iterdir(): | ||
if not entry.suffix == ".hdf5": | ||
continue | ||
|
||
print(f"Reading entry: {entry}") | ||
with h5py.File(entry, "r") as f: | ||
data = f[dataset_name][:] | ||
|
||
append_to_hdf5_dataset( | ||
file_path=save_location, | ||
dataset_name=dataset_name, | ||
array=data, | ||
expected_datapoint_shape=datapoint_shape | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file was deleted.
Oops, something went wrong.
24 changes: 24 additions & 0 deletions
24
use-cases/virgo/synthetic_data_gen/data_generation_hdf5.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH --account=intertwin | ||
#SBATCH --output=array_job.out | ||
#SBATCH --error=array_job.err | ||
#SBATCH --time=01:00:00 | ||
#SBATCH --mem-per-cpu=1G | ||
#SBATCH --partition=develbooster | ||
#SBATCH --array=1-250 | ||
#SBATCH --job-name=generate_virgo_data | ||
|
||
# Load required modules | ||
ml Stages/2024 GCC OpenMPI CUDA/12 MPI-settings/CUDA Python HDF5 PnetCDF libaio mpi4py | ||
|
||
# Activate Python virtual environment | ||
source ../../envAI_hdfml/bin/activate | ||
|
||
# Folder in which the datasets will be stored | ||
target_file="/p/scratch/intertwin/datasets/virgo_hdf5/virgo_data_${SLURM_ARRAY_TASK_ID}.hdf5" | ||
|
||
srun python synthetic_data_gen/file_gen_hdf5.py \ | ||
--num-datapoints 3000 \ | ||
--save-frequency 20 \ | ||
--save-location "$target_file" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters