diff --git a/src/postbibsnet.py b/src/postbibsnet.py index 8cccaaa..9c7460e 100755 --- a/src/postbibsnet.py +++ b/src/postbibsnet.py @@ -8,6 +8,7 @@ import numpy as np import json from scipy import ndimage +import csv from src.logger import LOGGER @@ -55,6 +56,13 @@ def run_postBIBSnet(j_args): else: # if j_args["ID"]["has_T2w"]: t1or2 = 2 + LOGGER.info("Generating crude L/R mask for first iteration of chirality correction") + # Generate crude chirality correction mask file first + crude_left_right_mask_nifti_fpath = create_crude_LR_mask( + sub_ses, j_args + ) + + LOGGER.info("Generating L/R mask from registration using templates for second iteration of chirality correction") # Run left/right registration script and chirality correction left_right_mask_nifti_fpath = run_left_right_registration( sub_ses, tmpl_age, t1or2, j_args @@ -68,7 +76,11 @@ def run_postBIBSnet(j_args): left_right_mask_nifti_fpath ) LOGGER.info("Finished dilating left/right segmentation mask") - nifti_file_paths, chiral_out_dir, xfm_ref_img_dict = run_correct_chirality(dilated_LRmask_fpath, j_args) + + LOGGER.info("Running chirality correction") + nifti_file_paths, chiral_out_dir, xfm_ref_img_dict = run_correct_chirality(crude_left_right_mask_nifti_fpath, dilated_LRmask_fpath, j_args) + + LOGGER.info("Reverting corrected segmentation to native space") for t in only_Ts_needed_for_bibsnet_model(j_args["ID"]): nii_outfpath = reverse_regn_revert_to_native( nifti_file_paths, chiral_out_dir, xfm_ref_img_dict[t], t, j_args @@ -110,9 +122,13 @@ def run_postBIBSnet(j_args): return j_args + # Write j_args out to logs + LOGGER.debug(j_args) -def run_correct_chirality(l_r_mask_nifti_fpath, j_args): +def run_correct_chirality(crude_l_r_mask_nifti_fpath, l_r_mask_nifti_fpath, j_args): """ + :param crude_l_r_mask_nifti_fpath: String, valid path to existing crude left/right + output mask file :param l_r_mask_nifti_fpath: String, valid path to existing left/right registration output mask file :param j_args: Dictionary containing all args @@ -149,17 +165,69 @@ def run_correct_chirality(l_r_mask_nifti_fpath, j_args): chiral_ref_img_fpaths.sort() chiral_ref_img_fpaths_dict[t] = chiral_ref_img_fpaths[0] - # Run chirality correction script and return the image to native space + # Run chirality correction first using the crude LR mask applied to the segmentation output from nnUNet in the BIBSNet stage msg = "{} running chirality correction on " + seg_BIBSnet_outfiles[0] LOGGER.info(msg.format("Now")) nii_fpaths = correct_chirality( seg_BIBSnet_outfiles[0], segment_lookup_table_path, - l_r_mask_nifti_fpath, chiral_out_dir + crude_l_r_mask_nifti_fpath, chiral_out_dir, 1 + ) + + # Run chirality correction a second time using the refined LR mask generated from registration with template files applied to the segmentation corrected with the crude LR mask + msg = "{} running chirality correction on " + nii_fpaths["crudecorrected"] + LOGGER.info(msg.format("Now")) + nii_fpaths = correct_chirality( + nii_fpaths["crudecorrected"], segment_lookup_table_path, + l_r_mask_nifti_fpath, chiral_out_dir, 2 ) + LOGGER.info(msg.format("Finished")) return nii_fpaths, chiral_out_dir, chiral_ref_img_fpaths_dict +def create_crude_LR_mask(sub_ses, j_args): + # Define paths to dirs/files used in chirality correction script + outdir_LR_reg = os.path.join(j_args["optional_out_dirs"]["postbibsnet"], + *sub_ses) + os.makedirs(outdir_LR_reg, exist_ok=True) + + chiral_out_dir = os.path.join(j_args["optional_out_dirs"]["postbibsnet"], + *sub_ses, "chirality_correction") # subj_ID, session, + os.makedirs(chiral_out_dir, exist_ok=True) + + # Get BIBSnet output file, and if there are multiple, then raise an error + out_BIBSnet_seg = os.path.join(j_args["optional_out_dirs"]["bibsnet"], + *sub_ses, "output", "*.nii.gz") + seg_BIBSnet_outfiles = glob(out_BIBSnet_seg) + if len(seg_BIBSnet_outfiles) != 1: + LOGGER.error(f"There must be exactly one BIBSnet segmentation file: " + "{}\nResume at postBIBSnet stage once this is fixed." + .format(out_BIBSnet_seg)) + sys.exit() + + crude_left_right_mask_nifti_fpath = os.path.join(outdir_LR_reg, "crude_LRmask.nii.gz") + + img = nib.load(seg_BIBSnet_outfiles[0]) + data = img.get_fdata() + affine = img.affine + + # Determine the midpoint of X-axis and make new image + midpoint_x = data.shape[0] // 2 + modified_data = np.zeros_like(data) + + # Assign value 1 to right-side voxels with values greater than 0 value 2 to left-side voxels with values greater than 0 (note that these actually correspond to left and right brain hemispheres respectively) + modified_data[midpoint_x:, :, :][data[midpoint_x:, :, :] > 0] = 1 + modified_data[:midpoint_x, :, :][data[:midpoint_x, :, :] > 0] = 2 + + #nib.save(img, seg_BIBSnet_outfiles[0]) + save_nifti(modified_data, affine, crude_left_right_mask_nifti_fpath) + + return crude_left_right_mask_nifti_fpath + +def save_nifti(data, affine, file_path): + img = nib.Nifti1Image(data, affine) + nib.save(img, file_path) + def run_left_right_registration(sub_ses, age_months, t1or2, j_args): """ @@ -272,7 +340,7 @@ def copy_to_derivatives_dir(file_to_copy, derivs_dir, sub_ses, space, new_fname_ def correct_chirality(nifti_input_file_path, segment_lookup_table, - nii_fpath_LR_mask, chiral_out_dir): + nii_fpath_LR_mask, chiral_out_dir, iteration): """ Creates an output file with chirality corrections fixed. :param nifti_input_file_path: String, path to a segmentation file with @@ -283,42 +351,82 @@ def correct_chirality(nifti_input_file_path, segment_lookup_table, :param xfm_ref_img: String, path to (T1w, unless running in T2w-only mode) image to use as a reference when applying transform :param j_args: Dictionary containing all args + :param iteration: either 1 or 2 for iteration1 or iteration2 of chirality correction :return: Dict with paths to native and chirality-corrected images """ - nifti_file_paths = dict() - for which_nii in ("native-T1", "native-T2", "corrected"): - nifti_file_paths[which_nii] = os.path.join(chiral_out_dir, "_".join(( - which_nii, os.path.basename(nifti_input_file_path) - ))) + if iteration==1: + nifti_file_paths = dict() + for which_nii in ("native-T1", "native-T2", "crudecorrected"): + nifti_file_paths[which_nii] = os.path.join(chiral_out_dir, "_".join(( + which_nii, os.path.basename(nifti_input_file_path) + ))) + + free_surfer_label_to_region = get_id_to_region_mapping(segment_lookup_table) + segment_name_to_number = {v: k for k, v in free_surfer_label_to_region.items()} + img = nib.load(nifti_input_file_path) + data = img.get_data() + left_right_img = nib.load(nii_fpath_LR_mask) + left_right_data = left_right_img.get_data() + + new_data = data.copy() + data_shape = img.header.get_data_shape() + left_right_data_shape = left_right_img.header.get_data_shape() + width = data_shape[0] + height = data_shape[1] + depth = data_shape[2] + assert \ + width == left_right_data_shape[0] and height == left_right_data_shape[1] and depth == left_right_data_shape[2] + for i in range(width): + for j in range(height): + for k in range(depth): + voxel = data[i][j][k] + region = free_surfer_label_to_region[voxel] + chirality_voxel = int(left_right_data[i][j][k]) + if not (region.startswith(LEFT) or region.startswith(RIGHT)): + continue + if chirality_voxel == CHIRALITY_CONST["LEFT"] or chirality_voxel == CHIRALITY_CONST["RIGHT"]: + check_and_correct_region( + chirality_voxel == CHIRALITY_CONST["LEFT"], region, segment_name_to_number, new_data, i, j, k) + fixed_img = nib.Nifti1Image(new_data, img.affine, img.header) + nib.save(fixed_img, nifti_file_paths["crudecorrected"]) + + elif iteration==2: + # Drop "crudecorrected_" from nifti_input_file_path to make filenames cleaner + nifti_input_file_path_mod=(os.path.basename(nifti_input_file_path)).split('_', 1)[1] + nifti_file_paths = dict() + for which_nii in ("native-T1", "native-T2", "corrected"): + nifti_file_paths[which_nii] = os.path.join(chiral_out_dir, "_".join(( + which_nii, nifti_input_file_path_mod + ))) - free_surfer_label_to_region = get_id_to_region_mapping(segment_lookup_table) - segment_name_to_number = {v: k for k, v in free_surfer_label_to_region.items()} - img = nib.load(nifti_input_file_path) - data = img.get_data() - left_right_img = nib.load(nii_fpath_LR_mask) - left_right_data = left_right_img.get_data() - - new_data = data.copy() - data_shape = img.header.get_data_shape() - left_right_data_shape = left_right_img.header.get_data_shape() - width = data_shape[0] - height = data_shape[1] - depth = data_shape[2] - assert \ - width == left_right_data_shape[0] and height == left_right_data_shape[1] and depth == left_right_data_shape[2] - for i in range(width): - for j in range(height): - for k in range(depth): - voxel = data[i][j][k] - region = free_surfer_label_to_region[voxel] - chirality_voxel = int(left_right_data[i][j][k]) - if not (region.startswith(LEFT) or region.startswith(RIGHT)): - continue - if chirality_voxel == CHIRALITY_CONST["LEFT"] or chirality_voxel == CHIRALITY_CONST["RIGHT"]: - check_and_correct_region( - chirality_voxel == CHIRALITY_CONST["LEFT"], region, segment_name_to_number, new_data, i, j, k) - fixed_img = nib.Nifti1Image(new_data, img.affine, img.header) - nib.save(fixed_img, nifti_file_paths["corrected"]) + free_surfer_label_to_region = get_id_to_region_mapping(segment_lookup_table) + segment_name_to_number = {v: k for k, v in free_surfer_label_to_region.items()} + img = nib.load(nifti_input_file_path) + data = img.get_data() + left_right_img = nib.load(nii_fpath_LR_mask) + left_right_data = left_right_img.get_data() + + new_data = data.copy() + data_shape = img.header.get_data_shape() + left_right_data_shape = left_right_img.header.get_data_shape() + width = data_shape[0] + height = data_shape[1] + depth = data_shape[2] + assert \ + width == left_right_data_shape[0] and height == left_right_data_shape[1] and depth == left_right_data_shape[2] + for i in range(width): + for j in range(height): + for k in range(depth): + voxel = data[i][j][k] + region = free_surfer_label_to_region[voxel] + chirality_voxel = int(left_right_data[i][j][k]) + if not (region.startswith(LEFT) or region.startswith(RIGHT)): + continue + if chirality_voxel == CHIRALITY_CONST["LEFT"] or chirality_voxel == CHIRALITY_CONST["RIGHT"]: + check_and_correct_region( + chirality_voxel == CHIRALITY_CONST["LEFT"], region, segment_name_to_number, new_data, i, j, k) + fixed_img = nib.Nifti1Image(new_data, img.affine, img.header) + nib.save(fixed_img, nifti_file_paths["corrected"]) return nifti_file_paths