Skip to content

Commit

Permalink
Match GTs and predictions based on BIDS compatible keys (#17)
Browse files Browse the repository at this point in the history
* Match the prediction and reference files based on the participant_id, acq_id, and run_id.

* remove python-app.yml from the original MetricsReloaded repo

* `get_images_in_folder` --> `get_images`

* handle no predictions/GTs

* add unittets to test the newly proposed matching based on participant_id, acq_id, and run_id

* add clarifying comment

* add session-based pairing between GT-pred

* fetch chunk id also

* fix import after changing function name

* update tests with ses_id and chunk_id

---------

Co-authored-by: Naga Karthik <[email protected]>
  • Loading branch information
valosekj and naga-karthik authored Dec 11, 2024
1 parent 76dbb55 commit b3f354e
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 100 deletions.
87 changes: 0 additions & 87 deletions .github/workflows/python-app.yml

This file was deleted.

77 changes: 65 additions & 12 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
python compute_metrics_reloaded.py
-reference /path/to/reference
-prediction /path/to/prediction
NOTE: The prediction and reference files are matched based on the participant_id, acq_id, and run_id.
The metrics to be computed can be specified using the `-metrics` argument. For example, to compute only the Dice
similarity coefficient (DSC) and Normalized surface distance (NSD), use:
Expand All @@ -37,6 +38,7 @@


import os
import re
import argparse
import numpy as np
import nibabel as nib
Expand Down Expand Up @@ -103,25 +105,76 @@ def load_nifti_image(file_path):
return nifti_image.get_fdata()


def get_images_in_folder(prediction, reference):
def fetch_bids_compatible_keys(filename_path, prefix='sub-'):
"""
Get all files (predictions and references/ground truths) in the input directories
Get participant_id, session_id, acq_id, chunk_id and run_id from the input BIDS-compatible filename or file path
The function works both on absolute file paths as well as filenames
:param filename_path: input nifti filename (e.g., sub-001_ses-01_T1w.nii.gz) or file path
:param prefix: prefix of the participant ID in the filename (default: 'sub-')
(e.g., /home/user/bids/sub-001/ses-01/anat/sub-001_ses-01_T1w.nii.gz
:return: participant_id: participant ID (e.g., sub-001)
:return: session_id: session ID (e.g., ses-01)
:return: acq_id: acquisition ID (e.g., acq-01)
:return: chunk_id: chunk ID (e.g., chunk-1)
:return: run_id: run ID (e.g., run-01)
"""

participant = re.search(f'{prefix}(.*?)[_/]', filename_path) # [_/] means either underscore or slash
participant_id = participant.group(0)[:-1] if participant else "" # [:-1] removes the last underscore or slash

session = re.search('ses-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
session_id = session.group(0)[:-1] if session else "" # [:-1] removes the last underscore or slash

acquisition = re.search('acq-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
acq_id = acquisition.group(0)[:-1] if acquisition else "" # [:-1] removes the last underscore or slash

chunk = re.search('chunk-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
chunk_id = chunk.group(0)[:-1] if chunk else "" # [:-1] removes the last underscore or slash

run = re.search('run-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
run_id = run.group(0)[:-1] if run else "" # [:-1] removes the last underscore or slash

# REGEX explanation
# . - match any character (except newline)
# *? - match the previous element as few times as possible (zero or more times)

return participant_id, session_id, acq_id, chunk_id, run_id


def get_images(prediction, reference):
"""
Get all files (predictions and references/ground truths) in the input directories.
The prediction and reference files are matched based on the participant_id, acq_id, and run_id.
:param prediction: path to the directory with prediction files
:param reference: path to the directory with reference (ground truth) files
:return: list of prediction files, list of reference/ground truth files
"""
# Get all files in the directories
prediction_files = [os.path.join(prediction, f) for f in os.listdir(prediction) if f.endswith('.nii.gz')]
reference_files = [os.path.join(reference, f) for f in os.listdir(reference) if f.endswith('.nii.gz')]
# Check if the number of files in the directories is the same
if len(prediction_files) != len(reference_files):
raise ValueError(f'The number of files in the directories is different. '
f'Prediction files: {len(prediction_files)}, Reference files: {len(reference_files)}')
print(f'Found {len(prediction_files)} files in the directories.')
# Sort the files
# NOTE: Hopefully, the files are named in the same order in both directories
prediction_files.sort()
reference_files.sort()

if not prediction_files:
raise FileNotFoundError(f'No prediction files found in {prediction}.')
if not reference_files:
raise FileNotFoundError(f'No reference (ground truths) files found in {reference}.')

# Create dataframe for prediction_files with participant_id, acq_id, run_id
df_pred = pd.DataFrame(prediction_files, columns=['filename'])
df_pred['participant_id'], df_pred['session_id'], df_pred['acq_id'], df_pred['chunk_id'], df_pred['run_id'] = zip(*df_pred['filename'].apply(fetch_bids_compatible_keys))

# Create dataframe for reference_files with participant_id, acq_id, run_id
df_ref = pd.DataFrame(reference_files, columns=['filename'])
df_ref['participant_id'], df_ref['session_id'], df_ref['acq_id'], df_ref['chunk_id'], df_ref['run_id'] = zip(*df_ref['filename'].apply(fetch_bids_compatible_keys))

# Merge the two dataframes on participant_id, acq_id, run_id
df = pd.merge(df_pred, df_ref, on=['participant_id', 'session_id', 'acq_id', 'chunk_id', 'run_id'], how='outer', suffixes=('_pred', '_ref'))
# Drop 'participant_id', 'acq_id', 'run_id'
df.drop(['participant_id', 'session_id', 'acq_id', 'chunk_id', 'run_id'], axis=1, inplace=True)
# Drop rows with NaN values. In other words, keep only the rows where both prediction and reference files exist
df.dropna(inplace=True)

prediction_files = df['filename_pred'].tolist()
reference_files = df['filename_ref'].tolist()

return prediction_files, reference_files

Expand Down Expand Up @@ -236,7 +289,7 @@ def main():
# Args.prediction and args.reference are paths to folders with multiple nii.gz files (i.e., MULTIPLE subjects)
if os.path.isdir(args.prediction) and os.path.isdir(args.reference):
# Get all files in the directories
prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference)
prediction_files, reference_files = get_images(args.prediction, args.reference)

# Use multiprocessing to parallelize the computation
with Pool(args.jobs) as pool:
Expand Down
127 changes: 126 additions & 1 deletion test/test_metrics/test_pairwise_measures_neuropoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import numpy as np
import nibabel as nib
from compute_metrics_reloaded import compute_metrics_single_subject
from compute_metrics_reloaded import compute_metrics_single_subject, get_images, fetch_bids_compatible_keys
import tempfile

METRICS = ['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error', 'lesion_ppv', 'lesion_sensitivity', 'lesion_f1_score',
Expand Down Expand Up @@ -358,6 +358,131 @@ def test_non_empty_ref_and_pred_with_full_overlap(self):
# Assert metrics
self.assert_metrics(metrics_dict, expected_metrics)

class TestGetImages(unittest.TestCase):
def setUp(self):
"""
Create temporary directories and files for testing.
"""
self.pred_dir = tempfile.TemporaryDirectory()
self.ref_dir = tempfile.TemporaryDirectory()

def tearDown(self):
"""
Cleanup temporary directories and files after tests.
"""
self.pred_dir.cleanup()
self.ref_dir.cleanup()

def create_temp_file(self, directory, filename):
"""
Create a temporary file in the given directory with the specified filename.
"""
file_path = os.path.join(directory, filename)
with open(file_path, 'w') as f:
f.write('dummy content')
return file_path

def test_matching_files(self):
"""
Test matching files based on participant_id, acq_id, and run_id.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_acq-01_chunk-1_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)

def test_mismatched_files(self):
"""
Test when no files match based on the criteria.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-02_ses-01_acq-02_chunk-1_run-02_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 0)
self.assertEqual(len(ref_files), 0)

def test_ses_id_empty(self):
"""
Test when ses_id is empty.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_chunk-1_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_acq-01_chunk-1_run-01_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_acq-01_chunk-1_run-01_ref.nii.gz", ref_files[0])

def test_acq_id_empty(self):
"""
Test when acq_id is empty.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_chunk-1_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_ses-01_chunk-1_run-01_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_ses-01_chunk-1_run-01_ref.nii.gz", ref_files[0])

def test_chunk_id_empty(self):
"""
Test when chunk_id is empty in the filenames.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_acq-01_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)

# Assert the matched files
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_ses-01_acq-01_run-01_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_ses-01_acq-01_run-01_ref.nii.gz", ref_files[0])

def test_run_id_empty(self):
"""
Test when run_id is empty in the filenames.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_chunk-1_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_acq-01_chunk-1_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)

# Assert the matched files
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_ses-01_acq-01_chunk-1_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_ses-01_acq-01_chunk-1_ref.nii.gz", ref_files[0])

def test_no_files(self):
"""
Test when there are no files in the directories.
Ensure that FileNotFoundError is raised.
"""
with self.assertRaises(FileNotFoundError) as context:
get_images(self.pred_dir.name, self.ref_dir.name)
# Check the exception message
self.assertIn(f'No prediction files found in {self.pred_dir.name}', str(context.exception))

def test_partial_matching(self):
"""
Test when some files match and some do not.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_run-01_ref.nii.gz")
# The following file will not be included in the lists below as there is no matching reference (GT) file
self.create_temp_file(self.pred_dir.name, "sub-02_acq-02_run-02_pred.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit b3f354e

Please sign in to comment.