Skip to content

Commit

Permalink
Merge pull request #99 from DCAN-Labs/expand-lr-mask
Browse files Browse the repository at this point in the history
Iterative chirality correction
  • Loading branch information
tjhendrickson authored Feb 19, 2024
2 parents bfaf8d6 + 0a80704 commit 9cafd56
Showing 1 changed file with 146 additions and 38 deletions.
184 changes: 146 additions & 38 deletions src/postbibsnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import json
from scipy import ndimage
import csv

from src.logger import LOGGER

Expand Down Expand Up @@ -55,6 +56,13 @@ def run_postBIBSnet(j_args):
else: # if j_args["ID"]["has_T2w"]:
t1or2 = 2

LOGGER.info("Generating crude L/R mask for first iteration of chirality correction")
# Generate crude chirality correction mask file first
crude_left_right_mask_nifti_fpath = create_crude_LR_mask(
sub_ses, j_args
)

LOGGER.info("Generating L/R mask from registration using templates for second iteration of chirality correction")
# Run left/right registration script and chirality correction
left_right_mask_nifti_fpath = run_left_right_registration(
sub_ses, tmpl_age, t1or2, j_args
Expand All @@ -68,7 +76,11 @@ def run_postBIBSnet(j_args):
left_right_mask_nifti_fpath
)
LOGGER.info("Finished dilating left/right segmentation mask")
nifti_file_paths, chiral_out_dir, xfm_ref_img_dict = run_correct_chirality(dilated_LRmask_fpath, j_args)

LOGGER.info("Running chirality correction")
nifti_file_paths, chiral_out_dir, xfm_ref_img_dict = run_correct_chirality(crude_left_right_mask_nifti_fpath, dilated_LRmask_fpath, j_args)

