diff --git a/configs/train_all.yaml b/configs/train_all.yaml index 3cfb050..28f4f0a 100644 --- a/configs/train_all.yaml +++ b/configs/train_all.yaml @@ -1,5 +1,5 @@ seed: 50 -save_test_preds: True +save_test_preds: False directories: # Path to the saved models directory @@ -12,14 +12,9 @@ directories: dataset: # Dataset name (will be used as "group_name" for wandb logging) - name: lifelong-contrast-agnostic + name: contrast-agnostic-v3 # Path to the dataset directory containing all datalists (.json files) - # root_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/lifelong-contrast-agnostic - root_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/aggregation-20240517 - # Type of contrast to be used for training. "all" corresponds to training on all contrasts - contrast: all # choices: ["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"] - # Type of label to be used for training. - label_type: soft_bin # choices: ["hard", "soft", "soft_bin"] + root_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/v2-final-aggregation-20241017 preprocessing: # Online resampling of images to the specified spacing. @@ -34,7 +29,7 @@ opt: max_epochs: 250 batch_size: 2 # Interval between validation checks in epochs - check_val_every_n_epochs: 1 + check_val_every_n_epochs: 50 # Early stopping patience (this is until patience * check_val_every_n_epochs) early_stopping_patience: 20 @@ -72,19 +67,3 @@ model: [1, 2, 2], ] enable_deep_supervision: True - - mednext: - num_input_channels: 1 - base_num_features: 32 - num_classes: 1 - kernel_size: 3 # 3x3x3 and 5x5x5 were tested in publication - block_counts: [2,2,2,2,1,1,1,1,1] # number of blocks in each layer - norm_type: 'layer' - enable_deep_supervision: True - - swinunetr: - spatial_dims: 3 - depths: [2, 2, 2, 2] - num_heads: [3, 6, 12, 24] # number of heads in multi-head Attention - feature_size: 36 - use_pretrained: False \ No newline at end of file diff --git a/csa_generate_figures/analyse_csa_across.py b/csa_generate_figures/analyse_csa_across.py index e88ce0b..61235a4 100644 --- a/csa_generate_figures/analyse_csa_across.py +++ b/csa_generate_figures/analyse_csa_across.py @@ -13,16 +13,16 @@ import matplotlib.pyplot as plt # Setting the hue order as specified -# HUE_ORDER = ["softseg_bin", "deepseg_2d", "nnunet", "monai", "mednext", "swinunetr", "swinpretrained", "ensemble"] +HUE_ORDER = ["softseg_bin", "deepseg", "plain_320", "plain_384", "resencM"] # HUE_ORDER = ["softseg_bin", "deepseg_2d", "monai_single", "monai_7datasets", "swinunetr_7datasets"] -HUE_ORDER = ["softseg_bin", "monai_v21", "monai_v23", "monai_v2x"] +# HUE_ORDER = ["softseg_bin", "monai_v21", "monai_v23", "monai_v2x"] HUE_ORDER_THR = ["GT", "15", "1", "05", "01", "005"] HUE_ORDER_RES = ["1mm", "05mm", "15mm", "3mm", "2mm"] CONTRAST_ORDER = ["DWI", "MTon", "MToff", "T1w", "T2star", "T2w"] FONTSIZE = 12 -# XTICKS = ["GT", "DeepSeg2D", "C-A\nv2.1", "C-A\nv2.3", "C-A\nv2.x"] -XTICKS = ["GT", "contrast-agnostic\nv2.1", "contrast-agnostic\nv2.3", "contrast-agnostic\nv2.x"] +XTICKS = ["GT", "DeepSeg2D", "C-A\nplain_320", "C-A\nplain_384", "C-A\nresencM"] +# XTICKS = ["GT", "contrast-agnostic\nv2.1", "contrast-agnostic\nv2.3", "contrast-agnostic\nv2.x"] def save_figure(file_path, save_fname): @@ -61,8 +61,8 @@ def extract_contrast_and_details(filename, across="Method"): if across == "Method": # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg_2d|soft_input|bin_input).*' # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg_2d|nnunet|monai|mednext|swinunetr|swinpretrained|ensemble).*' - # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg_2d|monai_single|monai_7datasets|swinunetr_7datasets).*' - pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|monai_v21|monai_v23|monai_v2x).*' + pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg|plain_320|plain_384|resencM).*' + # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|monai_v21|monai_v23|monai_v2x).*' match = re.search(pattern, filename) if match: return match.group(1), match.group(2) diff --git a/csa_generate_figures/analyse_perslice_csa_across.py b/csa_generate_figures/analyse_perslice_csa_across.py index f1cda4d..bf9cde6 100644 --- a/csa_generate_figures/analyse_perslice_csa_across.py +++ b/csa_generate_figures/analyse_perslice_csa_across.py @@ -13,15 +13,15 @@ import matplotlib.pyplot as plt # Setting the hue order as specified -# HUE_ORDER = ["softseg_bin", "deepseg_2d", "nnunet", "monai", "mednext", "swinunetr", "swinpretrained", "ensemble"] +HUE_ORDER = ["softseg_bin", "deepseg", "plain_320", "plain_384", "resencM"] # HUE_ORDER = ["softseg_bin", "deepseg_2d", "monai_single", "monai_7datasets", "swinunetr_7datasets"] -HUE_ORDER = ["softseg_bin", "monai_v21", "monai_v23", "monai_v2x"] +# HUE_ORDER = ["softseg_bin", "monai_v21", "monai_v23", "monai_v2x"] HUE_ORDER_RES = ["1mm", "05mm", "15mm", "3mm", "2mm"] CONTRAST_ORDER = ["DWI", "MTon", "MToff", "T1w", "T2star", "T2w"] FONTSIZE = 12 -# XTICKS = ["GT", "DeepSeg2D", "C-A\nOriginal", "C-A\n7-datasets", "SwinUNETR\n7-datasets"] -XTICKS = ["GT", "contrast-agnostic\nv2.1", "contrast-agnostic\nv2.3", "contrast-agnostic\nv2.x"] +XTICKS = ["GT", "DeepSeg2D", "C-A\nplain_320", "C-A\nplain_384", "C-A\nresencM"] +# XTICKS = ["GT", "contrast-agnostic\nv2.1", "contrast-agnostic\nv2.3", "contrast-agnostic\nv2.x"] def save_figure(file_path, save_fname): @@ -59,9 +59,9 @@ def extract_contrast_and_details(filename, across="Method"): # pattern = r'.*iso-(\d+mm).*_(propseg|deepseg_2d|nnunet_3d_fullres|monai).*' if across == "Method": # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg_2d|soft_input|bin_input).*' - # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg_2d|nnunet|monai|mednext|swinunetr|swinpretrained|ensemble).*' + pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg|plain_320|plain_384|resencM).*' # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|deepseg_2d|monai_single|monai_7datasets|swinunetr_7datasets).*' - pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|monai_v21|monai_v23|monai_v2x).*' + # pattern = r'.*_(DWI|MTon|MToff|T1w|T2star|T2w).*_(softseg_bin|monai_v21|monai_v23|monai_v2x).*' match = re.search(pattern, filename) if match: return match.group(1), match.group(2) diff --git a/csa_qc_evaluation_spine_generic/comparison_across_models.sh b/csa_qc_evaluation_spine_generic/comparison_across_models.sh index 2cec479..9f3c4a5 100644 --- a/csa_qc_evaluation_spine_generic/comparison_across_models.sh +++ b/csa_qc_evaluation_spine_generic/comparison_across_models.sh @@ -77,7 +77,8 @@ label_vertebrae(){ FILELABEL="${file}_discs" # Label vertebral levels - sct_label_utils -i ${file}.nii.gz -disc ${FILELABEL}.nii.gz -o ${FILESEG}_labeled.nii.gz + # sct_label_utils -i ${file}.nii.gz -disc ${FILELABEL}.nii.gz -o ${FILESEG}_labeled.nii.gz + sct_label_utils -i ${FILESEG}.nii.gz -disc ${FILELABEL}.nii.gz -o ${FILESEG}_labeled.nii.gz # # Run QC # sct_qc -i ${file}.nii.gz -s ${file_seg}_labeled.nii.gz -p sct_label_vertebrae -qc ${PATH_QC} -qc-subject ${SUBJECT} @@ -245,22 +246,28 @@ segment_sc_MONAI(){ local contrast="$4" # used only for saving output file name local csv_fname="$5" # used for saving output file name - if [[ $model == 'monai_single' ]]; then + if [[ $model == 'plain_320' ]]; then # FILESEG="${file%%_*}_${contrast}_seg_monai_orig" FILESEG="${file%%_*}_${contrast}_seg_${model}" PATH_MODEL=${PATH_MONAI_MODEL_1} model_name='monai' + max_feat=320 + pad='edge' - elif [[ $model == 'monai_v23' ]]; then + elif [[ $model == 'plain_384' ]]; then # FILESEG="${file%%_*}_${contrast}_seg_monai_ll" FILESEG="${file%%_*}_${contrast}_seg_${model}" PATH_MODEL=${PATH_MONAI_MODEL_2} model_name='monai' + max_feat=384 + pad='edge' - elif [[ $model == 'monai_v2x' ]]; then + elif [[ $model == 'resencM' ]]; then FILESEG="${file%%_*}_${contrast}_seg_${model}" PATH_MODEL=${PATH_MONAI_MODEL_3} - model_name='monai' + model_name='monai-resencM' + max_feat=384 + pad='edge' elif [[ $model == 'swinunetr' ]]; then FILESEG="${file%%_*}_${contrast}_seg_swinunetr" @@ -275,7 +282,7 @@ segment_sc_MONAI(){ echo "Running inference from model at ${PATH_MODEL}" # NOTE: surprisingly, the `edge` padding is resulting in higher abs. csa error compared to `constant` (zero) padded inputs. # Run SC segmentation - python ${PATH_MONAI_SCRIPT} --path-img ${file}.nii.gz --path-out . --chkp-path ${PATH_MODEL} --device gpu --model ${model_name} --pred-type soft --pad-mode constant + python ${PATH_MONAI_SCRIPT} --path-img ${file}.nii.gz --path-out . --chkp-path ${PATH_MODEL} --device gpu --model ${model_name} --pred-type soft --pad-mode ${pad} --max-feat ${max_feat} # python ${PATH_MONAI_SCRIPT} --path-img ${file}.nii.gz --path-out . --chkp-path ${PATH_MODEL} --device gpu --model monai --pred-type soft --pad-mode constant # Rename MONAI output mv ${file}_pred.nii.gz ${FILESEG}.nii.gz @@ -293,8 +300,8 @@ segment_sc_MONAI(){ echo "${FILESEG},${slicewise_dice}" echo "${FILESEG},${slicewise_dice}" >> ${PATH_RESULTS}/slicewise_dice.csv - # # Generate QC report with soft prediction - # sct_qc -i ${file}.nii.gz -s ${FILESEG}.nii.gz -p sct_deepseg_sc -qc ${PATH_QC} -qc-subject ${SUBJECT} + # Generate QC report + sct_qc -i ${file}.nii.gz -s ${FILESEG}.nii.gz -p sct_deepseg_sc -qc ${PATH_QC} -qc-subject ${SUBJECT} # Compute CSA averaged across all slices C2-C3 vertebral levels for plotting the STD across contrasts # NOTE: this is per-level because not all contrasts have thes same FoV (C2-C3 is what all contrasts have in common) @@ -364,7 +371,7 @@ contrasts="space-other_T1w space-other_T2w space-other_T2star flip-1_mt-on_space # contrasts="space-other_T1w rec-average_dwi" # output csv filename -csv_fname="csa_softIn_CL_deploy" # "csa_label_inputs" +csv_fname="csa_model_sizes" # "csa_label_inputs" # Loop across contrasts for contrast in ${contrasts}; do @@ -435,12 +442,12 @@ for contrast in ${contrasts}; do sct_process_segmentation -i ${FILEBIN}.nii.gz -perslice 1 -vertfile ${file}_seg-manual_labeled.nii.gz -o $PATH_RESULTS/${csv_fname}_softseg_bin_perslice.csv -append 1 # 3. Segment SC using different methods, binarize at 0.5 and compute CSA - CUDA_VISIBLE_DEVICES=1 segment_sc_MONAI ${file} "${file}_seg-manual" 'monai_v21' ${contrast} ${csv_fname} - CUDA_VISIBLE_DEVICES=1 segment_sc_MONAI ${file} "${file}_seg-manual" 'monai_v23' ${contrast} ${csv_fname} - CUDA_VISIBLE_DEVICES=2 segment_sc_MONAI ${file} "${file}_seg-manual" 'monai_v2x' ${contrast} ${csv_fname} + CUDA_VISIBLE_DEVICES=1 segment_sc_MONAI ${file} "${file}_seg-manual" 'plain_320' ${contrast} ${csv_fname} + CUDA_VISIBLE_DEVICES=2 segment_sc_MONAI ${file} "${file}_seg-manual" 'plain_384' ${contrast} ${csv_fname} + CUDA_VISIBLE_DEVICES=3 segment_sc_MONAI ${file} "${file}_seg-manual" 'resencM' ${contrast} ${csv_fname} # CUDA_VISIBLE_DEVICES=2 segment_sc_MONAI ${file} "${file}_seg-manual" 'swinunetr' ${contrast} ${csv_fname} # CUDA_VISIBLE_DEVICES=3 segment_sc_nnUNet ${file} "${file}_seg-manual" '3d_fullres' ${contrast} ${csv_fname} - # segment_sc ${file} "${file}_seg-manual" 'deepseg' ${deepseg_input_c} ${contrast} ${csv_fname} + segment_sc ${file} "${file}_seg-manual" 'deepseg' ${deepseg_input_c} ${contrast} ${csv_fname} # TODO: run on deep/progseg after fixing the contrasts for those # segment_sc ${file_res} 't2' 'propseg' '' "${file}_seg-manual" ${native_res} diff --git a/datasplits/datasplit_dcm-brno_seed50.yaml b/datasplits/datasplit_dcm-brno_seed50.yaml new file mode 100644 index 0000000..ca256f4 --- /dev/null +++ b/datasplits/datasplit_dcm-brno_seed50.yaml @@ -0,0 +1,156 @@ +test: +- sub-1860B6472B +- sub-2295B4676B +- sub-2336B +- sub-2347B6454B +- sub-2353B +- sub-2450B6177B +- sub-2694B +- sub-2779B4786B +- sub-2804B4632B +- sub-3066B +- sub-3121B +- sub-3148B +- sub-3206B +- sub-3261B4658B +- sub-3758B6378B +- sub-4403B5238B +train: +- sub-1838B6577B +- sub-2060B4733B +- sub-2247B6492B +- sub-2249B4725B +- sub-2255B6372B +- sub-2271B +- sub-2284B4723B +- sub-2287B +- sub-2296B4806B +- sub-2319B4634B +- sub-2320B6468B +- sub-2321B6243B +- sub-2322B +- sub-2330B6137B +- sub-2333B6135B +- sub-2334B4627B +- sub-2340B6562B +- sub-2346B +- sub-2348B4595B +- sub-2349B6345B +- sub-2358B4598B +- sub-2371B4687B +- sub-2383B6196B +- sub-2384B +- sub-2386B +- sub-2389B +- sub-2390B4949B +- sub-2391B6466B +- sub-2407B5757B +- sub-2411B4591B +- sub-2412B +- sub-2413B4629B +- sub-2414B +- sub-2416B6027B +- sub-2417B4965B +- sub-2419B +- sub-2420B +- sub-2421B +- sub-2446B6192B +- sub-2451B +- sub-2472B +- sub-2479B6195B +- sub-2481B6079B +- sub-2599B4623B +- sub-2600B4616B +- sub-2606B6390B +- sub-2608B +- sub-2610B6476B +- sub-2614B +- sub-2644B5869B +- sub-2648B +- sub-2649B +- sub-2654B4633B +- sub-2667B6188B +- sub-2683B +- sub-2698B +- sub-2699B +- sub-2702B6403B +- sub-2717B +- sub-2737B +- sub-2741B4963B +- sub-2742B +- sub-2750B4646B +- sub-2751B +- sub-2752B +- sub-2756B +- sub-2757B +- sub-2773B4661B +- sub-2774B4763B +- sub-2778B6482B +- sub-2825B4881B +- sub-2847B6453B +- sub-2873B +- sub-2884B4714B +- sub-2886B5253B +- sub-2887B4699B +- sub-2902B4626B +- sub-2988B6186B +- sub-2999B +- sub-3005B +- sub-3006B +- sub-3007B +- sub-3013B6405B +- sub-3024B +- sub-3056B6483B +- sub-3060B6384B +- sub-3075B6206B +- sub-3077B +- sub-3094B +- sub-3111B +- sub-3120B6558B +- sub-3132B4599B +- sub-3140B +- sub-3146B +- sub-3167B +- sub-3170B +- sub-3194B +- sub-3207B4664B +- sub-3209B +- sub-3214B +- sub-3215B +- sub-3230B +- sub-3231B +- sub-3249B +- sub-3262B +- sub-3281B4590B +- sub-3289B4615B +- sub-3427B4966B +- sub-3608B6555B +- sub-3636B +- sub-3640B +- sub-3646B +- sub-3651B +- sub-3782B6193B +- sub-3793B4681B +- sub-3998B6406B +- sub-4088B6310B +- sub-4196B5202B +- sub-4793B6025B +- sub-4891B6572B +- sub-4960B6571B +val: +- sub-1836B6029B +- sub-2259B4883B +- sub-2315B4686B +- sub-2316B5038B +- sub-2355B +- sub-2372B2863B +- sub-2418B4628B +- sub-2590B4647B +- sub-2613B +- sub-2723B4648B +- sub-3179B +- sub-3227B6487B +- sub-3232B5049B +- sub-3237B +- sub-3641B +- sub-3759B6530B diff --git a/datasplits/datasplit_dcm-zurich-lesions-20231115_seed50.yaml b/datasplits/datasplit_dcm-zurich-lesions-20231115_seed50.yaml new file mode 100644 index 0000000..f5cc1f5 --- /dev/null +++ b/datasplits/datasplit_dcm-zurich-lesions-20231115_seed50.yaml @@ -0,0 +1,45 @@ +test: +- sub-11 +- sub-12 +- sub-24 +- sub-40 +- sub-41 +train: +- sub-03 +- sub-04 +- sub-06 +- sub-07 +- sub-08 +- sub-09 +- sub-10 +- sub-13 +- sub-14 +- sub-15 +- sub-16 +- sub-17 +- sub-18 +- sub-19 +- sub-20 +- sub-21 +- sub-22 +- sub-23 +- sub-25 +- sub-26 +- sub-27 +- sub-28 +- sub-29 +- sub-30 +- sub-32 +- sub-33 +- sub-34 +- sub-35 +- sub-36 +- sub-37 +- sub-39 +- sub-42 +val: +- sub-01 +- sub-02 +- sub-05 +- sub-31 +- sub-38 diff --git a/datasplits/datasplit_dcm-zurich-lesions_seed50.yaml b/datasplits/datasplit_dcm-zurich-lesions_seed50.yaml new file mode 100644 index 0000000..0bf5ddb --- /dev/null +++ b/datasplits/datasplit_dcm-zurich-lesions_seed50.yaml @@ -0,0 +1,17 @@ +test: +- sub-09 +- sub-16 +train: +- sub-01 +- sub-02 +- sub-05 +- sub-06 +- sub-07 +- sub-11 +- sub-12 +- sub-13 +- sub-14 +- sub-15 +val: +- sub-08 +- sub-10 diff --git a/datasplits/datasplit_lumbar-epfl_seed50.yaml b/datasplits/datasplit_lumbar-epfl_seed50.yaml index cdd18e2..6c92f1e 100644 --- a/datasplits/datasplit_lumbar-epfl_seed50.yaml +++ b/datasplits/datasplit_lumbar-epfl_seed50.yaml @@ -1,14 +1,24 @@ test: -- sub-04 -- sub-09 +- sub-05 +- sub-11 +- sub-20 train: - sub-01 - sub-02 - sub-03 -- sub-05 -- sub-06 +- sub-04 +- sub-07 - sub-08 +- sub-09 - sub-10 -- sub-11 +- sub-12 +- sub-14 +- sub-15 +- sub-16 +- sub-17 +- sub-18 +- sub-19 +- sub-21 val: -- sub-07 +- sub-06 +- sub-13 diff --git a/datasplits/datasplit_lumbar-vanderbilt_seed50.yaml b/datasplits/datasplit_lumbar-vanderbilt_seed50.yaml index b67ed22..b74f621 100644 --- a/datasplits/datasplit_lumbar-vanderbilt_seed50.yaml +++ b/datasplits/datasplit_lumbar-vanderbilt_seed50.yaml @@ -1,37 +1,33 @@ test: - sub-140549 -- sub-140803 +- sub-242142 +- sub-242327 - sub-242655 -- sub-246835 -- sub-247581 +- sub-245558 - sub-247770 train: - sub-140488 - sub-140624 - sub-140653 +- sub-140803 - sub-141011 - sub-141314 - sub-141763 - sub-142487 - sub-241968 - sub-241981 -- sub-242142 - sub-242174 -- sub-242327 +- sub-242236 - sub-242436 - sub-242474 - sub-242549 - sub-242582 -- sub-242732 +- sub-242714 - sub-242986 - sub-243011 -- sub-243114 - sub-243417 - sub-243445 -- sub-243479 -- sub-243488 -- sub-243637 -- sub-245558 +- sub-243777 - sub-245609 - sub-245664 - sub-245756 @@ -39,17 +35,22 @@ train: - sub-245971 - sub-245980 - sub-245995 +- sub-246626 - sub-246638 - sub-246829 - sub-246830 - sub-246831 +- sub-246835 +- sub-247090 - sub-247195 - sub-247285 +- sub-247581 - sub-247694 +- sub-247981 val: - sub-242186 -- sub-242236 -- sub-242714 -- sub-243777 -- sub-246626 -- sub-247981 +- sub-242732 +- sub-243114 +- sub-243479 +- sub-243488 +- sub-243637 diff --git a/datasplits/datasplit_sct-testing-large_seed50.yaml b/datasplits/datasplit_sct-testing-large_seed50.yaml index 35d809f..a97bb4a 100644 --- a/datasplits/datasplit_sct-testing-large_seed50.yaml +++ b/datasplits/datasplit_sct-testing-large_seed50.yaml @@ -1,65 +1,67 @@ test: -- sub-bwh007 -- sub-bwh019 -- sub-bwh030 -- sub-bwh032 +- sub-bwh010 - sub-bwh040 -- sub-bwh045 -- sub-bwh046 -- sub-bwh054 -- sub-bwh057 -- sub-bwh066 +- sub-bwh060 +- sub-bwh065 +- sub-bwh080 +- sub-karoTobiasMS008 +- sub-karoTobiasMS010 - sub-karoTobiasMS013 - sub-karoTobiasMS034 -- sub-karoTobiasMS051 -- sub-mgh3Tconnectome004 -- sub-mghCaterina015 +- sub-karoTobiasMS035 +- sub-koreajisun002 - sub-mghCaterina021 - sub-mghCaterina024 -- sub-mghCaterina025 -- sub-milanFilippi004 -- sub-milanFilippi010 -- sub-milanFilippi018 -- sub-milanFilippi040 -- sub-milanFilippi050 -- sub-milanFilippi054 -- sub-milanFilippi068 -- sub-milanFilippi069 -- sub-milanFilippi071 +- sub-milanFilippi002 +- sub-milanFilippi009 +- sub-milanFilippi021 +- sub-milanFilippi046 +- sub-milanFilippi052 +- sub-milanFilippi062 +- sub-milanFilippi070 +- sub-milanFilippi076 - sub-milanFilippi082 -- sub-milanFilippi086 -- sub-milanFilippi092 -- sub-milanFilippi100 -- sub-milanFilippi108 +- sub-milanFilippi093 +- sub-milanFilippi097 +- sub-milanFilippi106 +- sub-montpellierLesion002 +- sub-montpellierLesion003 +- sub-montpellierLesion007 - sub-nihReich035 -- sub-parisPradat047 -- sub-rennesMS004 -- sub-rennesMS035 +- sub-parisPradat044 +- sub-rennesMS009 +- sub-rennesMS016 +- sub-rennesMS036 +- sub-rennesMS050 - sub-rennesMS056 -- sub-rennesMS061 -- sub-rennesMS065 -- sub-rennesMS069 -- sub-twh034 -- sub-twh065 -- sub-twh074 -- sub-twh075 -- sub-uclCiccarelli017 -- sub-uclCiccarelli018 -- sub-uclCiccarelli023 +- sub-twh008 +- sub-twh014 +- sub-twh045 +- sub-twh063 +- sub-twh066 +- sub-twh072 +- sub-uclCiccarelli005 +- sub-uclCiccarelli009 +- sub-uclCiccarelli022 - sub-uclCiccarelli024 -- sub-uclCiccarelli030 -- sub-uclCiccarelli031 +- sub-uclCiccarelli025 +- sub-uclCiccarelli026 +- sub-uclCiccarelli035 +- sub-ucsfTalbott001 +- sub-ucsfTalbott002 - sub-ucsfTalbott003 -- sub-ucsfTalbott005 - sub-ucsfTalbott009 -- sub-ucsfTalbott025 -- sub-ucsfTalbott029 -- sub-ucsfTalbott032 -- sub-xuanwuYaou015 -- sub-xuanwuYaou034 -- sub-xuanwuYaou040 +- sub-ucsfTalbott016 +- sub-unfErssm006 +- sub-vanderbiltSeth002 +- sub-vanderbiltSeth005 +- sub-vanderbiltSeth008 +- sub-vanderbiltSeth021 +- sub-xuanwuYaou037 +- sub-xuanwuYaou042 train: - sub-amuVirginie001 +- sub-amuVirginie002 - sub-amuVirginie003 - sub-amuVirginie004 - sub-amuVirginie005 @@ -72,94 +74,110 @@ train: - sub-amuVirginie012 - sub-amuVirginie013 - sub-amuVirginie014 +- sub-amuVirginie016 - sub-amuVirginie017 - sub-amuVirginie018 -- sub-amuVirginie019 -- sub-amuVirginie020 +- sub-bwh001 - sub-bwh002 - sub-bwh003 - sub-bwh004 +- sub-bwh005 - sub-bwh006 +- sub-bwh007 - sub-bwh008 - sub-bwh009 -- sub-bwh010 - sub-bwh011 - sub-bwh012 - sub-bwh013 - sub-bwh014 - sub-bwh015 - sub-bwh016 +- sub-bwh017 - sub-bwh018 +- sub-bwh019 +- sub-bwh020 - sub-bwh021 - sub-bwh022 - sub-bwh023 - sub-bwh024 - sub-bwh025 - sub-bwh026 +- sub-bwh027 - sub-bwh028 -- sub-bwh029 +- sub-bwh030 - sub-bwh031 +- sub-bwh032 - sub-bwh033 - sub-bwh034 - sub-bwh035 - sub-bwh036 - sub-bwh037 -- sub-bwh038 - sub-bwh039 +- sub-bwh041 - sub-bwh042 - sub-bwh043 - sub-bwh044 +- sub-bwh045 +- sub-bwh046 - sub-bwh047 - sub-bwh048 +- sub-bwh049 - sub-bwh050 - sub-bwh051 - sub-bwh052 +- sub-bwh053 +- sub-bwh054 - sub-bwh055 - sub-bwh056 +- sub-bwh057 - sub-bwh058 - sub-bwh059 - sub-bwh061 +- sub-bwh062 - sub-bwh063 -- sub-bwh065 +- sub-bwh064 +- sub-bwh066 - sub-bwh067 - sub-bwh068 -- sub-bwh069 - sub-bwh070 +- sub-bwh071 +- sub-bwh072 - sub-bwh073 - sub-bwh074 +- sub-bwh076 - sub-bwh077 - sub-bwh078 - sub-bwh079 -- sub-bwh080 - sub-karoTobiasMS001 +- sub-karoTobiasMS002 - sub-karoTobiasMS003 - sub-karoTobiasMS004 - sub-karoTobiasMS005 - sub-karoTobiasMS006 - sub-karoTobiasMS007 -- sub-karoTobiasMS008 - sub-karoTobiasMS009 -- sub-karoTobiasMS010 - sub-karoTobiasMS011 - sub-karoTobiasMS012 - sub-karoTobiasMS014 - sub-karoTobiasMS015 +- sub-karoTobiasMS016 - sub-karoTobiasMS017 - sub-karoTobiasMS018 - sub-karoTobiasMS019 - sub-karoTobiasMS020 - sub-karoTobiasMS021 - sub-karoTobiasMS022 +- sub-karoTobiasMS023 - sub-karoTobiasMS024 - sub-karoTobiasMS025 - sub-karoTobiasMS026 - sub-karoTobiasMS027 +- sub-karoTobiasMS028 - sub-karoTobiasMS029 - sub-karoTobiasMS030 - sub-karoTobiasMS031 - sub-karoTobiasMS032 - sub-karoTobiasMS033 -- sub-karoTobiasMS035 - sub-karoTobiasMS036 - sub-karoTobiasMS037 - sub-karoTobiasMS038 @@ -171,46 +189,45 @@ train: - sub-karoTobiasMS044 - sub-karoTobiasMS045 - sub-karoTobiasMS046 -- sub-karoTobiasMS047 -- sub-karoTobiasMS048 - sub-karoTobiasMS049 - sub-karoTobiasMS050 -- sub-karoTobiasMS052 +- sub-karoTobiasMS051 - sub-karoTobiasMS053 - sub-koreajisun001 -- sub-koreajisun002 - sub-koreajisun003 - sub-koreajisun004 - sub-koreajisun005 -- sub-koreajisun006 - sub-koreajisun007 +- sub-koreajisun008 - sub-koreajisun009 - sub-koreajisun010 - sub-mgh3Tconnectome001 +- sub-mgh3Tconnectome002 - sub-mgh3Tconnectome003 +- sub-mgh3Tconnectome004 - sub-mgh3Tconnectome005 - sub-mghCaterina012 -- sub-mghCaterina013 - sub-mghCaterina014 +- sub-mghCaterina015 - sub-mghCaterina016 - sub-mghCaterina017 -- sub-mghCaterina018 - sub-mghCaterina019 - sub-mghCaterina020 - sub-mghCaterina022 - sub-mghCaterina023 +- sub-mghCaterina025 - sub-mghCaterina026 - sub-mghCaterina027 - sub-mghCaterina028 - sub-mghCaterina029 - sub-milanFilippi001 -- sub-milanFilippi002 - sub-milanFilippi003 +- sub-milanFilippi004 - sub-milanFilippi005 - sub-milanFilippi006 - sub-milanFilippi007 - sub-milanFilippi008 -- sub-milanFilippi009 +- sub-milanFilippi010 - sub-milanFilippi011 - sub-milanFilippi012 - sub-milanFilippi013 @@ -218,8 +235,9 @@ train: - sub-milanFilippi015 - sub-milanFilippi016 - sub-milanFilippi017 +- sub-milanFilippi018 +- sub-milanFilippi019 - sub-milanFilippi020 -- sub-milanFilippi021 - sub-milanFilippi022 - sub-milanFilippi023 - sub-milanFilippi024 @@ -232,25 +250,19 @@ train: - sub-milanFilippi031 - sub-milanFilippi032 - sub-milanFilippi033 -- sub-milanFilippi034 - sub-milanFilippi035 - sub-milanFilippi036 -- sub-milanFilippi037 +- sub-milanFilippi038 - sub-milanFilippi039 -- sub-milanFilippi041 +- sub-milanFilippi040 - sub-milanFilippi042 -- sub-milanFilippi044 +- sub-milanFilippi043 - sub-milanFilippi045 -- sub-milanFilippi046 -- sub-milanFilippi047 - sub-milanFilippi048 - sub-milanFilippi049 -- sub-milanFilippi051 -- sub-milanFilippi052 - sub-milanFilippi053 +- sub-milanFilippi054 - sub-milanFilippi055 -- sub-milanFilippi056 -- sub-milanFilippi057 - sub-milanFilippi058 - sub-milanFilippi059 - sub-milanFilippi060 @@ -259,70 +271,63 @@ train: - sub-milanFilippi064 - sub-milanFilippi065 - sub-milanFilippi066 -- sub-milanFilippi070 -- sub-milanFilippi072 +- sub-milanFilippi067 +- sub-milanFilippi068 +- sub-milanFilippi069 +- sub-milanFilippi071 - sub-milanFilippi073 - sub-milanFilippi074 -- sub-milanFilippi075 -- sub-milanFilippi076 - sub-milanFilippi077 - sub-milanFilippi078 - sub-milanFilippi079 +- sub-milanFilippi080 - sub-milanFilippi081 - sub-milanFilippi083 - sub-milanFilippi084 - sub-milanFilippi085 +- sub-milanFilippi086 - sub-milanFilippi087 - sub-milanFilippi088 - sub-milanFilippi089 -- sub-milanFilippi090 - sub-milanFilippi091 +- sub-milanFilippi092 - sub-milanFilippi094 - sub-milanFilippi095 - sub-milanFilippi096 -- sub-milanFilippi097 - sub-milanFilippi098 - sub-milanFilippi099 +- sub-milanFilippi100 - sub-milanFilippi101 +- sub-milanFilippi102 - sub-milanFilippi104 -- sub-milanFilippi105 -- sub-milanFilippi106 - sub-milanFilippi107 - sub-milanFilippi109 -- sub-milanFilippi110 - sub-milanFilippi111 -- sub-milanFilippi112 - sub-milanFilippi113 - sub-milanFilippi114 -- sub-milanFilippi115 - sub-milanFilippi116 - sub-milanFilippi117 - sub-montpellierLesion001 -- sub-montpellierLesion003 -- sub-montpellierLesion004 - sub-montpellierLesion005 - sub-montpellierLesion006 -- sub-montpellierLesion007 - sub-montpellierLesion008 - sub-montpellierLesion009 -- sub-montpellierLesion010 -- sub-montpellierLesion011 -- sub-montpellierLesion012 +- sub-montpellierLesion013 - sub-montpellierLesion014 +- sub-nihReich031 +- sub-nihReich032 - sub-nihReich033 - sub-nihReich034 - sub-parisPradat003 +- sub-parisPradat004 - sub-parisPradat010 -- sub-parisPradat011 -- sub-parisPradat044 - sub-parisPradat045 -- sub-parisPradat046 +- sub-parisPradat047 - sub-parisPradat048 - sub-parisPradat049 - sub-parisPradat050 - sub-parisPradat053 - sub-parisPradat057 -- sub-parisPradat058 - sub-parisPradat060 - sub-parisPradat061 - sub-parisPradat062 @@ -332,56 +337,55 @@ train: - sub-parisPradat066 - sub-parisPradat067 - sub-parisPradat068 +- sub-parisPradat069 - sub-parisPradat070 - sub-rennesElise001 - sub-rennesMS001 - sub-rennesMS002 - sub-rennesMS003 - sub-rennesMS005 -- sub-rennesMS008 -- sub-rennesMS009 -- sub-rennesMS013 - sub-rennesMS015 -- sub-rennesMS016 - sub-rennesMS017 +- sub-rennesMS018 +- sub-rennesMS019 - sub-rennesMS020 - sub-rennesMS021 - sub-rennesMS023 +- sub-rennesMS025 - sub-rennesMS026 - sub-rennesMS027 -- sub-rennesMS028 +- sub-rennesMS029 - sub-rennesMS031 -- sub-rennesMS032 - sub-rennesMS033 - sub-rennesMS034 -- sub-rennesMS036 +- sub-rennesMS035 - sub-rennesMS037 - sub-rennesMS038 -- sub-rennesMS039 +- sub-rennesMS041 - sub-rennesMS043 - sub-rennesMS044 - sub-rennesMS046 - sub-rennesMS047 - sub-rennesMS048 - sub-rennesMS049 -- sub-rennesMS050 - sub-rennesMS053 - sub-rennesMS055 - sub-rennesMS057 - sub-rennesMS058 - sub-rennesMS059 - sub-rennesMS060 +- sub-rennesMS061 - sub-rennesMS062 - sub-rennesMS063 - sub-rennesMS064 +- sub-rennesMS065 +- sub-rennesMS066 - sub-rennesMS067 -- sub-rennesMS068 +- sub-rennesMS069 - sub-rennesMS070 - sub-twh006 - sub-twh007 -- sub-twh008 - sub-twh013 -- sub-twh014 - sub-twh015 - sub-twh016 - sub-twh018 @@ -394,76 +398,73 @@ train: - sub-twh029 - sub-twh031 - sub-twh032 +- sub-twh034 - sub-twh036 - sub-twh039 - sub-twh041 - sub-twh043 - sub-twh044 -- sub-twh045 - sub-twh047 - sub-twh049 - sub-twh055 - sub-twh057 - sub-twh059 -- sub-twh063 -- sub-twh066 +- sub-twh065 - sub-twh067 - sub-twh068 - sub-twh069 - sub-twh070 - sub-twh071 -- sub-twh072 - sub-twh073 +- sub-twh074 - sub-twh076 - sub-twh077 - sub-twh079 - sub-twh080 -- sub-twh081 - sub-twh082 - sub-twh083 - sub-twh084 - sub-twh085 - sub-twh086 -- sub-twh087 +- sub-uclCiccarelli001 +- sub-uclCiccarelli002 - sub-uclCiccarelli003 - sub-uclCiccarelli004 -- sub-uclCiccarelli005 - sub-uclCiccarelli006 - sub-uclCiccarelli007 - sub-uclCiccarelli008 -- sub-uclCiccarelli009 +- sub-uclCiccarelli010 - sub-uclCiccarelli011 - sub-uclCiccarelli012 - sub-uclCiccarelli013 - sub-uclCiccarelli014 - sub-uclCiccarelli015 - sub-uclCiccarelli016 +- sub-uclCiccarelli017 - sub-uclCiccarelli019 -- sub-uclCiccarelli020 - sub-uclCiccarelli021 -- sub-uclCiccarelli022 -- sub-uclCiccarelli025 -- sub-uclCiccarelli026 +- sub-uclCiccarelli023 +- sub-uclCiccarelli027 - sub-uclCiccarelli028 +- sub-uclCiccarelli029 +- sub-uclCiccarelli030 - sub-uclCiccarelli032 - sub-uclCiccarelli033 - sub-uclCiccarelli034 -- sub-uclCiccarelli035 -- sub-uclCiccarelli036 - sub-uclCiccarelli037 - sub-uclCiccarelli038 - sub-uclCiccarelli039 -- sub-ucsfTalbott001 -- sub-ucsfTalbott002 - sub-ucsfTalbott004 +- sub-ucsfTalbott005 - sub-ucsfTalbott006 - sub-ucsfTalbott007 - sub-ucsfTalbott008 - sub-ucsfTalbott010 +- sub-ucsfTalbott011 - sub-ucsfTalbott012 +- sub-ucsfTalbott013 - sub-ucsfTalbott014 - sub-ucsfTalbott015 -- sub-ucsfTalbott016 - sub-ucsfTalbott017 - sub-ucsfTalbott018 - sub-ucsfTalbott019 @@ -471,33 +472,37 @@ train: - sub-ucsfTalbott021 - sub-ucsfTalbott022 - sub-ucsfTalbott023 +- sub-ucsfTalbott024 +- sub-ucsfTalbott025 - sub-ucsfTalbott026 - sub-ucsfTalbott027 - sub-ucsfTalbott028 -- sub-ucsfTalbott031 +- sub-ucsfTalbott029 +- sub-ucsfTalbott030 +- sub-ucsfTalbott032 +- sub-unfErssm007 +- sub-unfErssm014 +- sub-unfErssm018 +- sub-unfErssm025 +- sub-unfbiospective009 - sub-unfbiospective010 - sub-unfbiospective011 - sub-unfbiospective012 - sub-unfbiospective013 - sub-vanderbiltSeth001 -- sub-vanderbiltSeth002 - sub-vanderbiltSeth003 - sub-vanderbiltSeth004 -- sub-vanderbiltSeth005 - sub-vanderbiltSeth006 - sub-vanderbiltSeth007 -- sub-vanderbiltSeth008 - sub-vanderbiltSeth009 - sub-vanderbiltSeth010 -- sub-vanderbiltSeth011 - sub-vanderbiltSeth012 - sub-vanderbiltSeth013 -- sub-vanderbiltSeth014 - sub-vanderbiltSeth015 - sub-vanderbiltSeth016 -- sub-vanderbiltSeth017 - sub-vanderbiltSeth018 - sub-vanderbiltSeth019 +- sub-vanderbiltSeth020 - sub-vanderbiltSeth022 - sub-vanderbiltSeth023 - sub-xuanwuChenxi001 @@ -512,81 +517,83 @@ train: - sub-xuanwuYaou010 - sub-xuanwuYaou011 - sub-xuanwuYaou012 +- sub-xuanwuYaou013 - sub-xuanwuYaou014 +- sub-xuanwuYaou015 - sub-xuanwuYaou016 -- sub-xuanwuYaou033 +- sub-xuanwuYaou034 - sub-xuanwuYaou035 -- sub-xuanwuYaou036 -- sub-xuanwuYaou037 - sub-xuanwuYaou038 - sub-xuanwuYaou039 +- sub-xuanwuYaou040 - sub-xuanwuYaou041 -- sub-xuanwuYaou042 - sub-xuanwuYaou043 - sub-xuanwuYaou044 - sub-xuanwuYaou045 - sub-xuanwuYaou046 - sub-xuanwuYaou047 +- sub-xuanwuYaou048 - sub-xuanwuYaou049 - sub-xuanwuYaou050 - sub-xuanwuYaou051 val: -- sub-amuVirginie016 -- sub-bwh001 -- sub-bwh005 -- sub-bwh017 -- sub-bwh020 -- sub-bwh027 -- sub-bwh041 -- sub-bwh049 -- sub-bwh053 -- sub-bwh060 -- sub-bwh062 -- sub-bwh064 -- sub-bwh071 -- sub-bwh072 +- sub-amuVirginie019 +- sub-amuVirginie020 +- sub-bwh029 +- sub-bwh038 +- sub-bwh069 - sub-bwh075 -- sub-bwh076 -- sub-karoTobiasMS002 -- sub-karoTobiasMS016 -- sub-karoTobiasMS023 -- sub-karoTobiasMS028 -- sub-koreajisun008 -- sub-mgh3Tconnectome002 -- sub-milanFilippi019 -- sub-milanFilippi038 -- sub-milanFilippi043 -- sub-milanFilippi062 -- sub-milanFilippi067 -- sub-milanFilippi080 -- sub-milanFilippi093 -- sub-milanFilippi102 +- sub-karoTobiasMS047 +- sub-karoTobiasMS048 +- sub-karoTobiasMS052 +- sub-koreajisun006 +- sub-mghCaterina013 +- sub-mghCaterina018 +- sub-milanFilippi034 +- sub-milanFilippi037 +- sub-milanFilippi041 +- sub-milanFilippi044 +- sub-milanFilippi047 +- sub-milanFilippi050 +- sub-milanFilippi051 +- sub-milanFilippi056 +- sub-milanFilippi057 +- sub-milanFilippi072 +- sub-milanFilippi075 +- sub-milanFilippi090 - sub-milanFilippi103 -- sub-montpellierLesion002 -- sub-montpellierLesion013 -- sub-nihReich031 -- sub-nihReich032 -- sub-parisPradat004 +- sub-milanFilippi105 +- sub-milanFilippi108 +- sub-milanFilippi110 +- sub-milanFilippi112 +- sub-milanFilippi115 +- sub-montpellierLesion004 +- sub-montpellierLesion010 +- sub-montpellierLesion011 +- sub-montpellierLesion012 - sub-parisPradat009 +- sub-parisPradat011 +- sub-parisPradat046 - sub-parisPradat051 -- sub-parisPradat069 -- sub-rennesMS018 -- sub-rennesMS019 -- sub-rennesMS025 -- sub-rennesMS029 -- sub-rennesMS041 -- sub-rennesMS066 -- sub-uclCiccarelli001 -- sub-uclCiccarelli002 -- sub-uclCiccarelli010 -- sub-uclCiccarelli027 -- sub-uclCiccarelli029 -- sub-ucsfTalbott011 -- sub-ucsfTalbott013 -- sub-ucsfTalbott024 -- sub-ucsfTalbott030 -- sub-unfbiospective009 -- sub-vanderbiltSeth020 -- sub-vanderbiltSeth021 -- sub-xuanwuYaou013 -- sub-xuanwuYaou048 +- sub-parisPradat058 +- sub-rennesMS004 +- sub-rennesMS008 +- sub-rennesMS013 +- sub-rennesMS028 +- sub-rennesMS032 +- sub-rennesMS039 +- sub-rennesMS068 +- sub-twh075 +- sub-twh081 +- sub-twh087 +- sub-uclCiccarelli018 +- sub-uclCiccarelli020 +- sub-uclCiccarelli031 +- sub-uclCiccarelli036 +- sub-ucsfTalbott031 +- sub-unfPain005 +- sub-vanderbiltSeth011 +- sub-vanderbiltSeth014 +- sub-vanderbiltSeth017 +- sub-xuanwuYaou033 +- sub-xuanwuYaou036 diff --git a/datasplits/datasplit_spider-challenge-2023_seed50.yaml b/datasplits/datasplit_spider-challenge-2023_seed50.yaml new file mode 100644 index 0000000..ea4159b --- /dev/null +++ b/datasplits/datasplit_spider-challenge-2023_seed50.yaml @@ -0,0 +1,221 @@ +test: +- sub-005 +- sub-006 +- sub-028 +- sub-035 +- sub-038 +- sub-039 +- sub-051 +- sub-065 +- sub-081 +- sub-090 +- sub-109 +- sub-134 +- sub-136 +- sub-137 +- sub-144 +- sub-146 +- sub-155 +- sub-166 +- sub-174 +- sub-213 +- sub-215 +- sub-243 +train: +- sub-001 +- sub-002 +- sub-003 +- sub-004 +- sub-007 +- sub-008 +- sub-009 +- sub-011 +- sub-012 +- sub-013 +- sub-015 +- sub-017 +- sub-018 +- sub-019 +- sub-020 +- sub-021 +- sub-022 +- sub-023 +- sub-024 +- sub-025 +- sub-029 +- sub-030 +- sub-031 +- sub-032 +- sub-033 +- sub-034 +- sub-036 +- sub-040 +- sub-041 +- sub-042 +- sub-044 +- sub-045 +- sub-048 +- sub-050 +- sub-055 +- sub-056 +- sub-057 +- sub-058 +- sub-059 +- sub-060 +- sub-061 +- sub-062 +- sub-063 +- sub-064 +- sub-066 +- sub-067 +- sub-068 +- sub-069 +- sub-071 +- sub-072 +- sub-073 +- sub-074 +- sub-075 +- sub-077 +- sub-078 +- sub-080 +- sub-082 +- sub-083 +- sub-085 +- sub-086 +- sub-088 +- sub-089 +- sub-091 +- sub-093 +- sub-094 +- sub-095 +- sub-096 +- sub-097 +- sub-098 +- sub-099 +- sub-100 +- sub-101 +- sub-104 +- sub-105 +- sub-106 +- sub-108 +- sub-110 +- sub-112 +- sub-115 +- sub-116 +- sub-117 +- sub-118 +- sub-120 +- sub-121 +- sub-122 +- sub-123 +- sub-124 +- sub-126 +- sub-127 +- sub-129 +- sub-130 +- sub-131 +- sub-132 +- sub-133 +- sub-138 +- sub-140 +- sub-141 +- sub-142 +- sub-143 +- sub-145 +- sub-147 +- sub-149 +- sub-151 +- sub-152 +- sub-154 +- sub-156 +- sub-159 +- sub-160 +- sub-162 +- sub-163 +- sub-165 +- sub-167 +- sub-168 +- sub-169 +- sub-170 +- sub-171 +- sub-172 +- sub-173 +- sub-175 +- sub-179 +- sub-180 +- sub-181 +- sub-182 +- sub-183 +- sub-184 +- sub-185 +- sub-186 +- sub-188 +- sub-189 +- sub-190 +- sub-192 +- sub-193 +- sub-195 +- sub-196 +- sub-198 +- sub-200 +- sub-201 +- sub-202 +- sub-204 +- sub-205 +- sub-208 +- sub-209 +- sub-210 +- sub-212 +- sub-214 +- sub-217 +- sub-218 +- sub-219 +- sub-220 +- sub-221 +- sub-222 +- sub-224 +- sub-225 +- sub-226 +- sub-227 +- sub-228 +- sub-229 +- sub-231 +- sub-232 +- sub-233 +- sub-236 +- sub-239 +- sub-242 +- sub-244 +- sub-245 +- sub-246 +- sub-249 +- sub-250 +- sub-251 +- sub-252 +- sub-254 +- sub-255 +- sub-256 +- sub-257 +val: +- sub-010 +- sub-016 +- sub-037 +- sub-047 +- sub-052 +- sub-053 +- sub-087 +- sub-107 +- sub-113 +- sub-125 +- sub-161 +- sub-177 +- sub-187 +- sub-191 +- sub-197 +- sub-203 +- sub-207 +- sub-223 +- sub-234 +- sub-237 +- sub-241 +- sub-253 diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index 2de2e8b..56309df 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -12,7 +12,7 @@ import subprocess from datetime import datetime -from utils import get_git_branch_and_commit +from utils import get_git_branch_and_commit, get_image_stats import pandas as pd pd.set_option('display.max_colwidth', None) @@ -27,7 +27,10 @@ "basel-mp2rage": ["labels_softseg_bin", "desc-softseg_label-SC_seg"], "canproco": ["labels", "seg-manual"], "data-multi-subject": ["labels_softseg_bin", "desc-softseg_label-SC_seg"], + "dcm-brno": ["labels", "seg"], "dcm-zurich": ["labels", "label-SC_mask-manual"], + "dcm-zurich-lesions": ["labels", "label-SC_mask-manual"], + "dcm-zurich-lesions-20231115": ["labels", "label-SC_mask-manual"], "lumbar-epfl": ["labels", "seg-manual"], "lumbar-vanderbilt": ["labels", "label-SC_seg"], "nih-ms-mp2rage": ["labels", "label-SC_seg"], @@ -35,10 +38,11 @@ "sci-paris": ["labels", "seg-manual"], "sci-zurich": ["labels", "seg-manual"], "sct-testing-large": ["labels", "seg-manual"], + "spider-challenge-2023": ["labels", "label-SC_seg"] } # add abbreviations of pathologies in sct-testing-large and other datasets to be included in the aggregated dataset -PATHOLOGIES = ["ALS", "DCM", "NMO", "MS", "SCI"] +PATHOLOGIES = ["ALS", "DCM", "NMO", "MS", "SYR", "SCI", "LBP"] def get_parser(): @@ -106,7 +110,7 @@ def fetch_subject_nifti_details(filename_path): else: # TODO: add more contrasts as needed # contrast_pattern = r'.*_(T1w|T2w|T2star|PSIR|STIR|UNIT1|acq-MTon_MTR|acq-dwiMean_dwi|acq-b0Mean_dwi|acq-T1w_MTR).*' - contrast_pattern = r'.*_(T1w|T2w|T2star|PSIR|STIR|UNIT1|T1map|inv-1_part-mag_MP2RAGE|inv-2_part-mag_MP2RAGE|acq-MTon_MTR|acq-dwiMean_dwi|acq-T1w_MTR).*' + contrast_pattern = r'.*_(T1w|acq-lowresSag_T1w|T2w|acq-lowresSag_T2w|acq-highresSag_T2w|T2star|PSIR|STIR|UNIT1|T1map|inv-1_part-mag_MP2RAGE|inv-2_part-mag_MP2RAGE|acq-MTon_MTR|acq-dwiMean_dwi|acq-T1w_MTR).*' contrast = re.search(contrast_pattern, filename_path) contrastID = contrast.group(1) if contrast else "" @@ -171,9 +175,6 @@ def create_df(dataset_path): # get subjectID, sessionID and orientationID df['subjectID'], df['sessionID'], df['orientationID'], df['contrastID'] = zip(*df['filename'].map(fetch_subject_nifti_details)) - # sub_files = [ df[df['subjectID'] == 'sub-sherbrookeBiospective006']['filename'].values[idx] for idx in range(len(df[df['subjectID'] == 'sub-sherbrookeBiospective006']))] - # print(len(sub_files)) - if dataset_name == 'basel-mp2rage': # set the type of pathologyID as str @@ -214,11 +215,15 @@ def create_df(dataset_path): # load the participants.tsv file df_participants = pd.read_csv(os.path.join(dataset_path, 'participants.tsv'), sep='\t') - # store the pathology info by merging the "pathology" colume from df_participants to the df dataframe - df = pd.merge(df, df_participants[['participant_id', 'pathology']], left_on='subjectID', right_on='participant_id', how='left') + # NOTE: participant_id are like sub-NIH001, sub-NIH002, etc. but the subjectIDs are like sub-nih001, sub-nih002, etc. + # convert participant_id to lower case + df_participants['participant_id'] = df_participants['participant_id'].str.lower() - # rename the column to 'pathologyID' - df.rename(columns={'pathology': 'pathologyID'}, inplace=True) + # store the phenotype info by merging the "phenotype" colume from df_participants to the df dataframe + df = pd.merge(df, df_participants[['participant_id', 'phenotype']], left_on='subjectID', right_on='participant_id', how='left') + + # rename phenotype to pathologyID + df.rename(columns={'phenotype': 'pathologyID'}, inplace=True) elif dataset_name == 'sct-testing-large': @@ -267,11 +272,15 @@ def create_df(dataset_path): # load the participants.tsv file df_participants = pd.read_csv(os.path.join(dataset_path, 'participants.tsv'), sep='\t') - # store the pathology info by merging the "pathology_M0" colume from df_participants to the df dataframe - df = pd.merge(df, df_participants[['participant_id', 'pathology_M0']], left_on='subjectID', right_on='participant_id', how='left') + # NOTE: taking the phenotype directly and using it as pathology because pathology is MS for all phenotypes + # (easier to report different phenotypes ) + # store the pathology info by merging the "phenotype_M0" colume from df_participants to the df dataframe + df = pd.merge(df, df_participants[['participant_id', 'phenotype_M0']], left_on='subjectID', right_on='participant_id', how='left') + # replace nan with HC + df['phenotype_M0'].fillna('HC', inplace=True) # rename the column to 'pathologyID' - df.rename(columns={'pathology_M0': 'pathologyID'}, inplace=True) + df.rename(columns={'phenotype_M0': 'pathologyID'}, inplace=True) for file in df['filename']: @@ -289,6 +298,9 @@ def create_df(dataset_path): except subprocess.CalledProcessError as e: logger.error(f"Error in downloading {file} from git-annex: {e}") + elif dataset_name == 'spider-challenge-2023': + df['pathologyID'] = 'LBP' + else: # load the participants.tsv file df_participants = pd.read_csv(os.path.join(dataset_path, 'participants.tsv'), sep='\t') @@ -299,12 +311,24 @@ def create_df(dataset_path): # rename the column to 'pathologyID' df.rename(columns={'pathology': 'pathologyID'}, inplace=True) + + elif 'sci' in dataset_name: + # sci-zurich and sci-colorado do not have a 'pathology' column in their participants.tsv file + df['pathologyID'] = 'SCI' + + elif dataset_name == 'lumbar-epfl': + # lumbar-epfl does not have a 'pathology' column in their participants.tsv file + df['pathologyID'] = 'HC' - else: + else: df['pathologyID'] = 'n/a' - + + # get image stats + df['shape'], df['imgOrientation'], df['spacing'] = zip(*df['filename'].map(get_image_stats)) + # refactor to move filename and filesegname to the end of the dataframe - df = df[['datasetName', 'subjectID', 'sessionID', 'orientationID', 'contrastID', 'pathologyID', 'filename']] #, 'filesegname']] + df = df[['datasetName', 'subjectID', 'sessionID', 'orientationID', 'contrastID', 'pathologyID', + 'shape', 'imgOrientation', 'spacing', 'filename']] return df @@ -483,7 +507,10 @@ def main(): df['filename'] = df['filename'].replace(file, os.path.basename(file)) # reorder the columns - df = df[['datasetName', 'subjectID', 'sessionID', 'orientationID', 'contrastID', 'pathologyID', 'split', 'filename']] #, 'filesegname']] + df = df[['datasetName', 'subjectID', 'sessionID', 'orientationID', 'contrastID', 'pathologyID', 'shape', 'imgOrientation', 'spacing', 'split', 'filename']] + # sort the dataframe based on subjectID + df = df.sort_values(by=['subjectID'], ascending=True) + # save the dataframe to a csv file df.to_csv(os.path.join(args.path_out, f"df_{dataset_name}_seed{args.seed}.csv"), index=False) final_json = json.dumps(params, indent=4, sort_keys=True) diff --git a/monai/image.py b/monai/image.py new file mode 100644 index 0000000..03e670c --- /dev/null +++ b/monai/image.py @@ -0,0 +1,685 @@ +import os +import numpy as np +import nibabel as nib +import logging +from copy import deepcopy + +logger = logging.getLogger(__name__) + +class Image(object): + """ + Compact version of SCT's Image Class (https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/image.py#L245) + Create an object that behaves similarly to nibabel's image object. Useful additions include: dims, change_orientation and getNonZeroCoordinates. + """ + + def __init__(self, param=None, hdr=None, orientation=None, absolutepath=None, dim=None): + """ + :param param: string indicating a path to a image file or an `Image` object. + """ + + # initialization of all parameters + self.affine = None + self.data = None + self._path = None + self.ext = "" + + if absolutepath is not None: + self._path = os.path.abspath(absolutepath) + + # Case 1: load an image from file + if isinstance(param, str): + self.loadFromPath(param) + # Case 2: create a copy of an existing `Image` object + elif isinstance(param, type(self)): + self.copy(param) + # Case 3: create a blank image from a list of dimensions + elif isinstance(param, list): + self.data = np.zeros(param) + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + # Case 4: create an image from an existing data array + elif isinstance(param, (np.ndarray, np.generic)): + self.data = param + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + else: + raise TypeError('Image constructor takes at least one argument.') + + # Fix any mismatch between the array's datatype and the header datatype + self.fix_header_dtype() + + @property + def dim(self): + return get_dimension(self) + + @property + def orientation(self): + return get_orientation(self) + + @property + def absolutepath(self): + """ + Storage path (either actual or potential) + + Notes: + + - As several tools perform chdir() it's very important to have absolute paths + - When set, if relative: + + - If it already existed, it becomes a new basename in the old dirname + - Else, it becomes absolute (shortcut) + + Usually not directly touched (use `Image.save`), but in some cases it's + the best way to set it. + """ + return self._path + + @absolutepath.setter + def absolutepath(self, value): + if value is None: + self._path = None + return + elif not os.path.isabs(value) and self._path is not None: + value = os.path.join(os.path.dirname(self._path), value) + elif not os.path.isabs(value): + value = os.path.abspath(value) + self._path = value + + @property + def header(self): + return self.hdr + + @header.setter + def header(self, value): + self.hdr = value + + def __deepcopy__(self, memo): + return type(self)(deepcopy(self.data, memo), deepcopy(self.hdr, memo), deepcopy(self.orientation, memo), deepcopy(self.absolutepath, memo), deepcopy(self.dim, memo)) + + def copy(self, image=None): + if image is not None: + self.affine = deepcopy(image.affine) + self.data = deepcopy(image.data) + self.hdr = deepcopy(image.hdr) + self._path = deepcopy(image._path) + else: + return deepcopy(self) + + def loadFromPath(self, path): + """ + This function load an image from an absolute path using nibabel library + + :param path: path of the file from which the image will be loaded + :return: + """ + + self.absolutepath = os.path.abspath(path) + im_file = nib.load(self.absolutepath, mmap=True) + self.affine = im_file.affine.copy() + self.data = np.asanyarray(im_file.dataobj) + self.hdr = im_file.header.copy() + if path != self.absolutepath: + logger.debug("Loaded %s (%s) orientation %s shape %s", path, self.absolutepath, self.orientation, self.data.shape) + else: + logger.debug("Loaded %s orientation %s shape %s", path, self.orientation, self.data.shape) + + def change_orientation(self, orientation, inverse=False): + """ + Change orientation on image (in-place). + + :param orientation: orientation string (SCT "from" convention) + + :param inverse: if you think backwards, use this to specify that you actually\ + want to transform *from* the specified orientation, not *to*\ + it. + + """ + change_orientation(self, orientation, self, inverse=inverse) + return self + + def getNonZeroCoordinates(self, sorting=None, reverse_coord=False): + """ + This function return all the non-zero coordinates that the image contains. + Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value' + If reverse_coord is True, coordinate are sorted from larger to smaller. + + Removed Coordinate object + """ + n_dim = 1 + if self.dim[3] == 1: + n_dim = 3 + else: + n_dim = 4 + if self.dim[2] == 1: + n_dim = 2 + + if n_dim == 3: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]] for i in range(0, len(X))] + elif n_dim == 2: + try: + X, Y = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i]]] for i in range(0, len(X))] + except ValueError: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i], 0]] for i in range(0, len(X))] + + if sorting is not None: + if reverse_coord not in [True, False]: + raise ValueError('reverse_coord parameter must be a boolean') + + if sorting == 'x': + list_coordinates = sorted(list_coordinates, key=lambda el: el[0], reverse=reverse_coord) + elif sorting == 'y': + list_coordinates = sorted(list_coordinates, key=lambda el: el[1], reverse=reverse_coord) + elif sorting == 'z': + list_coordinates = sorted(list_coordinates, key=lambda el: el[2], reverse=reverse_coord) + elif sorting == 'value': + list_coordinates = sorted(list_coordinates, key=lambda el: el[3], reverse=reverse_coord) + else: + raise ValueError("sorting parameter must be either 'x', 'y', 'z' or 'value'") + + return list_coordinates + + def change_type(self, dtype): + """ + Change data type on image. + + Note: the image path is voided. + """ + change_type(self, dtype, self) + return self + + def fix_header_dtype(self): + """ + Change the header dtype to the match the datatype of the array. + """ + # Using bool for nibabel headers is unsupported, so use uint8 instead: + # `nibabel.spatialimages.HeaderDataError: data dtype "bool" not supported` + dtype_data = self.data.dtype + if dtype_data == bool: + dtype_data = np.uint8 + + dtype_header = self.hdr.get_data_dtype() + if dtype_header != dtype_data: + logger.warning(f"Image header specifies datatype '{dtype_header}', but array is of type " + f"'{dtype_data}'. Header metadata will be overwritten to use '{dtype_data}'.") + self.hdr.set_data_dtype(dtype_data) + + def save(self, path=None, dtype=None, verbose=1, mutable=False): + """ + Write an image in a nifti file + + :param path: Where to save the data, if None it will be taken from the\ + absolutepath member.\ + If path is a directory, will save to a file under this directory\ + with the basename from the absolutepath member. + + :param dtype: if not set, the image is saved in the same type as input data\ + if 'minimize', image storage space is minimized\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + + :param mutable: whether to update members with newly created path or dtype + """ + if mutable: # do all modifications in-place + # Case 1: `path` not specified + if path is None: + if self.absolutepath: # Fallback to the original filepath + path = self.absolutepath + else: + raise ValueError("Don't know where to save the image (no absolutepath or path parameter)") + # Case 2: `path` points to an existing directory + elif os.path.isdir(path): + if self.absolutepath: # Use the original filename, but save to the directory specified by `path` + path = os.path.join(os.path.abspath(path), os.path.basename(self.absolutepath)) + else: + raise ValueError("Don't know where to save the image (path parameter is dir, but absolutepath is " + "missing)") + # Case 3: `path` points to a file (or a *nonexistent* directory) so use its value as-is + # (We're okay with letting nonexistent directories slip through, because it's difficult to distinguish + # between nonexistent directories and nonexistent files. Plus, `nibabel` will catch any further errors.) + else: + pass + + if os.path.isfile(path) and verbose: + logger.warning("File %s already exists. Will overwrite it.", path) + if os.path.isabs(path): + logger.debug("Saving image to %s orientation %s shape %s", + path, self.orientation, self.data.shape) + else: + logger.debug("Saving image to %s (%s) orientation %s shape %s", + path, os.path.abspath(path), self.orientation, self.data.shape) + + # Now that `path` has been set and log messages have been written, we can assign it to the image itself + self.absolutepath = os.path.abspath(path) + + if dtype is not None: + self.change_type(dtype) + + if self.hdr is not None: + self.hdr.set_data_shape(self.data.shape) + self.fix_header_dtype() + + # nb. that copy() is important because if it were a memory map, save() would corrupt it + dataobj = self.data.copy() + affine = None + header = self.hdr.copy() if self.hdr is not None else None + nib.save(nib.nifti1.Nifti1Image(dataobj, affine, header), self.absolutepath) + if not os.path.isfile(self.absolutepath): + raise RuntimeError(f"Couldn't save image to {self.absolutepath}") + else: + # if we're not operating in-place, then make any required modifications on a throw-away copy + self.copy().save(path, dtype, verbose, mutable=True) + return self + + +class SlicerOneAxis(object): + """ + Image slicer to use when you don't care about the 2D slice orientation, + and don't want to specify them. + The slicer will just iterate through the right axis that corresponds to + its specification. + + Can help getting ranges and slice indices. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + + def __init__(self, im, axis="IS"): + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + axis_labels = "LRPAIS" + if len(axis) != 2: + raise ValueError() + if axis[0] not in axis_labels: + raise ValueError() + if axis[1] not in axis_labels: + raise ValueError() + if axis[0] != opposite_character[axis[1]]: + raise ValueError() + + for idx_axis in range(2): + dim_nr = im.orientation.find(axis[idx_axis]) + if dim_nr != -1: + break + if dim_nr == -1: + raise ValueError() + + # SCT convention + from_dir = im.orientation[dim_nr] + self.direction = +1 if axis[0] == from_dir else -1 + self.nb_slices = im.dim[dim_nr] + self.im = im + self.axis = axis + self._slice = lambda idx: tuple([(idx if x in axis else slice(None)) for x in im.orientation]) + + def __len__(self): + return self.nb_slices + + def __getitem__(self, idx): + """ + + :return: an image slice, at slicing index idx + :param idx: slicing index (according to the slicing direction) + """ + if isinstance(idx, slice): + raise NotImplementedError() + + if idx >= self.nb_slices: + raise IndexError("I just have {} slices!".format(self.nb_slices)) + + if self.direction == -1: + idx = self.nb_slices - 1 - idx + + return self.im.data[self._slice(idx)] + +def get_dimension(im_file, verbose=1): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + Get dimension from Image or nibabel object. Manages 2D, 3D or 4D images. + + :param: im_file: Image or nibabel object + :return: nx, ny, nz, nt, px, py, pz, pt + """ + if not isinstance(im_file, (nib.nifti1.Nifti1Image, Image)): + raise TypeError("The provided image file is neither a nibabel.nifti1.Nifti1Image instance nor an Image instance") + # initializating ndims [nx, ny, nz, nt] and pdims [px, py, pz, pt] + ndims = [1, 1, 1, 1] + pdims = [1, 1, 1, 1] + data_shape = im_file.header.get_data_shape() + zooms = im_file.header.get_zooms() + for i in range(min(len(data_shape), 4)): + ndims[i] = data_shape[i] + pdims[i] = zooms[i] + return *ndims, *pdims + + +def change_orientation(im_src, orientation, im_dst=None, inverse=False): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src: source image + :param orientation: orientation string (SCT "from" convention) + :param im_dst: destination image (can be the source image for in-place + operation, can be unset to generate one) + :param inverse: if you think backwards, use this to specify that you actually + want to transform *from* the specified orientation, not *to* it. + :return: an image with changed orientation + + .. note:: + - the resulting image has no path member set + - if the source image is < 3D, it is reshaped to 3D and the destination is 3D + """ + + if len(im_src.data.shape) < 3: + pass # Will reshape to 3D + elif len(im_src.data.shape) == 3: + pass # OK, standard 3D volume + elif len(im_src.data.shape) == 4: + pass # OK, standard 4D volume + elif len(im_src.data.shape) == 5 and im_src.header.get_intent()[0] == "vector": + pass # OK, physical displacement field + else: + raise NotImplementedError("Don't know how to change orientation for this image") + + im_src_orientation = im_src.orientation + im_dst_orientation = orientation + if inverse: + im_src_orientation, im_dst_orientation = im_dst_orientation, im_src_orientation + + perm, inversion = _get_permutations(im_src_orientation, im_dst_orientation) + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + im_src_data = im_src.data + if len(im_src_data.shape) < 3: + im_src_data = im_src_data.reshape(tuple(list(im_src_data.shape) + ([1] * (3 - len(im_src_data.shape))))) + + # Update data by performing inversions and swaps + + # axes inversion (flip) + data = im_src_data[::inversion[0], ::inversion[1], ::inversion[2]] + + # axes manipulations (transpose) + if perm == [1, 0, 2]: + data = np.swapaxes(data, 0, 1) + elif perm == [2, 1, 0]: + data = np.swapaxes(data, 0, 2) + elif perm == [0, 2, 1]: + data = np.swapaxes(data, 1, 2) + elif perm == [2, 0, 1]: + data = np.swapaxes(data, 0, 2) # transform [2, 0, 1] to [1, 0, 2] + data = np.swapaxes(data, 0, 1) # transform [1, 0, 2] to [0, 1, 2] + elif perm == [1, 2, 0]: + data = np.swapaxes(data, 0, 2) # transform [1, 2, 0] to [0, 2, 1] + data = np.swapaxes(data, 1, 2) # transform [0, 2, 1] to [0, 1, 2] + elif perm == [0, 1, 2]: + # do nothing + pass + else: + raise NotImplementedError() + + # Update header + + im_src_aff = im_src.hdr.get_best_affine() + aff = nib.orientations.inv_ornt_aff( + np.array((perm, inversion)).T, + im_src_data.shape) + im_dst_aff = np.matmul(im_src_aff, aff) + + im_dst.header.set_qform(im_dst_aff) + im_dst.header.set_sform(im_dst_aff) + im_dst.header.set_data_shape(data.shape) + im_dst.data = data + + return im_dst + + +def _get_permutations(im_src_orientation, im_dst_orientation): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src_orientation str: Orientation of source image. Example: 'RPI' + :param im_dest_orientation str: Orientation of destination image. Example: 'SAL' + :return: list of axes permutations and list of inversions to achieve an orientation change + """ + + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + + perm = [0, 1, 2] + inversion = [1, 1, 1] + for i, character in enumerate(im_src_orientation): + try: + perm[i] = im_dst_orientation.index(character) + except ValueError: + perm[i] = im_dst_orientation.index(opposite_character[character]) + inversion[i] = -1 + + return perm, inversion + + +def get_orientation(im): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im: an Image + :return: reference space string (ie. what's in Image.orientation) + """ + res = "".join(nib.orientations.aff2axcodes(im.hdr.get_best_affine())) + return orientation_string_nib2sct(res) + + +def orientation_string_nib2sct(s): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :return: SCT reference space code from nibabel one + """ + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + return "".join([opposite_character[x] for x in s]) + + +def change_type(im_src, dtype, im_dst=None): + """ + Change the voxel type of the image + + :param dtype: if not set, the image is saved in standard type\ + if 'minimize', image space is minimize\ + if 'minimize_int', image space is minimize and values are approximated to integers\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + :return: + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + if dtype is None: + return im_dst + + # get min/max from input image + min_in = np.nanmin(im_src.data) + max_in = np.nanmax(im_src.data) + + # find optimum type for the input image + if dtype in ('minimize', 'minimize_int'): + # warning: does not take intensity resolution into account, neither complex voxels + + # check if voxel values are real or integer + isInteger = True + if dtype == 'minimize': + for vox in im_src.data.flatten(): + if int(vox) != vox: + isInteger = False + break + + if isInteger: + if min_in >= 0: # unsigned + if max_in <= np.iinfo(np.uint8).max: + dtype = np.uint8 + elif max_in <= np.iinfo(np.uint16): + dtype = np.uint16 + elif max_in <= np.iinfo(np.uint32).max: + dtype = np.uint32 + elif max_in <= np.iinfo(np.uint64).max: + dtype = np.uint64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + if max_in <= np.iinfo(np.int8).max and min_in >= np.iinfo(np.int8).min: + dtype = np.int8 + elif max_in <= np.iinfo(np.int16).max and min_in >= np.iinfo(np.int16).min: + dtype = np.int16 + elif max_in <= np.iinfo(np.int32).max and min_in >= np.iinfo(np.int32).min: + dtype = np.int32 + elif max_in <= np.iinfo(np.int64).max and min_in >= np.iinfo(np.int64).min: + dtype = np.int64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + # if max_in <= np.finfo(np.float16).max and min_in >= np.finfo(np.float16).min: + # type = 'np.float16' # not supported by nibabel + if max_in <= np.finfo(np.float32).max and min_in >= np.finfo(np.float32).min: + dtype = np.float32 + elif max_in <= np.finfo(np.float64).max and min_in >= np.finfo(np.float64).min: + dtype = np.float64 + + dtype = to_dtype(dtype) + else: + dtype = to_dtype(dtype) + + # if output type is int, check if it needs intensity rescaling + if "int" in dtype.name: + # get min/max from output type + min_out = np.iinfo(dtype).min + max_out = np.iinfo(dtype).max + # before rescaling, check if there would be an intensity overflow + + if (min_in < min_out) or (max_in > max_out): + # This condition is important for binary images since we do not want to scale them + logger.warning(f"To avoid intensity overflow due to convertion to +{dtype.name}+, intensity will be rescaled to the maximum quantization scale") + # rescale intensity + data_rescaled = im_src.data * (max_out - min_out) / (max_in - min_in) + im_dst.data = data_rescaled - (data_rescaled.min() - min_out) + + # change type of data in both numpy array and nifti header + im_dst.data = getattr(np, dtype.name)(im_dst.data) + im_dst.hdr.set_data_dtype(dtype) + return im_dst + + +def to_dtype(dtype): + """ + Take a dtypeification and return an np.dtype + + :param dtype: dtypeification (string or np.dtype or None are supported for now) + :return: dtype or None + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + # TODO add more or filter on things supported by nibabel + + if dtype is None: + return None + if isinstance(dtype, type): + if isinstance(dtype(0).dtype, np.dtype): + return dtype(0).dtype + if isinstance(dtype, np.dtype): + return dtype + if isinstance(dtype, str): + return np.dtype(dtype) + + raise TypeError("data type {}: {} not understood".format(dtype.__class__, dtype)) + + +def zeros_like(img, dtype=None): + """ + + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, filled with zeros + + Similar to numpy.zeros_like(), the goal of the function is to show the developer's + intent and avoid doing a copy, which is slower than initialization with a constant. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + zimg = Image(np.zeros_like(img.data), hdr=img.hdr.copy()) + if dtype is not None: + zimg.change_type(dtype) + return zimg + + +def empty_like(img, dtype=None): + """ + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, whose data is uninitialized + + Similar to numpy.empty_like(), the goal of the function is to show the developer's + intent and avoid touching the allocated memory, because it will be written to + afterwards. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + dst = change_type(img, dtype) + return dst + + +def find_zmin_zmax(im, threshold=0.1): + """ + Find the min (and max) z-slice index below which (and above which) slices only have voxels below a given threshold. + + :param im: Image object + :param threshold: threshold to apply before looking for zmin/zmax, typically corresponding to noise level. + :return: [zmin, zmax] + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + slicer = SlicerOneAxis(im, axis="IS") + + # Make sure image is not empty + if not np.any(slicer): + logger.error('Input image is empty') + + # Iterate from bottom to top until we find data + for zmin in range(0, len(slicer)): + if np.any(slicer[zmin] > threshold): + break + + # Conversely from top to bottom + for zmax in range(len(slicer) - 1, zmin, -1): + if np.any(slicer[zmax] > threshold): + break + + return zmin, zmax \ No newline at end of file diff --git a/monai/main.py b/monai/main.py index b5cca5e..62478d5 100644 --- a/monai/main.py +++ b/monai/main.py @@ -4,6 +4,7 @@ from loguru import logger import yaml import json +import time import numpy as np import wandb @@ -24,15 +25,12 @@ from monai.data import (ThreadDataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, set_track_meta) from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) -# mednext -from nnunet_mednext import MedNeXt - # list of contrasts and their possible various names in the datasets CONTRASTS = { "t1map": ["T1map"], "mp2rage": ["inv-1_part-mag_MP2RAGE", "inv-2_part-mag_MP2RAGE"], - "t1w": ["T1w", "space-other_T1w"], - "t2w": ["T2w", "space-other_T2w"], + "t1w": ["T1w", "space-other_T1w", "acq-lowresSag_T1w"], + "t2w": ["T2w", "space-other_T2w", "acq-lowresSag_T2w", "acq-highresSag_T2w"], "t2star": ["T2star", "space-other_T2star"], "dwi": ["rec-average_dwi", "acq-dwiMean_dwi"], "mt-on": ["flip-1_mt-on_space-other_MTS", "acq-MTon_MTR"], @@ -46,9 +44,9 @@ def get_args(): parser = argparse.ArgumentParser(description='Script for training contrast-agnositc SC segmentation model.') # arguments for model - parser.add_argument('-m', '--model', choices=['nnunet-plain', 'nnunet-resencM', 'mednext', 'swinunetr'], + parser.add_argument('-m', '--model', choices=['nnunet-plain', 'nnunet-resencM', 'swinunetr'], default='nnunet', type=str, - help='Model type to be used. Options: nnunet, mednext, swinunetr.') + help='Model type to be used. Options: nnunet, swinunetr.') # path to the config file parser.add_argument("--config", type=str, default="./config.json", help="Path to the config file containing all training details.") @@ -164,9 +162,9 @@ def prepare_data(self): test_files = test_files[:6] train_cache_rate = 0.5 # 0.25 if args.model == 'swinunetr' else 0.5 - self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=train_cache_rate, num_workers=4, + self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=train_cache_rate, num_workers=12, copy_cache=False) - self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4, + self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=12, copy_cache=False) # define test transforms @@ -181,7 +179,7 @@ def prepare_data(self): meta_keys=["pred_meta_dict", "label_meta_dict"], nearest_interp=False, to_tensor=True), ]) - self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4, + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=8, copy_cache=False) # # avoid the computation of meta information in random transforms @@ -191,15 +189,15 @@ def prepare_data(self): # DATA LOADERS # -------------------------------- def train_dataloader(self): - return ThreadDataLoader(self.train_ds, batch_size=self.cfg["opt"]["batch_size"], shuffle=True, num_workers=16, + return ThreadDataLoader(self.train_ds, batch_size=self.cfg["opt"]["batch_size"], shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) def val_dataloader(self): - return ThreadDataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + return ThreadDataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True) def test_dataloader(self): - return ThreadDataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + return ThreadDataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) # -------------------------------- @@ -236,7 +234,7 @@ def training_step(self, batch, batch_idx): output = self.forward(inputs) # logits # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") - if args.model in ["nnunet-plain", "nnunet-resencM", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + if args.model in ["nnunet-plain", "nnunet-resencM"] and self.cfg['model'][args.model]["enable_deep_supervision"]: # calculate dice loss for each output loss, train_soft_dice = 0.0, 0.0 @@ -335,7 +333,7 @@ def validation_step(self, batch, batch_idx): outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", sw_batch_size=4, predictor=self.forward, overlap=0.5,) # outputs shape: (B, C, ) - if args.model in ["nnunet-plain", "nnunet-resencM", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + if args.model in ["nnunet-plain", "nnunet-resencM"] and self.cfg['model'][args.model]["enable_deep_supervision"]: # we only need the output with the highest resolution outputs = outputs[0] @@ -432,7 +430,7 @@ def test_step(self, batch, batch_idx): batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, sw_batch_size=4, predictor=self.forward, overlap=0.5) - if args.model in ["nnunet-plain", "nnunet-resencM", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + if args.model in ["nnunet-plain", "nnunet-resencM"] and self.cfg['model'][args.model]["enable_deep_supervision"]: # we only need the output with the highest resolution batch["pred"] = batch["pred"][0] @@ -631,49 +629,13 @@ def main(args): # save experiment id save_exp_id = f"{args.model}_seed={config['seed']}_" \ f"ndata={n_datasets}_ncont={n_contrasts}_" \ - f"nf={config['model']['nnunet-plain']['features_per_stage'][0]}_" \ + f"nf={config['model']['nnunet-plain']['features_per_stage'][-1]}_" \ f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ f"bs={config['opt']['batch_size']}" \ if args.debug: save_exp_id = f"DEBUG_{save_exp_id}" - elif args.model == "mednext": - # NOTE: the S, B models in the paper don't fit as-is for this data, gpu - # hence tweaking the models - logger.info(f"Using MedNext model tweaked ...") - net = MedNeXt( - in_channels=config["model"]["mednext"]["num_input_channels"], - n_channels=config["model"]["mednext"]["base_num_features"], - n_classes=config["model"]["mednext"]["num_classes"], - exp_r=[2,3,4,4,4,4,4,3,2], - kernel_size=config["model"]["mednext"]["kernel_size"], - deep_supervision=config["model"]["mednext"]["enable_deep_supervision"], - do_res=True, - do_res_up_down=True, - checkpoint_style="outside_block", - block_counts=config["model"]["mednext"]["block_counts"], - norm_type='layer', - ) - - # variable for saving patch size in the experiment id (same as crop_pad_size) - patch_size = f"{config['preprocessing']['crop_pad_size'][0]}x" \ - f"{config['preprocessing']['crop_pad_size'][1]}x" \ - f"{config['preprocessing']['crop_pad_size'][2]}" - # count number of 2s in the block_counts list - num_two_blocks = config["model"]["mednext"]["block_counts"].count(2) - norm_type = 'LN' if config["model"]["mednext"]["norm_type"] == 'layer' else 'GN' - # save experiment id - save_exp_id = f"{args.model}_seed={config['seed']}_" \ - f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ - f"nf={config['model']['mednext']['base_num_features']}_" \ - f"expR=base_bcs={num_two_blocks}_{norm_type}_" \ - f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ - f"bs={config['opt']['batch_size']}_{patch_size}" \ - - if args.debug: - save_exp_id = f"DEBUG_{save_exp_id}" - timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format save_exp_id = f"{save_exp_id}_{timestamp}" @@ -737,6 +699,7 @@ def main(args): num_model_params = count_parameters(model=net) logger.info(f"Number of Trainable model parameters: {(num_model_params / 1e6):.3f}M") + start_time = time.time() logger.info(f"Starting training from scratch ...") # wandb logger exp_logger = pl.loggers.WandbLogger( @@ -746,11 +709,11 @@ def main(args): log_model=True, # save best model using checkpoint callback project='contrast-agnostic', entity='naga-karthik', - config=config) + config=config, + mode="disabled") # Saving training script to wandb wandb.save("main.py") - wandb.save("transforms.py") # Enable TF32 on matmul and on cuDNN # torch._dynamo.config.verbose = True @@ -768,13 +731,17 @@ def main(args): # NOTE: Each epoch takes a looot of time with the aggregated dataset, so limiting the number of training batches # per epoch. Turns out that we don't need to go through all the training samples within an epoch for good performance. # nnunet hardcodes 250 training steps per epoch and we all know how it performs :) - limit_train_batches=0.5, # use 1.0 for full training + # limit_train_batches=0.5, # use 1.0 for full training enable_progress_bar=True) # profiler="simple",) # to profile the training time taken for each step # Train! trainer.fit(pl_model) logger.info(f" Training Done!") + end_time = time.time() + + duration = (end_time - start_time) + logger.info(f"Total training time: {duration / 3600}hrs {(duration / 60) % 60}mins {(duration) % 60}secs") else: logger.info(f" Resuming training from the latest checkpoint! ") diff --git a/monai/models.py b/monai/models.py index eed4684..d2269d5 100644 --- a/monai/models.py +++ b/monai/models.py @@ -15,44 +15,25 @@ # Define plans json taken from nnUNet # ====================================================================================================== nnunet_plans = { - "UNet_class_name": "PlainConvUNet", - "UNet_base_num_features": 32, - "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], - "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], - "pool_op_kernel_sizes": [ - [1, 1, 1], - [2, 2, 2], - [2, 2, 2], - [2, 2, 2], - [2, 2, 2], - [1, 2, 2] - ], - "conv_kernel_sizes": [ - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3] - ], - "unet_max_num_features": 320, + "arch_class_name": "dynamic_network_architectures.architectures.unet.PlainConvUNet", + "arch_kwargs": { + "n_stages": 6, + "features_per_stage": [32, 64, 128, 256, 384, 384], + "strides": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "n_conv_per_stage": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2] + }, + "arch_kwargs_requires_import": ["conv_op", "norm_op", "dropout_op", "nonlin"], } -# ====================================================================================================== -# Utils for nnUNet's Model -# ==================================================================================================== -class InitWeights_He(object): - def __init__(self, neg_slope=1e-2): - self.neg_slope = neg_slope - - def __call__(self, module): - if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): - module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) - if module.bias is not None: - module.bias = nn.init.constant_(module.bias, 0) - - # ====================================================================================================== # Define the network based on plans json # ==================================================================================================== @@ -180,7 +161,7 @@ def load_pretrained_swinunetr(model, path_pretrained_weights: str): if __name__ == "__main__": enable_deep_supervision = True - model = create_nnunet_from_plans(nnunet_plans, 1, 1, enable_deep_supervision) + model = create_nnunet_from_plans(nnunet_plans, 1, 1, deep_supervision=enable_deep_supervision) input = torch.randn(1, 1, 160, 224, 96) output = model(input) if enable_deep_supervision: diff --git a/monai/run_inference_single_image.py b/monai/run_inference_single_image.py index ebf928a..05f3966 100644 --- a/monai/run_inference_single_image.py +++ b/monai/run_inference_single_image.py @@ -8,6 +8,8 @@ import os import argparse import numpy as np +import pydoc +import warnings from loguru import logger import torch.nn.functional as F import torch @@ -20,43 +22,43 @@ from monai.inferers import sliding_window_inference from monai.data import (DataLoader, Dataset, decollate_batch) from monai.networks.nets import SwinUNETR -from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage, Spacingd, - LoadImaged, NormalizeIntensityd, EnsureChannelFirstd, - DivisiblePadd, Orientationd, ResizeWithPadOrCropd) -from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet -from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op -from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 +import monai.transforms as transforms -from nnunet_mednext import MedNeXt +# ---------------------------- Imports for nnUNet's Model ----------------------------- +from batchgenerators.utilities.file_and_folder_operations import join +from utils import recursive_find_python_class -# NNUNET global params -INIT_FILTERS=32 -ENABLE_DS = True nnunet_plans = { - "UNet_class_name": "PlainConvUNet", - "UNet_base_num_features": INIT_FILTERS, - "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], - "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], - "pool_op_kernel_sizes": [ - [1, 1, 1], - [2, 2, 2], - [2, 2, 2], - [2, 2, 2], - [2, 2, 2], - [1, 2, 2] - ], - "conv_kernel_sizes": [ - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3] - ], - "unet_max_num_features": 320, + "arch_class_name": "dynamic_network_architectures.architectures.unet.PlainConvUNet", + "arch_kwargs": { + "n_stages": 6, + "features_per_stage": [32, 64, 128, 256, 384, 384], + "strides": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "n_conv_per_stage": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2] + }, + "arch_kwargs_requires_import": ["conv_op", "norm_op", "dropout_op", "nonlin"], } +nnunet_plans_resencM = { + "arch_class_name": "dynamic_network_architectures.architectures.unet.ResidualEncoderUNet", + "arch_kwargs": { + "n_stages": nnunet_plans["arch_kwargs"]["n_stages"], + "features_per_stage": nnunet_plans["arch_kwargs"]["features_per_stage"], + "strides": nnunet_plans["arch_kwargs"]["strides"], + "n_blocks_per_stage": [1, 3, 4, 6, 6, 6], + "n_conv_per_stage_decoder": [1, 1, 1, 1, 1] + }, + "arch_kwargs_requires_import": ["conv_op", "norm_op", "dropout_op", "nonlin"], +} def get_parser(): @@ -76,13 +78,15 @@ def get_parser(): ' Default: 64x192x-1') parser.add_argument('--device', default="gpu", type=str, choices=["gpu", "cpu"], help='Device to run inference on. Default: cpu') - parser.add_argument('--model', default="monai", type=str, choices=["monai", "swinunetr", "mednext", "swinpretrained"], + parser.add_argument('--model', default="monai", type=str, choices=["monai", "monai-resencM", "swinunetr", "swinpretrained"], help='Model to use for inference. Default: monai') parser.add_argument('--pred-type', default="soft", type=str, choices=["soft", "hard"], help='Type of prediction to output/save. `soft` outputs soft segmentation masks with a threshold of 0.1' '`hard` outputs binarized masks thresholded at 0.5 Default: hard') parser.add_argument('--pad-mode', default="constant", type=str, choices=["constant", "edge", "reflect"], - help='Padding mode for the input image. Default: constant') + help='Padding mode for the input image. Default: edge') + parser.add_argument('--max-feat', default=384, type=int, + help='Maximum number of features in the network. Default: 320') return parser @@ -90,99 +94,87 @@ def get_parser(): # Test-time Transforms # =========================================================================== def inference_transforms_single_image(crop_size, pad_mode="constant"): - return Compose([ - LoadImaged(keys=["image"], image_only=False), - EnsureChannelFirstd(keys=["image"]), - Orientationd(keys=["image"], axcodes="RPI"), - Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)), - ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size,), - # pad inputs to ensure divisibility by no. of layers nnUNet has (5) - DivisiblePadd(keys=["image"], k=2**5, mode=pad_mode), - NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + return transforms.Compose([ + transforms.LoadImaged(keys=["image"], image_only=False), + transforms.EnsureChannelFirstd(keys=["image"]), + transforms.Orientationd(keys=["image"], axcodes="RPI"), + transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=(2)), + transforms.ResizeWithPadOrCropd(keys=["image"], spatial_size=crop_size, mode=pad_mode), + # pad inputs to ensure divisibility by no. of layers nnUNet has (5) + transforms.DivisiblePadd(keys=["image"], k=2**5, mode=pad_mode), + transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), ]) -# =========================================================================== -# Model utils -# =========================================================================== -class InitWeights_He(object): - def __init__(self, neg_slope=1e-2): - self.neg_slope = neg_slope - - def __call__(self, module): - if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): - module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) - if module.bias is not None: - module.bias = nn.init.constant_(module.bias, 0) - - # ============================================================================ # Define the network based on nnunet_plans dict # ============================================================================ -def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, deep_supervision: bool = True): +def create_nnunet_from_plans(plans, input_channels, output_channels, allow_init = True, + deep_supervision: bool = True): """ - Adapted from nnUNet's source code: + Adapted from nnUNet's source code: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 - """ - num_stages = len(plans["conv_kernel_sizes"]) - - dim = len(plans["conv_kernel_sizes"][0]) - conv_op = convert_dim_to_conv_op(dim) - - segmentation_network_class_name = plans["UNet_class_name"] - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + + network_class = plans["arch_class_name"] + # only the keys that "could" depend on the dataset are defined in main.py + architecture_kwargs = dict(**plans["arch_kwargs"]) + # rest of the default keys are defined here + architecture_kwargs.update({ + "kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "conv_op": "torch.nn.modules.conv.Conv3d", + "conv_bias": True, + "norm_op": "torch.nn.modules.instancenorm.InstanceNorm3d", + "norm_op_kwargs": { + "eps": 1e-05, + "affine": True }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accomodate that.' - network_class = mapping[segmentation_network_class_name] - - conv_or_blocks_per_stage = { - 'n_conv_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], - 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] - } - - # network class name!! - model = network_class( - input_channels=num_input_channels, - n_stages=num_stages, - features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, - plans["unet_max_num_features"]) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=plans["conv_kernel_sizes"], - strides=plans["pool_op_kernel_sizes"], - num_classes=num_classes, - deep_supervision=deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] + "dropout_op": None, + "dropout_op_kwargs": None, + "nonlin": "torch.nn.LeakyReLU", + "nonlin_kwargs": {"inplace": True}, + }) + + for ri in plans["arch_kwargs_requires_import"]: + if architecture_kwargs[ri] is not None: + architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri]) + + nw_class = pydoc.locate(network_class) + # sometimes things move around, this makes it so that we can at least recover some of that + if nw_class is None: + warnings.warn(f'Network class {network_class} not found. Attempting to locate it within ' + f'dynamic_network_architectures.architectures...') + + import dynamic_network_architectures + + nw_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), + network_class.split(".")[-1], + 'dynamic_network_architectures.architectures') + if nw_class is not None: + print(f'FOUND IT: {nw_class}') + else: + raise ImportError('Network class could not be found, please check/correct your plans file') + + if deep_supervision is not None and 'deep_supervision' not in architecture_kwargs.keys(): + architecture_kwargs['deep_supervision'] = deep_supervision + + network = nw_class( + input_channels=input_channels, + num_classes=output_channels, + **architecture_kwargs ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) - - return model + + if hasattr(network, 'initialize') and allow_init: + network.apply(network.initialize) + + return network # =========================================================================== @@ -197,9 +189,9 @@ def prepare_data(path_image, crop_size=(64, 160, 320), pad_mode="edge"): # define post-processing transforms for testing; taken (with explanations) from # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 - test_post_pred = Compose([ - EnsureTyped(keys=["pred"]), - Invertd(keys=["pred"], transform=transforms_test, + test_post_pred = transforms.Compose([ + transforms.EnsureTyped(keys=["pred"]), + transforms.Invertd(keys=["pred"], transform=transforms_test, orig_keys=["image"], meta_keys=["pred_meta_dict"], nearest_interp=False, to_tensor=True), @@ -251,7 +243,7 @@ def main(): # define root path for finding datalists path_image = args.path_img results_path = args.path_out - chkp_path = os.path.join(args.chkp_path, "model", "best_model.ckpt") + chkp_path = os.path.join(args.chkp_path, "best_model.ckpt") # save terminal outputs to a file logger.add(os.path.join(results_path, "logs.txt"), rotation="10 MB", level="INFO") @@ -268,10 +260,17 @@ def main(): test_ds, test_post_pred = prepare_data(path_image, crop_size=crop_size, pad_mode=args.pad_mode) test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + # temporary fix for the nnUNet model because v2x was trained with 320 max features but the newer + # models have 384 max features + nnunet_plans["arch_kwargs"]["features_per_stage"] = [32, 64, 128, 256, args.max_feat, args.max_feat] + # define model if args.model == "monai": - net = create_nnunet_from_plans(plans=nnunet_plans, - num_input_channels=1, num_classes=1, deep_supervision=ENABLE_DS) + net = create_nnunet_from_plans(plans=nnunet_plans, input_channels=1, + output_channels=1, deep_supervision=True) + elif args.model == "monai-resencM": + net = create_nnunet_from_plans(plans=nnunet_plans_resencM, input_channels=1, + output_channels=1, deep_supervision=True) elif args.model in ["swinunetr", "swinpretrained"]: # load config file @@ -286,26 +285,9 @@ def main(): depths=config["model"]["swinunetr"]["depths"], feature_size=config["model"]["swinunetr"]["feature_size"], num_heads=config["model"]["swinunetr"]["num_heads"]) - - elif args.model == "mednext": - config_path = os.path.join(args.chkp_path, "config.yaml") - with open(config_path, "r") as f: - config = yaml.safe_load(f) - - net = MedNeXt( - in_channels=config["model"]["mednext"]["num_input_channels"], - n_channels=config["model"]["mednext"]["base_num_features"], - n_classes=config["model"]["mednext"]["num_classes"], - exp_r=2, - kernel_size=config["model"]["mednext"]["kernel_size"], - deep_supervision=config["model"]["mednext"]["enable_deep_supervision"], - do_res=True, - do_res_up_down=True, - checkpoint_style="outside_block", - block_counts=config["model"]["mednext"]["block_counts"],) - + else: - raise ValueError("Model not recognized. Please choose from: nnunet, swinunetr, mednext") + raise ValueError("Model not recognized. Please choose from: nnunet, swinunetr") # define list to collect the test metrics @@ -340,7 +322,7 @@ def main(): batch["pred"] = sliding_window_inference(test_input, inference_roi_size, mode="gaussian", sw_batch_size=4, predictor=net, overlap=0.5, progress=False) - if args.model in ["monai", "mednext"]: + if args.model in ["monai", "monai-resencM"]: # take only the highest resolution prediction # NOTE: both these models use Deep Supervision, so only the highest resolution prediction is taken batch["pred"] = batch["pred"][0] @@ -370,7 +352,7 @@ def main(): # this takes about 0.25s on average on a CPU # image saver class - pred_saver = SaveImage( + pred_saver = transforms.SaveImage( output_dir=results_path, output_postfix="pred", output_ext=".nii.gz", separate_folder=False, print_log=False) # save the prediction diff --git a/monai/transforms.py b/monai/transforms.py index e332d68..4dff576 100644 --- a/monai/transforms.py +++ b/monai/transforms.py @@ -1,7 +1,62 @@ import numpy as np +from typing import Dict, Hashable, Mapping +from scipy.ndimage.morphology import binary_erosion +import torch import monai.transforms as transforms -import batchgenerators.transforms.spatial_transforms as bg_spatial_transforms +from monai.config import KeysCollection +from monai.transforms import MapTransform + + +class SpinalCordContourd(MapTransform): + def __init__( + self, + keys: KeysCollection, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + + for key in self.keys: + d[key] = self.create_contour_mask_3d(d[key]) + + return d + + def create_contour_mask_3d(self, segmentation_mask): + # Get the shape of the 3D mask + depth = segmentation_mask.shape[-1] + + # Initialize the contour mask + contour_mask = torch.zeros_like(segmentation_mask) + + # Process each slice + for i in range(depth): + # Extract the 2D slice + slice_2d = segmentation_mask[0, :, :, i] + + # Skip the slice if it is empty (because of padding) + if torch.sum(slice_2d) == 0: + continue + + # Ensure the slice is binary + binary_slice = (slice_2d > 0).astype(torch.uint8) + + # Perform binary erosion + # eroded_slice = binary_erosion(binary_slice, structure=kernel).astype(np.uint8) + eroded_slice = binary_erosion(binary_slice) + + # Subtract the eroded image from the original to get the contour + contour_slice = binary_slice - eroded_slice + + # Store the contour slice in the contour mask + contour_mask[0, :, :, i] = contour_slice + + return contour_mask + + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + return data def train_transforms(crop_size, lbl_key="label", pad_mode="zero", device="cuda"): @@ -23,8 +78,8 @@ def train_transforms(crop_size, lbl_key="label", pad_mode="zero", device="cuda") scale_range=(-0.2, 0.2), translate_range=(-0.1, 0.1)), transforms.Rand3DElasticd(keys=["image", lbl_key], prob=0.5, - sigma_range=(3.5, 5.5), - magnitude_range=(25., 35.)), + sigma_range=(3.5, 5.5), magnitude_range=(25., 35.),), + # mode=(2, 1), padding_mode="border",), transforms.RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), transforms.RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma transforms.RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), @@ -33,29 +88,14 @@ def train_transforms(crop_size, lbl_key="label", pad_mode="zero", device="cuda") transforms.RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform transforms.RandFlipd(keys=["image", lbl_key], prob=0.3,), transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + # # select one of: spinal cord contour transform or Identity transform (i.e. no transform) + # transforms.OneOf( + # transforms=[SpinalCordContourd(keys=["label"]), transforms.Identityd(keys=["label"])], + # weights=[0.25, 0.75] + # ) ] - batchgenerators_transforms = [ - bg_spatial_transforms.ChannelTranslation( - data_key="image", - const_channel=5, - max_shifts={'x': 5, 'y': 5, 'z': 5}) - ] - - # add batchgenerators transforms - transforms_final = monai_transforms + [ - # add another dim as BatchGenerator expects shape [B, C, H, W, D] - transforms.EnsureChannelFirstd(keys=["image", lbl_key], channel_dim="no_channel"), - # batchgenerators transforms work on numpy arrays - transforms.ToNumpyd(keys=["image", lbl_key]), - # use adaptors to port batchgenerators transforms to monai-compatible transforms - transforms.adaptor(batchgenerators_transforms[0], {"image": "image", "label": f"{lbl_key}"}), - # convert the data back to Tensor - transforms.EnsureTyped(keys=["image", lbl_key], device=device, track_meta=False), - transforms.SqueezeDimd(keys=[f"{lbl_key}"], dim=0), - ] - - return transforms.Compose(transforms_final) + return transforms.Compose(monai_transforms) def inference_transforms(crop_size, lbl_key="label"): return transforms.Compose([ diff --git a/monai/utils.py b/monai/utils.py index ef5a4d5..2e0e43b 100644 --- a/monai/utils.py +++ b/monai/utils.py @@ -4,19 +4,19 @@ import torch import subprocess import os -import json import pandas as pd import re import importlib import pkgutil from batchgenerators.utilities.file_and_folder_operations import * +from image import Image CONTRASTS = { "t1map": ["T1map"], "mp2rage": ["inv-1_part-mag_MP2RAGE", "inv-2_part-mag_MP2RAGE"], - "t1w": ["T1w", "space-other_T1w"], - "t2w": ["T2w", "space-other_T2w"], + "t1w": ["T1w", "space-other_T1w", "acq-lowresSag_T1w"], + "t2w": ["T2w", "space-other_T2w", "acq-lowresSag_T2w", "acq-highresSag_T2w"], "t2star": ["T2star", "space-other_T2star"], "dwi": ["rec-average_dwi", "acq-dwiMean_dwi"], "mt-on": ["flip-1_mt-on_space-other_MTS", "acq-MTon_MTR"], @@ -26,67 +26,367 @@ "stir": ["STIR"] } -def get_datasets_stats(datalists_root, contrasts_dict, path_save): - - datalists = [file for file in os.listdir(datalists_root) if file.endswith('_seed50.json')] - df = pd.DataFrame(columns=['train', 'validation', 'test']) - # collect all the contrasts from the datalists - for datalist in datalists: - json_path = os.path.join(datalists_root, datalist) - with open(json_path, 'r') as f: - data = json.load(f) - data = data['numImagesPerContrast'].items() +def get_image_stats(image_path): + """ + This function takes an image file as input and returns its orientation. - for split, contrast_info in data: - for contrast, num_images in contrast_info.items(): - # add the contrast and the number of images to the dataframe - if contrast not in df.index: - df.loc[contrast] = [0, 0, 0] - df.loc[contrast, split] += num_images + Input: + image_path : str : Path to the image file - # reshape dataframe and add a column for contrast - df = df.reset_index() - df = df.rename(columns={'index': 'contrast'}) + Returns: + orientation : str : Orientation of the image + """ + img = Image(str(image_path)) + img.change_orientation('RPI') + shape = img.dim[:3] + shape = [int(s) for s in shape] + # Get pixdim + pixdim = img.dim[4:7] + # If all are the same, the image is isotropic + if np.allclose(pixdim, pixdim[0], atol=1e-3): + orientation = 'isotropic' + # Elif, the lowest arg is 0 then the orientation is sagittal + elif np.argmax(pixdim) == 0: + orientation = 'sagittal' + # Elif, the lowest arg is 1 then the orientation is coronal + elif np.argmax(pixdim) == 1: + orientation = 'coronal' + # Else the orientation is axial + else: + orientation = 'axial' + resolution = np.round(pixdim, 2) + return shape, orientation, resolution - contrasts_final = list(contrasts_dict.keys()) - # rename the contrasts column as per contrasts_final - for c in df['contrast'].unique(): - for cf in contrasts_final: - if re.search(cf, c.lower()): - df.loc[df['contrast'] == c, 'contrast'] = cf - break +def get_pathology_wise_split(unified_df): - # NOTE: MTon-MTR is same as flip-1_mt-on_space-other_MTS, but the naming is not mt-on - # so doing the renaming manually - df.loc[df['contrast'] == 'acq-MTon_MTR', 'contrast'] = 'mt-on' - - # sum the duplicate contrasts - df = df.groupby('contrast').sum().reset_index() + # =========================================================================== + # Subject-wise Pathology split + # =========================================================================== + pathologies = unified_df['pathologyID'].unique() + + # count the number of subjects for each pathology + pathology_subjects = {} + for pathology in pathologies: + pathology_subjects[pathology] = len(unified_df[unified_df['pathologyID'] == pathology]['subjectID'].unique()) + + # merge MildCompression, DCM, MildCompression/DCM into DCM + pathology_subjects['DCM'] = pathology_subjects['MildCompression'] + pathology_subjects['MildCompression/DCM'] + pathology_subjects['DCM'] + pathology_subjects.pop('MildCompression', None) + pathology_subjects.pop('MildCompression/DCM', None) + + # =========================================================================== + # Contrast-wise Pathology split + # =========================================================================== + # for a given contrast, count the number of images for each pathology + pathology_contrasts = {} + for contrast in CONTRASTS.keys(): + pathology_contrasts[contrast] = {} + # initialize the count for each pathology + pathology_contrasts[contrast] = {pathology: 0 for pathology in pathologies} + for pathology in pathologies: + pathology_contrasts[contrast][pathology] += len(unified_df[(unified_df['pathologyID'] == pathology) & (unified_df['contrastID'] == contrast)]['filename']) + + # merge MildCompression, DCM, MildCompression/DCM into DCM + for contrast in pathology_contrasts.keys(): + pathology_contrasts[contrast]['DCM'] = pathology_contrasts[contrast]['MildCompression'] + pathology_contrasts[contrast]['MildCompression/DCM'] + pathology_contrasts[contrast]['DCM'] + pathology_contrasts[contrast].pop('MildCompression', None) + pathology_contrasts[contrast].pop('MildCompression/DCM', None) + + return pathology_subjects, pathology_contrasts - # rename columns - df = df.rename(columns={'contrast': 'Contrast', 'train': '#train_images', 'validation': '#validation_images', 'test': '#test_images'}) - # add a row for the total number of images - df.loc[len(df)] = ['TOTAL', df['#train_images'].sum(), df['#validation_images'].sum(), df['#test_images'].sum()] - # print(df.to_markdown(index=False)) - df.to_markdown(os.path.join(path_save, 'dataset_split.md'), index=False) - # # total number of images for each split - # print(f"Total number of training images: {df['#train_images'].sum()}") - # print(f"Total number of validation images: {df['#validation_images'].sum()}") - # print(f"Total number of test images: {df['#test_images'].sum()}") +def plot_contrast_wise_pathology(df, path_save): + # remove the TOTAL row + df = df[:-1] + # remove the #total_per_contrast column + df = df.drop(columns=['#total_per_contrast']) + + color_palette = { + 'HC': '#55A868', + 'MS': '#2B4373', + 'RRMS': '#6A89C8', + 'PPMS': '#88A1E0', + 'SPMS': '#3B5A92', + 'RIS': '#4C72B0', + 'DCM': '#DD8452', + 'SCI': '#C44E52', + 'NMO': '#937860', + 'SYR': '#b3b3b3', + 'ALS': '#DA8BC3', + 'LBP': '#CCB974' + } + + contrasts = df.index.tolist() + + # plot a pie chart for each contrast and save as different file + # NOTE: some pathologies with less subjects were overlapping so this is a hacky (and bad) way to fix this + # issue temporarily by reordering the columns of the df + for contrast in contrasts: + df_contrast = df.loc[[contrast]].T + # reorder the columsn to put 'ALS' between 'HC' and 'MS' + if contrast in ['dwi']: + df_contrast = df_contrast.reindex(['ALS', 'HC', 'MS', 'DCM', 'SCI', 'NMO', 'RRMS', 'PPMS', 'SPMS', 'RIS', 'LBP', 'SYR']) + elif contrast in ['unit1']: + # reorder the columsn to put 'PPMS' between 'MS' and 'RRMS' + df_contrast = df_contrast.reindex(['HC', 'MS', 'PPMS', 'RRMS', 'SPMS', 'RIS', 'DCM', 'SCI', 'NMO', 'ALS', 'LBP', 'SYR']) + elif contrast in ['t2star']: + df_contrast = df_contrast.reindex(['HC', 'ALS', 'MS', 'DCM', 'SCI', 'NMO', 'RRMS', 'PPMS', 'SPMS', 'RIS', 'LBP', 'SYR']) + + # for the given contrast, remove columns (pathologies) with 0 images + df_contrast = df_contrast[df_contrast[contrast] != 0] + + # adapted from https://matplotlib.org/stable/gallery/pie_and_polar_charts/pie_and_donut_labels.html + fig, ax = plt.subplots(figsize=(6.3, 3.5), subplot_kw=dict(aspect="equal")) # Increased figure size + wedges, texts = ax.pie( + df_contrast[contrast], + wedgeprops=dict(width=0.5), + startangle=-40, + colors=[color_palette[pathology] for pathology in df_contrast.index], + ) + + # Annotation customization + bbox_props = dict(boxstyle="square,pad=0.5", fc="w", ec="k", lw=0.72) + kw = dict(arrowprops=dict(arrowstyle="-"), bbox=bbox_props, zorder=0, va="center") + texts_to_adjust = [] # collect all annotations for adjustment + + for i, p in enumerate(wedges): + ang = (p.theta2 - p.theta1)/2. + p.theta1 + y = np.sin(np.deg2rad(ang)) + x = np.cos(np.deg2rad(ang)) + horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))] + connectionstyle = f"angle,angleA=0,angleB={ang}" + kw["arrowprops"].update({"connectionstyle": connectionstyle}) + # font size + kw["fontsize"] = 14.5 + # bold font + kw["fontweight"] = 'bold' + + # Skip annotation for 'SYR' + if df_contrast.index[i] == 'SYR': + continue + + # Push small labels further away from pie + distance = 1.1 + # for dwi contrast and sci pathology, plot the annotation to the left + if contrast == 'dwi' and df_contrast.index[i] == 'SCI': + distance = 1.5 + horizontalalignment = 'right' + if df_contrast.index[i] == 'ALS': + distance = 1.2 + horizontalalignment = 'left' + if contrast == 't2w' and df_contrast.index[i] in ['RIS', 'ALS', 'PPMS']: + if df_contrast.index[i] != 'PPMS': + distance = 1.4 + horizontalalignment = 'left' + else: + distance = 1 + horizontalalignment = 'right' + if contrast == 't1w' and df_contrast.index[i] == 'LBP': + distance = 1.3 + horizontalalignment = 'left' + + # Annotate with number of images per pathology + text = f"{df_contrast.index[i]} (n={df_contrast.iloc[i, 0]})" + annotation = ax.annotate(text, xy=(x, y), xytext=(distance*np.sign(x)*1.05, distance*y), + horizontalalignment=horizontalalignment, **kw) + texts_to_adjust.append(annotation) + + plt.ylabel('') + plt.tight_layout() + plt.savefig(os.path.join(path_save, f'{contrast}_pathology_split.png'), dpi=300) + plt.close() + + +def parse_spacing(spacing_str): + # Remove brackets and split by spaces + spacing_values = re.findall(r"[\d.]+", spacing_str) + # Convert to float + return [float(val) for val in spacing_values] + + +def get_datasets_stats(datalists_root, contrasts_dict, path_save): # create a unified dataframe combining all datasets - csvs = [os.path.join(datalists_root, file) for file in os.listdir(datalists_root) if file.endswith('.csv')] + csvs = [os.path.join(datalists_root, file) for file in os.listdir(datalists_root) if file.endswith('_seed50.csv')] unified_df = pd.concat([pd.read_csv(csv) for csv in csvs], ignore_index=True) # sort the dataframe by the dataset column unified_df = unified_df.sort_values(by='datasetName', ascending=True) - # save as csv + # save the originals as the csv unified_df.to_csv(os.path.join(path_save, 'dataset_contrast_agnostic.csv'), index=False) + # dropna + unified_df = unified_df.dropna(subset=['pathologyID']) + + contrasts_final = list(contrasts_dict.keys()) + # rename the contrasts column as per contrasts_final + for c in unified_df['contrastID'].unique(): + for cf in contrasts_final: + if re.search(cf, c.lower()): + unified_df.loc[unified_df['contrastID'] == c, 'contrastID'] = cf + break + + # NOTE: MTon-MTR is same as flip-1_mt-on_space-other_MTS, but the naming is not mt-on + # so doing the renaming manually + unified_df.loc[unified_df['contrastID'] == 'acq-MTon_MTR', 'contrastID'] = 'mt-on' + + # convert 'spacing' column from a string like "[1. 1. 1.]" to a list of floats + unified_df['spacing'] = unified_df['spacing'].apply(parse_spacing) + # for contrast in contrasts_final: + # print(f"Max resolution for {contrast}:") + # print([unified_df[unified_df['contrastID'] == contrast]['spacing'].apply(lambda x: x[i]).max() for i in range(3)]) + + splits = ['train', 'validation', 'test'] + # count the number of images per contrast + df = pd.DataFrame(columns=['contrast', 'train', 'validation', 'test']) + for contrast in contrasts_final: + df.loc[len(df)] = [contrast, 0, 0, 0] + # count the number of images per split + for split in splits: + df.loc[df['contrast'] == contrast, split] = len(unified_df[(unified_df['contrastID'] == contrast) & (unified_df['split'] == split)]) + + # sort the dataframe by the contrast column + df = df.sort_values(by='contrast', ascending=True) + # add a row for the total number of images + df.loc[len(df)] = ['TOTAL', df['train'].sum(), df['validation'].sum(), df['test'].sum()] + # add a column for total number of images per contrast + df['#images_per_contrast'] = df['train'] + df['validation'] + df['test'] + + df_mega = pd.DataFrame() + for orientation in ['sagittal', 'axial', 'isotropic']: + + # get the median resolutions per contrast + df_res_median = pd.DataFrame(columns=[f'contrast', 'x', 'y', 'z']) + df_res_min = pd.DataFrame(columns=[f'contrast', 'x', 'y', 'z']) + df_res_max = pd.DataFrame(columns=[f'contrast', 'x', 'y', 'z']) + df_res_mean = pd.DataFrame(columns=[f'contrast', 'x', 'y', 'z']) + + for contrast in contrasts_final: + + if len(unified_df[(unified_df['contrastID'] == contrast) & (unified_df['imgOrientation'] == orientation)]) == 0: + # set NA values for the contrast + df_res_median.loc[len(df_res_median)] = [f'{contrast}_{orientation}'] + [np.nan, np.nan, np.nan] + df_res_min.loc[len(df_res_min)] = [f'{contrast}_{orientation}'] + [np.nan, np.nan, np.nan] + df_res_max.loc[len(df_res_max)] = [f'{contrast}_{orientation}'] + [np.nan, np.nan, np.nan] + df_res_mean.loc[len(df_res_mean)] = [f'{contrast}_{orientation}'] + [np.nan, np.nan, np.nan] + + else: + # median + df_res_median.loc[len(df_res_median)] = [f'{contrast}_{orientation}'] + [ + unified_df[(unified_df['contrastID'] == contrast) & (unified_df['imgOrientation'] == orientation) + ]['spacing'].apply(lambda x: x[i]).median() for i in range(3)] + # min + df_res_min.loc[len(df_res_min)] = [f'{contrast}_{orientation}'] + [ + unified_df[(unified_df['contrastID'] == contrast) & (unified_df['imgOrientation'] == orientation) + ]['spacing'].apply(lambda x: x[i]).min() for i in range(3)] + # max + df_res_max.loc[len(df_res_max)] = [f'{contrast}_{orientation}'] + [ + unified_df[(unified_df['contrastID'] == contrast) & (unified_df['imgOrientation'] == orientation) + ]['spacing'].apply(lambda x: x[i]).max() for i in range(3)] + # mean + df_res_mean.loc[len(df_res_mean)] = [f'{contrast}_{orientation}'] + [ + unified_df[(unified_df['contrastID'] == contrast) & (unified_df['imgOrientation'] == orientation) + ]['spacing'].apply(lambda x: x[i]).mean().round(2) for i in range(3)] + + # drop rows with NA values + df_res_median = df_res_median.dropna() + df_res_min = df_res_min.dropna() + df_res_max = df_res_max.dropna() + df_res_mean = df_res_mean.dropna() + + # combine the x,y,z columns into a single column and drop the x,y,z columns + df_res_median['median_resolution_rpi'] = df_res_median.apply(lambda x: f"{x['x']} x {x['y']} x {x['z']}", axis=1) + df_res_median = df_res_median.drop(columns=['x', 'y', 'z']) + + df_res_min['min_resolution_rpi'] = df_res_min.apply(lambda x: f"{x['x']} x {x['y']} x {x['z']}", axis=1) + df_res_min = df_res_min.drop(columns=['x', 'y', 'z']) + + df_res_max['max_resolution_rpi'] = df_res_max.apply(lambda x: f"{x['x']} x {x['y']} x {x['z']}", axis=1) + df_res_max = df_res_max.drop(columns=['x', 'y', 'z']) + + df_res_mean['mean_resolution_rpi'] = df_res_mean.apply(lambda x: f"{x['x']} x {x['y']} x {x['z']}", axis=1) + df_res_mean = df_res_mean.drop(columns=['x', 'y', 'z']) + + # combine the dataframes based on the contrast column + df_res = pd.merge(df_res_median, df_res_min, on='contrast') + df_res = pd.merge(df_res, df_res_max, on='contrast') + df_res = pd.merge(df_res, df_res_mean, on='contrast') + + # sort the dataframe by the contrast column + df_res = df_res.sort_values(by='contrast', ascending=True) + + # concatenate the dataframes for different orientations on columns + df_mega = pd.concat([df_mega, df_res], axis=0) + + + # get the subject-wise pathology split + pathology_subjects, pathology_contrasts = get_pathology_wise_split(unified_df) + df_pathology = pd.DataFrame.from_dict(pathology_subjects, orient='index', columns=['Number of Subjects']) + # rename index to Pathology + df_pathology.index.name = 'Pathology' + # sort the dataframe by the pathology column + df_pathology = df_pathology.sort_index() + # add a row for the total number of subjects + df_pathology.loc['TOTAL'] = df_pathology['Number of Subjects'].sum() + + + # get the contrast-wise pathology split + df_contrast_pathology = pd.DataFrame.from_dict(pathology_contrasts, orient='index') + # sort the dataframe by the contrast column + df_contrast_pathology = df_contrast_pathology.sort_index() + # add a row for the total number of images + df_contrast_pathology.loc['TOTAL'] = df_contrast_pathology.sum() + # add a column for the total number of images per contrast + df_contrast_pathology['#total_per_contrast'] = df_contrast_pathology.sum(axis=1) + # print(df_contrast_pathology) + + # plots + save_path = os.path.join(path_save, 'plots') + os.makedirs(save_path, exist_ok=True) + plot_contrast_wise_pathology(df_contrast_pathology, save_path) + # exit() + + # sort the csvs list + csvs = sorted(csvs) + + # create a txt file + with open(os.path.join(path_save, 'dataset_stats_overall.txt'), 'w') as f: + # 1. write the datalists used in a bullet list + f.write(f"DATASETS USED FOR MODEL TRAINING (n={len(csvs)}):\n") + for csv in csvs: + # only write the dataset name + f.write(f"\t- {csv.split('_')[1]}\n") + f.write("\n") + + # 2. write the subject-wise pathology split + f.write(f"\nSUBJECT-WISE PATHOLOGY SPLIT:\n\n") + f.write(df_pathology.to_markdown()) + f.write("\n\n\n") + + # 3. write the contrast-wise pathology split (a subject can have multiple contrasts) + f.write(f"CONTRAST-WISE PATHOLOGY SPLIT (a subject can have multiple contrasts):\n\n") + f.write(df_contrast_pathology.to_markdown()) + f.write("\n\n\n") + + # 4. write the train/validation/test split per contrast + f.write(f"SPLITS ACROSS DIFFERENT CONTRASTS (n={len(contrasts_final)}):\n\n") + f.write(df.to_markdown(index=False)) + f.write("\n\n\n") + + # 5. write the median, min, max and mean resolutions per contrast + f.write(f"RESOLUTIONS PER CONTRAST PER ORIENTATION (in mm^3):\n\n") + f.write(f"How to interpret the table: Each row corresponds to the contrast and its orientation with the median, min, max and mean resolutions in mm^3.\n") + f.write(f"For simplification, if a contrast does not have any images in a particular orientation in the dataset, then the row is not present in the table.\n") + f.write(f"For e.g. if you want to report the mean (min, max) resolution of a contrast, say, 'dwi_axial', \n") + f.write(f"then you pick the respective element in each of the columns.\n") + f.write(f"\t i.e. mean in-plane resolution: 0.89x0.89; range: (0.34x0.34, 1x1), and, likewise, Slice thickness: 5.1; range: (4, 17.5)\n\n") + f.write(df_mega.to_markdown(index=False)) + f.write("\n\n") + + # Taken from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/find_class_by_name.py def recursive_find_python_class(folder: str, class_name: str, current_module: str): @@ -167,80 +467,6 @@ def get_git_branch_and_commit(dataset_path=None): return branch, commit -def numeric_score(prediction, groundtruth): - """Computation of statistical numerical scores: - - * FP = Soft False Positives - * FN = Soft False Negatives - * TP = Soft True Positives - * TN = Soft True Negatives - - Robust to hard or soft input masks. For example:: - prediction=np.asarray([0, 0.5, 1]) - groundtruth=np.asarray([0, 1, 1]) - Leads to FP = 1.5 - - Note: It assumes input values are between 0 and 1. - - Args: - prediction (ndarray): Binary prediction. - groundtruth (ndarray): Binary groundtruth. - - Returns: - float, float, float, float: FP, FN, TP, TN - """ - FP = float(np.sum(prediction * (1.0 - groundtruth))) - FN = float(np.sum((1.0 - prediction) * groundtruth)) - TP = float(np.sum(prediction * groundtruth)) - TN = float(np.sum((1.0 - prediction) * (1.0 - groundtruth))) - return FP, FN, TP, TN - - -def precision_score(prediction, groundtruth, err_value=0.0): - """Positive predictive value (PPV). - - Precision equals the number of true positive voxels divided by the sum of true and false positive voxels. - True and false positives are computed on soft masks, see ``"numeric_score"``. - Taken from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/metrics.py - - Args: - prediction (ndarray): First array. - groundtruth (ndarray): Second array. - err_value (float): Value returned in case of error. - - Returns: - float: Precision score. - """ - FP, FN, TP, TN = numeric_score(prediction, groundtruth) - if (TP + FP) <= 0.0: - return err_value - - precision = np.divide(TP, TP + FP) - return precision - - -def recall_score(prediction, groundtruth, err_value=0.0): - """True positive rate (TPR). - - Recall equals the number of true positive voxels divided by the sum of true positive and false negative voxels. - True positive and false negative values are computed on soft masks, see ``"numeric_score"``. - Taken from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/metrics.py - - Args: - prediction (ndarray): First array. - groundtruth (ndarray): Second array. - err_value (float): Value returned in case of error. - - Returns: - float: Recall score. - """ - FP, FN, TP, TN = numeric_score(prediction, groundtruth) - if (TP + FN) <= 0.0: - return err_value - TPR = np.divide(TP, TP + FN) - return TPR - - def dice_score(prediction, groundtruth): smooth = 1. numer = (prediction * groundtruth).sum() @@ -294,17 +520,6 @@ def plot_slices(image, gt, pred, debug=False): return fig -def compute_average_csa(patch, spacing): - num_slices = patch.shape[2] - areas = torch.empty(num_slices) - for slice_idx in range(num_slices): - slice_mask = patch[:, :, slice_idx] - area = torch.count_nonzero(slice_mask) * (spacing[0] * spacing[1]) - areas[slice_idx] = area - - return torch.mean(areas) - - class PolyLRScheduler(_LRScheduler): """ Polynomial learning rate scheduler. Taken from: @@ -338,6 +553,9 @@ def step(self, current_step=None): # tr_ix, val_tx, te_ix, fold = names_list[0] # print(len(tr_ix), len(val_tx), len(te_ix)) - # datalists_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/lifelong-contrast-agnostic" - datalists_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/aggregation-20240517" - get_datasets_stats(datalists_root, contrasts_dict=CONTRASTS, path_save=datalists_root) \ No newline at end of file + datalists_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/v2-final-aggregation-20241022" + get_datasets_stats(datalists_root, contrasts_dict=CONTRASTS, path_save=datalists_root) + # get_pathology_wise_split(datalists_root, path_save=datalists_root) + + # img_path = "/home/GRAMES.POLYMTL.CA/u114716/datasets/sci-colorado/sub-5694/anat/sub-5694_T2w.nii.gz" + # get_image_stats(img_path) \ No newline at end of file diff --git a/qc_other_datasets/generate_qc.sh b/qc_other_datasets/generate_qc.sh index 99157ab..be5e84a 100644 --- a/qc_other_datasets/generate_qc.sh +++ b/qc_other_datasets/generate_qc.sh @@ -199,17 +199,20 @@ segment_sc_MONAI(){ # PATH_MODEL=${PATH_MONAI_MODEL_BIN} # fi - if [[ $model == 'soft_monai_single' ]]; then + if [[ $model == 'v2x' ]]; then FILEPRED="${file}_seg_${model}" PATH_MODEL=${PATH_MONAI_MODEL_1} + max_feat=320 - elif [[ $model == 'soft_monai_2datasets' ]]; then + elif [[ $model == 'v2x_contour' ]]; then FILEPRED="${file}_seg_${model}" PATH_MODEL=${PATH_MONAI_MODEL_2} + max_feat=384 - elif [[ $model == 'soft_monai_7datasets' ]]; then + elif [[ $model == 'v2x_contour_dcm' ]]; then FILEPRED="${file}_seg_${model}" PATH_MODEL=${PATH_MONAI_MODEL_3} + max_feat=384 # elif [[ $model == 'swinunetr' ]]; then # FILEPRED="${file}_seg_swinunetr" @@ -221,7 +224,7 @@ segment_sc_MONAI(){ start_time=$(date +%s) # Run SC segmentation # python ${PATH_MONAI_SCRIPT} --path-img ${file}.nii.gz --path-out . --chkp-path ${PATH_MODEL} --device gpu --model ${model} - python ${PATH_MONAI_SCRIPT} --path-img ${file}.nii.gz --path-out . --chkp-path ${PATH_MODEL} --device gpu --model monai --pred-type soft --pad-mode edge + python ${PATH_MONAI_SCRIPT} --path-img ${file}.nii.gz --path-out . --chkp-path ${PATH_MODEL} --device gpu --model monai --pred-type soft --pad-mode edge --max-feat ${max_feat} # Rename MONAI output mv ${file}_pred.nii.gz ${FILEPRED}.nii.gz # Get the end time @@ -230,8 +233,8 @@ segment_sc_MONAI(){ execution_time=$(python3 -c "print($end_time - $start_time)") echo "${FILEPRED},${execution_time}" >> ${PATH_RESULTS}/execution_time.csv - # Generate QC report on soft predictions - sct_qc -i ${file}.nii.gz -s ${FILEPRED}.nii.gz -p sct_deepseg_sc -qc ${PATH_QC} -qc-subject ${SUBJECT} + # # Generate QC report on soft predictions + # sct_qc -i ${file}.nii.gz -s ${FILEPRED}.nii.gz -p sct_deepseg_sc -qc ${PATH_QC} -qc-subject ${SUBJECT} # Binarize MONAI output (which is soft by default); output is overwritten sct_maths -i ${FILEPRED}.nii.gz -bin 0.5 -o ${FILEPRED}.nii.gz @@ -294,13 +297,14 @@ echo "Contrast: ${contrast}" # Copy source images # check if the file exists -if [[ ! -e ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.* ]]; then - echo "File ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.* does not exist" >> ${PATH_LOG}/missing_files.log - echo "ERROR: File ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.* does not exist. Exiting." - exit 1 -else - rsync -Ravzh ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.* . -fi +# if [[ ! -e ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.nii.gz ]]; then +# echo "${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.nii.gz" +# echo "File ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.* does not exist" >> ${PATH_LOG}/missing_files.log +# echo "ERROR: File ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.* does not exist. Exiting." +# exit 1 +# else +rsync -Ravzh ${PATH_DATA}/./${SUBJECT}/anat/${SUBJECT//[\/]/_}*${contrast}.* . +# fi # Go to the folder where the data is cd ${PATH_DATA_PROCESSED}/${SUBJECT}/anat @@ -322,9 +326,9 @@ copy_gt_seg "${file}" "${label_suffix}" sct_qc -i ${file}.nii.gz -s ${file}_seg-manual.nii.gz -p sct_deepseg_sc -qc ${PATH_QC} -qc-subject ${SUBJECT} # Segment SC using different methods, binarize at 0.5 and compute QC -# CUDA_VISIBLE_DEVICES=1 segment_sc_MONAI ${file} 'soft_monai_single' -# CUDA_VISIBLE_DEVICES=2 segment_sc_MONAI ${file} 'soft_monai_2datasets' -CUDA_VISIBLE_DEVICES=0 segment_sc_MONAI ${file} 'soft_monai_7datasets' +CUDA_VISIBLE_DEVICES=2 segment_sc_MONAI ${file} 'v2x' +CUDA_VISIBLE_DEVICES=2 segment_sc_MONAI ${file} 'v2x_contour' +CUDA_VISIBLE_DEVICES=3 segment_sc_MONAI ${file} 'v2x_contour_dcm' # segment_sc_MONAI ${file} 'monai' # segment_sc_MONAI ${file} 'swinunetr' @@ -343,22 +347,22 @@ json_dict='{ ] }' -PATH_DATA_PROCESSED_CLEAN="${PATH_DATA_PROCESSED}_clean" -# create new folder and copy only the predictions -mkdir -p ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat - -rsync -avzh ${file}_seg_soft_monai_4datasets.nii.gz ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.nii.gz -rsync -avzh ${file}_seg-manual.json ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.json - -# create json file -echo $json_dict > ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.json -# re-save json files with indentation -python -c "import json; -json_file = '${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.json' -with open(json_file, 'r') as f: - data = json.load(f) - json.dump(data, open(json_file, 'w'), indent=4) -" +# PATH_DATA_PROCESSED_CLEAN="${PATH_DATA_PROCESSED}_clean" +# # create new folder and copy only the predictions +# mkdir -p ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat + +# rsync -avzh ${file}_seg_soft_monai_4datasets.nii.gz ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.nii.gz +# rsync -avzh ${file}_seg-manual.json ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.json + +# # create json file +# echo $json_dict > ${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.json +# # re-save json files with indentation +# python -c "import json; +# json_file = '${PATH_DATA_PROCESSED_CLEAN}/derivatives/labels_softseg_bin/${SUBJECT}/anat/${file%%_*}_space-other_${contrast}_desc-softseg_label-SC_seg.json' +# with open(json_file, 'r') as f: +# data = json.load(f) +# json.dump(data, open(json_file, 'w'), indent=4) +# " # ------------------------------------------------------------------------------ # End diff --git a/qc_other_datasets/generate_qc_datasets_gt.sh b/qc_other_datasets/generate_qc_datasets_gt.sh index 5ba02b4..2c23dc2 100644 --- a/qc_other_datasets/generate_qc_datasets_gt.sh +++ b/qc_other_datasets/generate_qc_datasets_gt.sh @@ -73,6 +73,14 @@ elif [[ $QC_DATASET == "canproco" ]]; then contrasts="PSIR STIR T2w" label_suffix="seg-manual" +elif [[ $QC_DATASET == "spider-challenge-2023" ]]; then + contrasts="acq-lowresSag_T1w acq-lowresSag_T2w acq-highresSag_T2w" + label_suffix="label-SC_seg" + +elif [[ $QC_DATASET == "bavaria-quebec-spine-ms-unstitched" ]]; then + contrasts="acq-ax_chunk-1_T2w" + label_suffix="seg-manual" + fi PATH_DERIVATIVES="${PATH_DATA}/derivatives/labels/./${SUBJECT}/anat" @@ -83,6 +91,7 @@ for contrast in ${contrasts}; do # NOTE: this replacement is cool because it automatically takes care of 'ses-XX' for longitudinal data file="${SUBJECT//[\/]/_}_${contrast}" + echo "Processing file: ${file}" # check if label exists in the dataset if [[ ! -f ${PATH_DERIVATIVES}/${file}_${label_suffix}.nii.gz ]]; then diff --git a/scripts/batch_create_datalists.sh b/scripts/batch_create_datalists.sh new file mode 100644 index 0000000..47157bb --- /dev/null +++ b/scripts/batch_create_datalists.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Unified script for creating datalist jsons + +# seed +SEED=50 + +# script path +PATH_SCRIPT="/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/contrast-agnostic-softseg-spinalcord" + +# list of datasets to train on +DATASETS=("basel-mp2rage" "canproco" "data-multi-subject" \ + "dcm-brno" "dcm-zurich" "dcm-zurich-lesions" "dcm-zurich-lesions-20231115" \ + "lumbar-epfl" "lumbar-vanderbilt" "nih-ms-mp2rage" \ + "sci-colorado" "sci-paris" "sci-zurich" \ + "sct-testing-large" "spider-challenge-2023") + +# base root path for the datasets +PATH_DATA_BASE="/home/GRAMES.POLYMTL.CA/u114716/datasets" + +# output path +folder_name=v2-final-aggregation-$(date +"%Y%m%d") +PATH_OUTPUT="/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/${folder_name}" + +# create datalists! + +for dataset in ${DATASETS[@]}; do + + echo "-----------------------------------" + echo "Creating datalist for ${dataset}" + echo "-----------------------------------" + + python ${PATH_SCRIPT}/monai/create_msd_data.py \ + --seed ${SEED} \ + --path-data ${PATH_DATA_BASE}/${dataset} \ + --path-out ${PATH_OUTPUT} +done + +echo "-----------------------------------" +echo "Done!" +echo "Datalists created in ${PATH_OUTPUT}" +echo "-----------------------------------" \ No newline at end of file diff --git a/scripts/combine_datasets_stats.py b/scripts/combine_datasets_stats.py new file mode 100644 index 0000000..1abed8a --- /dev/null +++ b/scripts/combine_datasets_stats.py @@ -0,0 +1,102 @@ +""" +Loop across csv files containing sequence parameters (one for each dataset) and combine the averaged details +into a single csv file.: + +Assumes that the input folder contains the CSV file (parsed_data.csv) for datasets included in the training of +the contrast-agnostic model. + +Example usage: + python scripts/combine_dataset_stats.py -i /path/to/folder/sequence_parameters + +Author: Naga Karthik + +""" + +import os +import re +import glob +import json +import argparse +from tqdm import tqdm + +import numpy as np +import pandas as pd +import nibabel as nib + + +def get_parser(): + """ + parser function + """ + + parser = argparse.ArgumentParser( + description='Loop across CSV files in the input folder and parse from them relevant information.', + prog=os.path.basename(__file__).strip('.py') + ) + parser.add_argument( + '-i', required=True, type=str, + help='Path to the folder containing the CSV files containing the sequence parameters for the datasets. ' + 'typically the output of fetch_sequence_parameters.py script.' + ) + + return parser + +def main(): + # Parse the command line arguments + parser = get_parser() + args = parser.parse_args() + + list_of_csvs = glob.glob(os.path.join(args.i, '*.csv'), recursive=True) + + # Create a pandas DataFrame from the parsed data + df = pd.DataFrame() + for csv_file in list_of_csvs: + temp_dict = {} + + csv_path = os.path.join(args.i, csv_file) + print(f"Processing {csv_path}") + df_temp = pd.read_csv(csv_path) + + dataset_name = csv_path.split('/')[-1].split('_')[0] + # Add the dataset name to the new dataframe + temp_dict['Dataset'] = dataset_name + + # get the unique values of the contrast + temp_dict['Contrast'] = list(df_temp['Contrast'].unique()) + + # get the unique values of "MagneticFieldStrength" + scanners = [] + scanner_strengths = df_temp['MagneticFieldStrength'].unique() + # remove nan values + scanner_strengths = [x for x in scanner_strengths if str(x) != 'nan'] + for scanner in scanner_strengths: + scanners.append(f"{scanner}T") + temp_dict['Scanners'] = scanners + + # get the unique values of "Manufacturer" + temp_dict['Manufacturer'] = list(df_temp['Manufacturer'].unique()) + + # get the min-max ranges for PixDim along each dimension + temp_dict['PixDim_min'] = df_temp['PixDim'].min() + temp_dict['PixDim_max'] = df_temp['PixDim'].max() + + # get the min-max ranges for SliceThickness + temp_dict['SliceThickness_min'] = df_temp['SliceThickness'].min() + temp_dict['SliceThickness_max'] = df_temp['SliceThickness'].max() + + temp_dict['Authors'] = 'n/a' + + # add the dictionary to the dataframe + df_dataset = pd.DataFrame.from_dict(temp_dict, orient='index').T + + df = pd.concat([df, df_dataset]) + + + # Save the DataFrame to a CSV file + out_path = os.path.join(args.i, 'combined_datasets_stats.csv') + df.to_csv(out_path, index=False) + + +if __name__ == "__main__": + main() + diff --git a/scripts/fetch_sequence_parameters.py b/scripts/fetch_sequence_parameters.py new file mode 100644 index 0000000..1bc714b --- /dev/null +++ b/scripts/fetch_sequence_parameters.py @@ -0,0 +1,229 @@ +""" +Loop across JSON sidecar files and nii headers in the input path and parse from them the following information: + MagneticFieldStrength + Manufacturer + ManufacturerModelName + ProtocolName + PixDim + SliceThickness + +If JSON sidecar is not available, fetch only PixDim and SliceThickness from nii header. + +The fetched information is saved to a CSV file (parsed_data.csv). + +Example usage: + python scripts/fetch_sequence_parameters.py -i /path/to/dataset -contrast T2w + +Original Author: Jan Valosek + +Adapted from: +https://github.com/ivadomed/model_seg_sci/blob/1079d156c322f555cab38c846240ab936ba98afb/utils/fetch_sequence_parameters.py + +Adapted by: Naga Karthik +""" + +import os +import re +import json +import argparse +from tqdm import tqdm + +import numpy as np +import pandas as pd +import nibabel as nib +from loguru import logger + +LIST_OF_PARAMETERS = [ + 'MagneticFieldStrength', + 'Manufacturer', + 'ManufacturerModelName', + 'ProtocolName', + 'RepetitionTime', + 'EchoTime', + 'InversionTime', + ] + + +def get_parser(): + """ + parser function + """ + + parser = argparse.ArgumentParser( + description='Loop across JSON sidecar files in the input path and parse from them relevant information.', + prog=os.path.basename(__file__).strip('.py') + ) + parser.add_argument( + '--path-datalists', required=True, type=str, + help='Path to datalist json containing filename of subjects included train/val/test splits. ' + '(i.e. the output of create_msd_data.py script). Output is stored in args.path_datalist/sequence_parameters' + ) + + return parser + + +def parse_json_file(file_path): + """ + Read the JSON file and parse from it relevant information. + :param file_path: + :return: + """ + + file_path = file_path.replace('.nii.gz', '.json') + + # Read the JSON file, return dict with n/a if the file is empty + try: + with open(file_path) as f: + data = json.load(f) + except: + print(f'WARNING: {file_path} is empty.') + return {param: "n/a" for param in LIST_OF_PARAMETERS} + + # Initialize an empty dictionary to store the parsed information + parsed_info = {} + + if 'sci-zurich' in file_path: + # For sci-zurich, JSON file contains a list of dictionaries, each dictionary contains a list of dictionaries + data = data['acqpar'][0] + elif 'sci-colorado' in file_path: + data = data + + # Loop across the parameters + for param in LIST_OF_PARAMETERS: + try: + parsed_info[param] = data[param] + except: + parsed_info[param] = "n/a" + + return parsed_info + + +def parse_nii_file(file_path): + """ + Read nii file header using nibabel and to get PixDim and SliceThickness. + We are doing this because 'PixelSpacing' and 'SliceThickness' can be missing from the JSON file. + :param file_path: + :return: + """ + + _, contrast_id = fetch_participant_details(file_path) + + # Read the nii file, return dict with n/a if the file is empty + try: + img = nib.load(file_path) + header = img.header + except: + print(f'WARNING: {file_path} is empty. Did you run git-annex get .?') + return {param: "n/a" for param in ['PixDim', 'SliceThickness']} + + # Initialize an empty dictionary to store the parsed information + parsed_info = { + 'Contrast': contrast_id, + 'PixDim': list(header['pixdim'][1:3]), + 'SliceThickness': float(header['pixdim'][3]) + } + + return parsed_info + + +def fetch_participant_details(input_string): + """ + Fetch the participant_id from the input string + :param input_string: input string or path, e.g. 'sub-5416_T2w_seg_nnunet' + :return participant_id: subject id, e.g. 'sub-5416' + """ + participant = re.search('sub-(.*?)[_/]', input_string) # [_/] slash or underscore + participant_id = participant.group(0)[:-1] if participant else "" # [:-1] removes the last underscore or slash + + if 'data-multi-subject' in input_string: + # NOTE: the preprocessed spine-generic dataset have a weird BIDS naming convention (due to how they were preprocessed) + contrast_pattern = r'.*_(space-other_T1w|space-other_T2w|space-other_T2star|flip-1_mt-on_space-other_MTS|flip-2_mt-off_space-other_MTS|rec-average_dwi).*' + else: + # TODO: add more contrasts as needed + # contrast_pattern = r'.*_(T1w|T2w|T2star|PSIR|STIR|UNIT1|acq-MTon_MTR|acq-dwiMean_dwi|acq-b0Mean_dwi|acq-T1w_MTR).*' + contrast_pattern = r'.*_(T1w|acq-lowresSag_T1w|T2w|acq-lowresSag_T2w|acq-highresSag_T2w|T2star|PSIR|STIR|UNIT1|T1map|inv-1_part-mag_MP2RAGE|inv-2_part-mag_MP2RAGE|acq-MTon_MTR|acq-dwiMean_dwi|acq-T1w_MTR).*' + contrast = re.search(contrast_pattern, input_string) + contrast_id = contrast.group(1) if contrast else "" + + + return participant_id, contrast_id + + +def main(): + # Parse the command line arguments + parser = get_parser() + args = parser.parse_args() + + datalists = [os.path.join(args.path_datalists, file) for file in os.listdir(args.path_datalists) if file.endswith('_seed50.json')] + out_path = os.path.join(args.path_datalists, 'sequence_parameters') + if not os.path.exists(out_path): + os.makedirs(out_path, exist_ok=True) + + logger.add(os.path.join(out_path, 'sequence_params.log'), rotation='10 MB', level='INFO') + logger.info(f"Saving parsed data to {out_path}") + + for datalist in datalists: + dataset_name = datalist.split('/')[-1].split('_')[1] + + if not os.path.exists(datalist): + logger.info(f'ERROR: {datalist} does not exist. Run create_msd_data.py script first.') + + # load json file + with open(datalist, 'r') as f: + data = json.load(f) + + list_of_files = [] + for split in ['train', 'validation', 'test']: + for idx in range(len(data[split])): + list_of_files.append(data[split][idx]["image"]) + + # Initialize an empty list to store the parsed data + parsed_data = [] + + + # Loop across JSON sidecar files in the input path + for file in tqdm(list_of_files): + # print(f'Parsing {file}') + parsed_json = parse_json_file(file) + parsed_header = parse_nii_file(file) + # Note: **metrics is used to unpack the key-value pairs from the metrics dictionary + parsed_data.append({'filename': file, **parsed_json, **parsed_header}) + + # Create a pandas DataFrame from the parsed data + df = pd.DataFrame(parsed_data) + + df['filename'] = df['filename'].apply(lambda x: x.replace('/home/GRAMES.POLYMTL.CA/u114716/datasets/', '')) + + # Save the DataFrame to a CSV file + df.to_csv(os.path.join(out_path, f'{dataset_name}_parsed_data.csv'), index=False) + logger.info(f"Parsed data saved to {os.path.join(out_path, f'{dataset_name}_parsed_data.csv')}") + + # # For sci-paris, we do not have JSON sidecars --> we can fetch only PixDim and SliceThickness from nii header + # if 'sci-paris' in dir_path: + # # Print the min and max values of the PixDim, and SliceThickness + # print(df[['PixDim', 'SliceThickness']].agg([np.min, np.max])) + # else: + + logger.info(f"Dataset: {dataset_name}") + + # Remove rows with n/a values for MagneticFieldStrength + df = df[df['MagneticFieldStrength'] != 'n/a'] + + # Convert MagneticFieldStrength to float + df['MagneticFieldStrength'] = df['MagneticFieldStrength'].astype(float) + + # Print the min and max values of the MagneticFieldStrength, PixDim, and SliceThickness + logger.info(f"\n{df[['MagneticFieldStrength', 'PixDim', 'SliceThickness']].agg(['min', 'max'])}") + + # Print unique values of the Manufacturer and ManufacturerModelName + logger.info(f"\n{df[['Manufacturer', 'ManufacturerModelName']].drop_duplicates()}") + # # Print number of filenames for unique values of the Manufacturer + # print(df.groupby('Manufacturer')['filename'].nunique()) + # # Print number of filenames for unique values of the MagneticFieldStrength + # print(df.groupby('MagneticFieldStrength')['filename'].nunique()) + + logger.info('') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/generate_tables.py b/scripts/generate_tables.py new file mode 100644 index 0000000..9d00b39 --- /dev/null +++ b/scripts/generate_tables.py @@ -0,0 +1,104 @@ +import os, re +import numpy as np +import pandas as pd + + +CONTRASTS = { + "t1map": ["T1map"], + "mp2rage": ["inv-1_part-mag_MP2RAGE", "inv-2_part-mag_MP2RAGE"], + "t1w": ["T1w", "space-other_T1w", "acq-lowresSag_T1w"], + "t2w": ["T2w", "space-other_T2w", "acq-lowresSag_T2w", "acq-highresSag_T2w"], + "t2star": ["T2star", "space-other_T2star"], + "dwi": ["rec-average_dwi", "acq-dwiMean_dwi"], + "mt-on": ["flip-1_mt-on_space-other_MTS", "acq-MTon_MTR"], + "mt-off": ["flip-2_mt-off_space-other_MTS"], + "unit1": ["UNIT1"], + "psir": ["PSIR"], + "stir": ["STIR"] +} + + +def generate_table(df, path_save): + contrast_stats = {} + + # replace 'coronal' with 'axial' in imgOrientation + df.loc[df['imgOrientation'] == 'coronal', 'imgOrientation'] = 'axial' + + # get unique imgOrientations + img_orientations = df['imgOrientation'].unique() + + for contrast in CONTRASTS.keys(): + contrast_stats[contrast] = {} + + for orientation in img_orientations: + contrast_stats[contrast][orientation] = {"n": 0, "spacing_min": [], "spacing_max": [], "size_min": [], "size_max": []} + + # get the number of images with the contrast and orientation + contrast_stats[contrast][orientation]['n'] = len(df[(df['contrastID'] == contrast) & (df['imgOrientation'] == orientation)]) + + if contrast_stats[contrast][orientation]['n'] == 0: + # if there are no images with this contrast and orientation, remove the key + del contrast_stats[contrast][orientation] + continue + + # create a temp list to store the spacings and sizes of the images + all_spacings, all_sizes = [], [] + + for i in range(contrast_stats[contrast][orientation]['n']): + # get the spacings and sizes of the images + all_spacings.append(df[(df['contrastID'] == contrast) & (df['imgOrientation'] == orientation)]['spacing'].iloc[i]) + all_sizes.append(df[(df['contrastID'] == contrast) & (df['imgOrientation'] == orientation)]['shape'].iloc[i]) + + # Convert the list of strings to a numpy array + all_spacings = np.array([np.fromstring(s.strip('[]'), sep=' ') for s in all_spacings]) + all_sizes = np.array([np.fromstring(s.strip('[]'), sep=',') for s in all_sizes]) + + # get the min and max spacings across the respective dimensions + contrast_stats[contrast][orientation]['spacing_min'] = np.min(all_spacings, axis=0) + contrast_stats[contrast][orientation]['spacing_max'] = np.max(all_spacings, axis=0) + + # get the min and max sizes across the respective dimensions + contrast_stats[contrast][orientation]['size_min'] = np.min(all_sizes, axis=0) + contrast_stats[contrast][orientation]['size_max'] = np.max(all_sizes, axis=0) + + # create a dataframe from contrast_stats + df_img_stats = pd.DataFrame.from_dict({(i, j): contrast_stats[i][j] + for i in contrast_stats.keys() + for j in contrast_stats[i].keys()},) + df_img_stats = df_img_stats.T + print(df_img_stats) + + + +def main(datalists_root, contrasts_dict, path_save): + + # create a unified dataframe combining all datasets + csvs = [os.path.join(datalists_root, file) for file in os.listdir(datalists_root) if file.endswith('_seed50.csv')] + unified_df = pd.concat([pd.read_csv(csv) for csv in csvs], ignore_index=True) + + # sort the dataframe by the dataset column + unified_df = unified_df.sort_values(by='datasetName', ascending=True) + + # dropna + unified_df = unified_df.dropna(subset=['pathologyID']) + + contrasts_final = list(contrasts_dict.keys()) + # rename the contrasts column as per contrasts_final + for c in unified_df['contrastID'].unique(): + for cf in contrasts_final: + if re.search(cf, c.lower()): + unified_df.loc[unified_df['contrastID'] == c, 'contrastID'] = cf + break + + # NOTE: MTon-MTR is same as flip-1_mt-on_space-other_MTS, but the naming is not mt-on + # so doing the renaming manually + unified_df.loc[unified_df['contrastID'] == 'acq-MTon_MTR', 'contrastID'] = 'mt-on' + + generate_table(unified_df, path_save) + + + +if __name__ == "__main__": + + datalists_root = "/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/v2-final-aggregation-20241022" + main(datalists_root, CONTRASTS, datalists_root)