diff --git a/tests/test_compute_metrics_reloaded.py b/tests/test_compute_metrics_reloaded.py index 16afe45..c45618f 100644 --- a/tests/test_compute_metrics_reloaded.py +++ b/tests/test_compute_metrics_reloaded.py @@ -37,14 +37,16 @@ def tearDown(self): def assert_metrics(self, metrics_dict, expected_metrics): for metric in self.metrics: - # if value is nan, use np.isnan to check - if np.isnan(expected_metrics[metric]): - self.assertTrue(np.isnan(metrics_dict[1][metric])) - # if value is inf, use np.isinf to check - elif np.isinf(expected_metrics[metric]): - self.assertTrue(np.isinf(metrics_dict[1][metric])) - else: - self.assertAlmostEqual(metrics_dict[1][metric], expected_metrics[metric]) + # Loop over labels/classes (e.g., 1, 2, ...) + for label in expected_metrics.keys(): + # if value is nan, use np.isnan to check + if np.isnan(expected_metrics[label][metric]): + self.assertTrue(np.isnan(metrics_dict[label][metric])) + # if value is inf, use np.isinf to check + elif np.isinf(expected_metrics[label][metric]): + self.assertTrue(np.isinf(metrics_dict[label][metric])) + else: + self.assertAlmostEqual(metrics_dict[label][metric], expected_metrics[label][metric]) def test_empty_ref_and_pred(self): """ @@ -142,6 +144,42 @@ def test_non_empty_ref_and_pred(self): # Assert metrics self.assert_metrics(metrics_dict, expected_metrics) + def test_non_empty_ref_and_pred_multi_class(self): + """ + Non-empty reference and non-empty prediction with partial overlap + Multi-class (i.e., voxels with values 1 and 2, e.g., region-based nnUNet training) + """ + + expected_metrics = {1.0: {'dsc': 0.25, + 'fbeta': 0.2500000055879354, + 'nsd': 0.5, + 'vol_diff': 2.0, + 'rel_vol_error': 200.0, + 'EmptyRef': False, + 'EmptyPred': False}, + 2.0: {'dsc': 0.26666666666666666, + 'fbeta': 0.26666667461395266, + 'nsd': 0.5373134328358209, + 'vol_diff': 3.0, + 'rel_vol_error': 300.0, + 'EmptyRef': False, + 'EmptyPred': False}} + + # Create non-empty reference + ref = np.zeros((10, 10, 10)) + ref[4:5, 3:10] = 1 + ref[4:5, 3:6] = 2 # e.g., lesion within spinal cord + self.create_dummy_nii(self.ref_file, ref) + # Create non-empty prediction + pred = np.zeros((10, 10, 10)) + pred[4:8, 2:8] = 1 + pred[4:8, 2:5] = 2 # e.g., lesion within spinal cord + self.create_dummy_nii(self.pred_file, pred) + # Compute metrics + metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) + # Assert metrics + self.assert_metrics(metrics_dict, expected_metrics) + def test_non_empty_ref_and_pred_with_full_overlap(self): """ Non-empty reference and non-empty prediction with full overlap