From d85122388707fc187de8cf095a34ae2e77cc7210 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 16 Apr 2024 16:01:17 -0400 Subject: [PATCH] explicitly log number of images per contrast per split to datalist json --- monai/create_msd_data.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index a723be7..ad7fa36 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -313,6 +313,20 @@ def main(): logger.info(f"Number of validation images (not subjects): {params['numValidationImagesTotal']}") logger.info(f"Number of testing images (not subjects): {params['numTestImagesTotal']}") + # update the dataframe to remove subjects whose labels don't exist + df = df[~df['subjectID'].isin(subjects_to_remove)] + + # log the number of images per contrasts + params["numImagesPerContrast"] = { + "train": {}, + "validation": {}, + "test": {}, + } + for contrast in params["contrasts"]: + params["numImagesPerContrast"]["train"][contrast] = len(df[(df['subjectID'].isin(train_subs_all)) & (df['contrastID'] == contrast)]) + params["numImagesPerContrast"]["validation"][contrast] = len(df[(df['subjectID'].isin(val_subs_all)) & (df['contrastID'] == contrast)]) + params["numImagesPerContrast"]["test"][contrast] = len(df[(df['contrastID'] == contrast) & (df['subjectID'].isin(test_subs_all))]) + # dump train/val/test splits into a yaml file with open(f"datasplits/datasplit_{dataset_name}_seed{args.seed}.yaml", 'w') as file: yaml.dump({'train': sorted(train_subs_all), 'val': sorted(val_subs_all), 'test': sorted(test_subs_all)}, file, indent=2, sort_keys=True)