LOGGER.info("Reverting corrected segmentation to native space")
for t in only_Ts_needed_for_bibsnet_model(j_args["ID"]):
nii_outfpath = reverse_regn_revert_to_native(
nifti_file_paths, chiral_out_dir, xfm_ref_img_dict[t], t, j_args
Expand Down Expand Up @@ -110,9 +122,13 @@ def run_postBIBSnet(j_args):

return j_args

# Write j_args out to logs
LOGGER.debug(j_args)

def run_correct_chirality(l_r_mask_nifti_fpath, j_args):
def run_correct_chirality(crude_l_r_mask_nifti_fpath, l_r_mask_nifti_fpath, j_args):
"""
:param crude_l_r_mask_nifti_fpath: String, valid path to existing crude left/right
output mask file
:param l_r_mask_nifti_fpath: String, valid path to existing left/right
registration output mask file
:param j_args: Dictionary containing all args
Expand Down Expand Up @@ -149,17 +165,69 @@ def run_correct_chirality(l_r_mask_nifti_fpath, j_args):
chiral_ref_img_fpaths.sort()
chiral_ref_img_fpaths_dict[t] = chiral_ref_img_fpaths[0]

# Run chirality correction script and return the image to native space
# Run chirality correction first using the crude LR mask applied to the segmentation output from nnUNet in the BIBSNet stage
msg = "{} running chirality correction on " + seg_BIBSnet_outfiles[0]
LOGGER.info(msg.format("Now"))
nii_fpaths = correct_chirality(
seg_BIBSnet_outfiles[0], segment_lookup_table_path,
l_r_mask_nifti_fpath, chiral_out_dir
crude_l_r_mask_nifti_fpath, chiral_out_dir, 1
)

# Run chirality correction a second time using the refined LR mask generated from registration with template files applied to the segmentation corrected with the crude LR mask
msg = "{} running chirality correction on " + nii_fpaths["crudecorrected"]
LOGGER.info(msg.format("Now"))
nii_fpaths = correct_chirality(
nii_fpaths["crudecorrected"], segment_lookup_table_path,
l_r_mask_nifti_fpath, chiral_out_dir, 2
)

LOGGER.info(msg.format("Finished"))

return nii_fpaths, chiral_out_dir, chiral_ref_img_fpaths_dict

def create_crude_LR_mask(sub_ses, j_args):
# Define paths to dirs/files used in chirality correction script
outdir_LR_reg = os.path.join(j_args["optional_out_dirs"]["postbibsnet"],
*sub_ses)
os.makedirs(outdir_LR_reg, exist_ok=True)

chiral_out_dir = os.path.join(j_args["optional_out_dirs"]["postbibsnet"],
*sub_ses, "chirality_correction") # subj_ID, session,
os.makedirs(chiral_out_dir, exist_ok=True)

# Get BIBSnet output file, and if there are multiple, then raise an error
out_BIBSnet_seg = os.path.join(j_args["optional_out_dirs"]["bibsnet"],
*sub_ses, "output", "*.nii.gz")
seg_BIBSnet_outfiles = glob(out_BIBSnet_seg)
if len(seg_BIBSnet_outfiles) != 1:
LOGGER.error(f"There must be exactly one BIBSnet segmentation file: "
"{}\nResume at postBIBSnet stage once this is fixed."
.format(out_BIBSnet_seg))
sys.exit()

crude_left_right_mask_nifti_fpath = os.path.join(outdir_LR_reg, "crude_LRmask.nii.gz")

img = nib.load(seg_BIBSnet_outfiles[0])
data = img.get_fdata()
affine = img.affine

# Determine the midpoint of X-axis and make new image
midpoint_x = data.shape[0] // 2
modified_data = np.zeros_like(data)

# Assign value 1 to right-side voxels with values greater than 0 value 2 to left-side voxels with values greater than 0 (note that these actually correspond to left and right brain hemispheres respectively)
modified_data[midpoint_x:, :, :][data[midpoint_x:, :, :] > 0] = 1
modified_data[:midpoint_x, :, :][data[:midpoint_x, :, :] > 0] = 2

#nib.save(img, seg_BIBSnet_outfiles[0])
save_nifti(modified_data, affine, crude_left_right_mask_nifti_fpath)

return crude_left_right_mask_nifti_fpath

def save_nifti(data, affine, file_path):
img = nib.Nifti1Image(data, affine)
nib.save(img, file_path)


def run_left_right_registration(sub_ses, age_months, t1or2, j_args):
"""
Expand Down Expand Up @@ -272,7 +340,7 @@ def copy_to_derivatives_dir(file_to_copy, derivs_dir, sub_ses, space, new_fname_


def correct_chirality(nifti_input_file_path, segment_lookup_table,
nii_fpath_LR_mask, chiral_out_dir):
nii_fpath_LR_mask, chiral_out_dir, iteration):
"""
Creates an output file with chirality corrections fixed.
:param nifti_input_file_path: String, path to a segmentation file with
Expand All @@ -283,42 +351,82 @@ def correct_chirality(nifti_input_file_path, segment_lookup_table,
:param xfm_ref_img: String, path to (T1w, unless running in T2w-only mode)
image to use as a reference when applying transform
:param j_args: Dictionary containing all args
:param iteration: either 1 or 2 for iteration1 or iteration2 of chirality correction
:return: Dict with paths to native and chirality-corrected images
"""
nifti_file_paths = dict()
for which_nii in ("native-T1", "native-T2", "corrected"):
nifti_file_paths[which_nii] = os.path.join(chiral_out_dir, "_".join((
which_nii, os.path.basename(nifti_input_file_path)
)))
if iteration==1:
nifti_file_paths = dict()
for which_nii in ("native-T1", "native-T2", "crudecorrected"):
nifti_file_paths[which_nii] = os.path.join(chiral_out_dir, "_".join((
which_nii, os.path.basename(nifti_input_file_path)
)))

free_surfer_label_to_region = get_id_to_region_mapping(segment_lookup_table)
segment_name_to_number = {v: k for k, v in free_surfer_label_to_region.items()}
img = nib.load(nifti_input_file_path)
data = img.get_data()
left_right_img = nib.load(nii_fpath_LR_mask)
left_right_data = left_right_img.get_data()

new_data = data.copy()
data_shape = img.header.get_data_shape()
left_right_data_shape = left_right_img.header.get_data_shape()
width = data_shape[0]
height = data_shape[1]
depth = data_shape[2]
assert \
width == left_right_data_shape[0] and height == left_right_data_shape[1] and depth == left_right_data_shape[2]
for i in range(width):
for j in range(height):
for k in range(depth):
voxel = data[i][j][k]
region = free_surfer_label_to_region[voxel]
chirality_voxel = int(left_right_data[i][j][k])
if not (region.startswith(LEFT) or region.startswith(RIGHT)):
continue
if chirality_voxel == CHIRALITY_CONST["LEFT"] or chirality_voxel == CHIRALITY_CONST["RIGHT"]:
check_and_correct_region(
chirality_voxel == CHIRALITY_CONST["LEFT"], region, segment_name_to_number, new_data, i, j, k)
fixed_img = nib.Nifti1Image(new_data, img.affine, img.header)
nib.save(fixed_img, nifti_file_paths["crudecorrected"])

elif iteration==2:
# Drop "crudecorrected_" from nifti_input_file_path to make filenames cleaner
nifti_input_file_path_mod=(os.path.basename(nifti_input_file_path)).split('_', 1)[1]
nifti_file_paths = dict()
for which_nii in ("native-T1", "native-T2", "corrected"):
nifti_file_paths[which_nii] = os.path.join(chiral_out_dir, "_".join((
which_nii, nifti_input_file_path_mod
)))

free_surfer_label_to_region = get_id_to_region_mapping(segment_lookup_table)
segment_name_to_number = {v: k for k, v in free_surfer_label_to_region.items()}
img = nib.load(nifti_input_file_path)
data = img.get_data()
left_right_img = nib.load(nii_fpath_LR_mask)
left_right_data = left_right_img.get_data()

new_data = data.copy()
data_shape = img.header.get_data_shape()
left_right_data_shape = left_right_img.header.get_data_shape()
width = data_shape[0]
height = data_shape[1]
depth = data_shape[2]
assert \
width == left_right_data_shape[0] and height == left_right_data_shape[1] and depth == left_right_data_shape[2]
for i in range(width):
for j in range(height):
for k in range(depth):
voxel = data[i][j][k]
region = free_surfer_label_to_region[voxel]
chirality_voxel = int(left_right_data[i][j][k])
if not (region.startswith(LEFT) or region.startswith(RIGHT)):
continue
if chirality_voxel == CHIRALITY_CONST["LEFT"] or chirality_voxel == CHIRALITY_CONST["RIGHT"]:
check_and_correct_region(
chirality_voxel == CHIRALITY_CONST["LEFT"], region, segment_name_to_number, new_data, i, j, k)
fixed_img = nib.Nifti1Image(new_data, img.affine, img.header)
nib.save(fixed_img, nifti_file_paths["corrected"])
free_surfer_label_to_region = get_id_to_region_mapping(segment_lookup_table)
segment_name_to_number = {v: k for k, v in free_surfer_label_to_region.items()}
img = nib.load(nifti_input_file_path)
data = img.get_data()
left_right_img = nib.load(nii_fpath_LR_mask)
left_right_data = left_right_img.get_data()

new_data = data.copy()
data_shape = img.header.get_data_shape()
left_right_data_shape = left_right_img.header.get_data_shape()
width = data_shape[0]
height = data_shape[1]
depth = data_shape[2]
assert \
width == left_right_data_shape[0] and height == left_right_data_shape[1] and depth == left_right_data_shape[2]
for i in range(width):
for j in range(height):
for k in range(depth):
voxel = data[i][j][k]
region = free_surfer_label_to_region[voxel]
chirality_voxel = int(left_right_data[i][j][k])
if not (region.startswith(LEFT) or region.startswith(RIGHT)):
continue
if chirality_voxel == CHIRALITY_CONST["LEFT"] or chirality_voxel == CHIRALITY_CONST["RIGHT"]:
check_and_correct_region(
chirality_voxel == CHIRALITY_CONST["LEFT"], region, segment_name_to_number, new_data, i, j, k)
fixed_img = nib.Nifti1Image(new_data, img.affine, img.header)
nib.save(fixed_img, nifti_file_paths["corrected"])
return nifti_file_paths


Expand Down

0 comments on commit 9cafd56

Please sign in to comment.