Skip to content

Commit

Permalink
Merge pull request #100 from sct-pipeline/nk/simplify-inference
Browse files Browse the repository at this point in the history
Simplify inference script
  • Loading branch information
naga-karthik authored Jan 23, 2024
2 parents 187d8ea + e43165a commit 5aea617
Showing 1 changed file with 9 additions and 37 deletions.
46 changes: 9 additions & 37 deletions monai/run_inference_single_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from time import time

from monai.inferers import sliding_window_inference
from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch)
from monai.data import (DataLoader, Dataset, decollate_batch)
from monai.transforms import (Compose, EnsureTyped, Invertd, SaveImage, Spacingd,
LoadImaged, NormalizeIntensityd, EnsureChannelFirstd,
DivisiblePadd, Orientationd, ResizeWithPadOrCropd)
Expand Down Expand Up @@ -177,33 +177,9 @@ def create_nnunet_from_plans(plans, num_input_channels: int, num_classes: int, d
# ===========================================================================
# Prepare temporary dataset for inference
# ===========================================================================
def prepare_data(path_image, path_out, crop_size=(64, 160, 320)):

# create a temporary datalist containing the image
# boiler plate keys to be defined in the MSD-style datalist
params = {}
params["description"] = "my-awesome-SC-image"
params["labels"] = {
"0": "background",
"1": "soft-sc-seg"
}
params["modality"] = {
"0": "MRI"
}
params["tensorImageSize"] = "3D"
params["test"] = [
{
"image": path_image
}
]
def prepare_data(path_image, crop_size=(64, 160, 320)):

final_json = json.dumps(params, indent=4, sort_keys=True)
jsonFile = open(path_out + "/" + f"temp_msd_datalist.json", "w")
jsonFile.write(final_json)
jsonFile.close()

dataset = os.path.join(path_out, f"temp_msd_datalist.json")
test_files = load_decathlon_datalist(dataset, True, "test")
test_file = [{"image": path_image}]

# define test transforms
transforms_test = inference_transforms_single_image(crop_size=crop_size)
Expand All @@ -217,7 +193,7 @@ def prepare_data(path_image, path_out, crop_size=(64, 160, 320)):
meta_keys=["pred_meta_dict"],
nearest_interp=False, to_tensor=True),
])
test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.75, num_workers=8)
test_ds = Dataset(data=test_file, transform=transforms_test)

return test_ds, test_post_pred

Expand Down Expand Up @@ -254,7 +230,7 @@ def main():
inference_roi_size = (64, 192, 320)

# define the dataset and dataloader
test_ds, test_post_pred = prepare_data(path_image, results_path, crop_size=crop_size)
test_ds, test_post_pred = prepare_data(path_image, crop_size=crop_size)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

# define model
Expand Down Expand Up @@ -305,13 +281,10 @@ def main():

pred = post_test_out[0]['pred'].cpu()

# clip the prediction between 0.5 and 1
# turns out this sets the background to 0.5 and the SC to 1 (which is not correct)
# details: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/71
pred = torch.clamp(pred, 0.5, 1)
# set background values to 0
pred[pred <= 0.5] = 0

# binarize the prediction with a threshold of 0.5
pred[pred >= 0.5] = 1
pred[pred < 0.5] = 0

# get subject name
subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "")
logger.info(f"Saving subject: {subject_name}")
Expand Down Expand Up @@ -355,7 +328,6 @@ def main():
# free up memory
test_step_outputs.clear()
test_summary.clear()
os.remove(os.path.join(results_path, "temp_msd_datalist.json"))


if __name__ == "__main__":
Expand Down

0 comments on commit 5aea617

Please sign in to comment.