diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fee546bea3..bfd2f506c2 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1764,6 +1764,30 @@ def __init__( self.invert_affine = invert_affine self.affine_lps_to_ras = affine_lps_to_ras + def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the final affine transformation matrix to apply to the point data. + + Args: + data: Input coordinates assumed to be in the shape (C, N, 2 or 3). + affine: 3x3 or 4x4 affine transformation matrix. + + Returns: + Final affine transformation matrix. + """ + + affine = convert_data_type(affine, dtype=torch.float64)[0] + + if self.affine_lps_to_ras: + affine = orientation_ras_lps(affine) + + if self.invert_affine: + affine = linalg_inv(affine) + if applied_affine is not None: + affine = affine @ applied_affine + + return affine + def transform_coordinates( self, data: torch.Tensor, affine: torch.Tensor | None = None ) -> tuple[torch.Tensor, dict]: @@ -1780,35 +1804,25 @@ def transform_coordinates( Transformed coordinates. """ data = convert_to_tensor(data, track_meta=get_track_meta()) - # applied_affine is the affine transformation matrix that has already been applied to the point data - applied_affine = getattr(data, "affine", None) - if affine is None and self.invert_affine: raise ValueError("affine must be provided when invert_affine is True.") - + # applied_affine is the affine transformation matrix that has already been applied to the point data + applied_affine: torch.Tensor | None = getattr(data, "affine", None) affine = applied_affine if affine is None else affine - affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine - original_affine: torch.Tensor = affine - if self.affine_lps_to_ras: - affine = orientation_ras_lps(affine) + if affine is None: + raise ValueError("affine must be provided if data does not have an affine matrix.") - # the final affine transformation matrix that will be applied to the point data - _affine: torch.Tensor = affine - if self.invert_affine: - _affine = linalg_inv(affine) - if applied_affine is not None: - # consider the affine transformation already applied to the data in the world space - # and compute delta affine - _affine = _affine @ linalg_inv(applied_affine) - out = apply_affine_to_points(data, _affine, dtype=self.dtype) + final_affine = self._compute_final_affine(affine, applied_affine) + out = apply_affine_to_points(data, final_affine, dtype=self.dtype) extra_info = { "invert_affine": self.invert_affine, "dtype": get_dtype_string(self.dtype), - "image_affine": original_affine, # record for inverse operation + "image_affine": affine, "affine_lps_to_ras": self.affine_lps_to_ras, } - xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine) + + xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine) meta_info = TraceableTransform.track_transform_meta( data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info() ) @@ -1834,16 +1848,12 @@ def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None): def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) - # Create inverse transform - dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] - invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"] - affine = transform[TraceKeys.EXTRA_INFO]["image_affine"] - affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"] inverse_transform = ApplyTransformToPoints( - dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + dtype=transform[TraceKeys.EXTRA_INFO]["dtype"], + invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"], + affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"], ) - # Apply inverse with inverse_transform.trace_transform(False): - data = inverse_transform(data, affine) + data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"]) return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1279ca93ab..db5f19c0de 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform): Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform - refer_key: The key of the reference item used for transformation. - It can directly refer to an affine or an image from which the affine can be derived. + refer_keys: The key of the reference item used for transformation. + It can directly refer to an affine or an image from which the affine can be derived. It can also be a + sequence of keys, in which case each refers to the affine applied to the matching points in `keys`. dtype: The desired data type for the output. affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary @@ -1782,7 +1783,7 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform): def __init__( self, keys: KeysCollection, - refer_key: str | None = None, + refer_keys: KeysCollection | None = None, dtype: DtypeLike | torch.dtype = torch.float64, affine: torch.Tensor | None = None, invert_affine: bool = True, @@ -1790,23 +1791,24 @@ def __init__( allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) - self.refer_key = refer_key + self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys)) self.converter = ApplyTransformToPoints( dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras ) def __call__(self, data: Mapping[Hashable, torch.Tensor]): d = dict(data) - if self.refer_key is not None: - if self.refer_key in d: - refer_data = d[self.refer_key] - else: - raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.") - else: - refer_data = None - affine = getattr(refer_data, "affine", refer_data) - for key in self.key_iterator(d): + for key, refer_key in self.key_iterator(d, self.refer_keys): coords = d[key] + affine = None # represents using affine given in constructor + if refer_key is not None: + if refer_key in d: + refer_data = d[refer_key] + else: + raise KeyError(f"The refer_key '{refer_key}' is not found in the data.") + + # use the "affine" member of refer_data, or refer_data itself, as the affine matrix + affine = getattr(refer_data, "affine", refer_data) d[key] = self.converter(coords, affine) return d diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py index 4cedfa9d66..978113931c 100644 --- a/tests/test_apply_transform_to_pointsd.py +++ b/tests/test_apply_transform_to_pointsd.py @@ -30,72 +30,90 @@ POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) +AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]) TEST_CASES = [ + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine + [None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine [ - MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, - False, - POINT_2D_IMAGE, - ], + True, + POINT_2D_IMAGE_RAS, + ], # test affine_lps_to_ras + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself [ - None, - MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(DATA_3D, affine=AFFINE_2), + MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), None, False, False, - POINT_2D_WORLD, + POINT_3D_WORLD, ], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], +] +TEST_CASES_SEQUENCE = [ [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], None, - MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), - False, + True, False, - POINT_2D_WORLD, - ], + ["image_1", "image_2"], + [POINT_2D_IMAGE, POINT_3D_IMAGE], + ], # use image affine [ - MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_2D_WORLD, + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], None, True, True, - POINT_2D_IMAGE_RAS, - ], + ["image_1", "image_2"], + [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS], + ], # test affine_lps_to_ras [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_3D_WORLD, + (None, None), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, + False, + False, + None, + [POINT_2D_WORLD, POINT_3D_WORLD], + ], # use point affine + [ + (None, None), + [POINT_2D_WORLD, POINT_2D_WORLD], + AFFINE_1, True, False, - POINT_3D_IMAGE, - ], - ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + None, + [POINT_2D_IMAGE, POINT_2D_IMAGE], + ], # use input affine [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], None, False, False, - POINT_3D_WORLD, - ], - [ - MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), - POINT_3D_WORLD, - None, - True, - True, - POINT_3D_IMAGE_RAS, + ["image_1", "image_2"], + [POINT_2D_WORLD, POINT_3D_WORLD], ], ] TEST_CASES_WRONG = [ - [POINT_2D_WORLD, True, None], - [POINT_2D_WORLD.unsqueeze(0), False, None], - [POINT_3D_WORLD[..., 0:1], False, None], - [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], + [POINT_2D_WORLD, True, None, None], + [POINT_2D_WORLD.unsqueeze(0), False, None, None], + [POINT_3D_WORLD[..., 0:1], False, None, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None], + [POINT_3D_WORLD, False, None, "image"], + [POINT_3D_WORLD, False, None, []], ] @@ -107,10 +125,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin "point": points, "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), } - refer_key = "image" if (image is not None and image != "affine") else image + refer_keys = "image" if (image is not None and image != "affine") else image transform = ApplyTransformToPointsd( keys="point", - refer_key=refer_key, + refer_keys=refer_keys, dtype=torch.int64, affine=affine, invert_affine=invert_affine, @@ -122,11 +140,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin invert_out = transform.inverse(output) self.assertTrue(torch.allclose(invert_out["point"], points)) + @parameterized.expand(TEST_CASES_SEQUENCE) + def test_transform_coordinates_sequences( + self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output + ): + data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]} + keys = ["point_1", "point_2"] + transform = ApplyTransformToPointsd( + keys=keys, + refer_keys=refer_keys, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point_1"], expected_output[0])) + self.assertTrue(torch.allclose(output["point_2"], expected_output[1])) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point_1"], points[0])) + @parameterized.expand(TEST_CASES_WRONG) - def test_wrong_input(self, input, invert_affine, affine): - transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine) - with self.assertRaises(ValueError): - transform({"point": input}) + def test_wrong_input(self, input, invert_affine, affine, refer_keys): + if refer_keys == []: + with self.assertRaises(ValueError): + ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) + else: + transform = ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) + data = {"point": input} + if refer_keys == "image": + with self.assertRaises(KeyError): + transform(data) + else: + with self.assertRaises(ValueError): + transform(data) if __name__ == "__main__":