Skip to content

Commit

Permalink
Allow ApplyTransformToPointsd receive a sequence of refer keys (#8063)
Browse files Browse the repository at this point in the history
Enhance `ApplyTransformToPointsd` to receive a sequence of refer keys. 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent aea46ff commit 4e70bf6
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 82 deletions.
64 changes: 37 additions & 27 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,30 @@ def __init__(
self.invert_affine = invert_affine
self.affine_lps_to_ras = affine_lps_to_ras

def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor:
"""
Compute the final affine transformation matrix to apply to the point data.
Args:
data: Input coordinates assumed to be in the shape (C, N, 2 or 3).
affine: 3x3 or 4x4 affine transformation matrix.
Returns:
Final affine transformation matrix.
"""

affine = convert_data_type(affine, dtype=torch.float64)[0]

if self.affine_lps_to_ras:
affine = orientation_ras_lps(affine)

if self.invert_affine:
affine = linalg_inv(affine)
if applied_affine is not None:
affine = affine @ applied_affine

return affine

def transform_coordinates(
self, data: torch.Tensor, affine: torch.Tensor | None = None
) -> tuple[torch.Tensor, dict]:
Expand All @@ -1780,35 +1804,25 @@ def transform_coordinates(
Transformed coordinates.
"""
data = convert_to_tensor(data, track_meta=get_track_meta())
# applied_affine is the affine transformation matrix that has already been applied to the point data
applied_affine = getattr(data, "affine", None)

if affine is None and self.invert_affine:
raise ValueError("affine must be provided when invert_affine is True.")

# applied_affine is the affine transformation matrix that has already been applied to the point data
applied_affine: torch.Tensor | None = getattr(data, "affine", None)
affine = applied_affine if affine is None else affine
affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine
original_affine: torch.Tensor = affine
if self.affine_lps_to_ras:
affine = orientation_ras_lps(affine)
if affine is None:
raise ValueError("affine must be provided if data does not have an affine matrix.")

# the final affine transformation matrix that will be applied to the point data
_affine: torch.Tensor = affine
if self.invert_affine:
_affine = linalg_inv(affine)
if applied_affine is not None:
# consider the affine transformation already applied to the data in the world space
# and compute delta affine
_affine = _affine @ linalg_inv(applied_affine)
out = apply_affine_to_points(data, _affine, dtype=self.dtype)
final_affine = self._compute_final_affine(affine, applied_affine)
out = apply_affine_to_points(data, final_affine, dtype=self.dtype)

extra_info = {
"invert_affine": self.invert_affine,
"dtype": get_dtype_string(self.dtype),
"image_affine": original_affine, # record for inverse operation
"image_affine": affine,
"affine_lps_to_ras": self.affine_lps_to_ras,
}
xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine)

xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine)
meta_info = TraceableTransform.track_transform_meta(
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
)
Expand All @@ -1834,16 +1848,12 @@ def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"]
affine = transform[TraceKeys.EXTRA_INFO]["image_affine"]
affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"]
inverse_transform = ApplyTransformToPoints(
dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
dtype=transform[TraceKeys.EXTRA_INFO]["dtype"],
invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"],
affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"],
)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data, affine)
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])

return data
28 changes: 15 additions & 13 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_key: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived.
refer_keys: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived. It can also be a
sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Expand All @@ -1782,31 +1783,32 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
def __init__(
self,
keys: KeysCollection,
refer_key: str | None = None,
refer_keys: KeysCollection | None = None,
dtype: DtypeLike | torch.dtype = torch.float64,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.refer_key = refer_key
self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys))
self.converter = ApplyTransformToPoints(
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]):
d = dict(data)
if self.refer_key is not None:
if self.refer_key in d:
refer_data = d[self.refer_key]
else:
raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.")
else:
refer_data = None
affine = getattr(refer_data, "affine", refer_data)
for key in self.key_iterator(d):
for key, refer_key in self.key_iterator(d, self.refer_keys):
coords = d[key]
affine = None # represents using affine given in constructor
if refer_key is not None:
if refer_key in d:
refer_data = d[refer_key]
else:
raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")

# use the "affine" member of refer_data, or refer_data itself, as the affine matrix
affine = getattr(refer_data, "affine", refer_data)
d[key] = self.converter(coords, affine)
return d

Expand Down
136 changes: 94 additions & 42 deletions tests/test_apply_transform_to_pointsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,72 +30,90 @@
POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])
POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])
POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])
AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])

TEST_CASES = [
[MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine
[None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine
[
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(DATA_2D, affine=AFFINE_1),
POINT_2D_WORLD,
None,
True,
False,
POINT_2D_IMAGE,
],
True,
POINT_2D_IMAGE_RAS,
], # test affine_lps_to_ras
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself
[
None,
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(DATA_3D, affine=AFFINE_2),
MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),
None,
False,
False,
POINT_2D_WORLD,
POINT_3D_WORLD,
],
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
]
TEST_CASES_SEQUENCE = [
[
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[POINT_2D_WORLD, POINT_3D_WORLD],
None,
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
False,
True,
False,
POINT_2D_WORLD,
],
["image_1", "image_2"],
[POINT_2D_IMAGE, POINT_3D_IMAGE],
], # use image affine
[
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_2D_WORLD,
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[POINT_2D_WORLD, POINT_3D_WORLD],
None,
True,
True,
POINT_2D_IMAGE_RAS,
],
["image_1", "image_2"],
[POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS],
], # test affine_lps_to_ras
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_3D_WORLD,
(None, None),
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
None,
False,
False,
None,
[POINT_2D_WORLD, POINT_3D_WORLD],
], # use point affine
[
(None, None),
[POINT_2D_WORLD, POINT_2D_WORLD],
AFFINE_1,
True,
False,
POINT_3D_IMAGE,
],
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
None,
[POINT_2D_IMAGE, POINT_2D_IMAGE],
], # use input affine
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
None,
False,
False,
POINT_3D_WORLD,
],
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_3D_WORLD,
None,
True,
True,
POINT_3D_IMAGE_RAS,
["image_1", "image_2"],
[POINT_2D_WORLD, POINT_3D_WORLD],
],
]

TEST_CASES_WRONG = [
[POINT_2D_WORLD, True, None],
[POINT_2D_WORLD.unsqueeze(0), False, None],
[POINT_3D_WORLD[..., 0:1], False, None],
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])],
[POINT_2D_WORLD, True, None, None],
[POINT_2D_WORLD.unsqueeze(0), False, None, None],
[POINT_3D_WORLD[..., 0:1], False, None, None],
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None],
[POINT_3D_WORLD, False, None, "image"],
[POINT_3D_WORLD, False, None, []],
]


Expand All @@ -107,10 +125,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
"point": points,
"affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]),
}
refer_key = "image" if (image is not None and image != "affine") else image
refer_keys = "image" if (image is not None and image != "affine") else image
transform = ApplyTransformToPointsd(
keys="point",
refer_key=refer_key,
refer_keys=refer_keys,
dtype=torch.int64,
affine=affine,
invert_affine=invert_affine,
Expand All @@ -122,11 +140,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
invert_out = transform.inverse(output)
self.assertTrue(torch.allclose(invert_out["point"], points))

@parameterized.expand(TEST_CASES_SEQUENCE)
def test_transform_coordinates_sequences(
self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output
):
data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]}
keys = ["point_1", "point_2"]
transform = ApplyTransformToPointsd(
keys=keys,
refer_keys=refer_keys,
dtype=torch.int64,
affine=affine,
invert_affine=invert_affine,
affine_lps_to_ras=affine_lps_to_ras,
)
output = transform(data)

self.assertTrue(torch.allclose(output["point_1"], expected_output[0]))
self.assertTrue(torch.allclose(output["point_2"], expected_output[1]))
invert_out = transform.inverse(output)
self.assertTrue(torch.allclose(invert_out["point_1"], points[0]))

@parameterized.expand(TEST_CASES_WRONG)
def test_wrong_input(self, input, invert_affine, affine):
transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine)
with self.assertRaises(ValueError):
transform({"point": input})
def test_wrong_input(self, input, invert_affine, affine, refer_keys):
if refer_keys == []:
with self.assertRaises(ValueError):
ApplyTransformToPointsd(
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
)
else:
transform = ApplyTransformToPointsd(
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
)
data = {"point": input}
if refer_keys == "image":
with self.assertRaises(KeyError):
transform(data)
else:
with self.assertRaises(ValueError):
transform(data)


if __name__ == "__main__":
Expand Down

0 comments on commit 4e70bf6

Please sign in to comment.