-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcsa_pmj_extrapolation.py
121 lines (100 loc) · 5.22 KB
/
csa_pmj_extrapolation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# -*- coding: utf-8
# Functions to get distance from PMJ for processing segmentation data
# Author: Sandrine Bédard
import logging
import numpy as np
from spinalcordtoolbox.image import Image
from spinalcordtoolbox.centerline.core import get_centerline
logger = logging.getLogger(__name__)
NEAR_ZERO_THRESHOLD = 1e-6
def get_slices_for_pmj_distance(segmentation, pmj, distance, extent, param_centerline=None, verbose=1):
"""
Compute distance from PMJ projection on centerline for all the centerline.
Generate mask from segmentation of the slices used to process segmentation data corresponding to a distance from PMJ projection.
:param segmentation: input segmentation. Could be either an Image or a file name.
:param pmj: label of PMJ.
:param distance: float: Distance from Ponto-Medullary Junction (PMJ) in mm.
:param param_centerline: see centerline.core.ParamCenterline()
:param verbose:
:return im_ctl:
:return mask:
:return slices:
"""
im_seg = Image(segmentation).change_orientation('RPI')
im_pmj = Image(pmj).change_orientation('RPI')
native_orientation = im_seg.orientation
nx, ny, nz, nt, px, py, pz, pt = im_seg.dim
data_pmj = im_pmj.data
if not im_seg.data.shape == im_pmj.data.shape:
raise RuntimeError("Segmentation and pmj should be in the same space coordinate.")
# Extract min and max index in Z direction
data_seg = im_seg.data
X, Y, Z = (data_seg > NEAR_ZERO_THRESHOLD).nonzero()
min_z_index, max_z_index = min(Z), max(Z)
# Remove top slices
im_seg.data[:, :, max_z_index - 0:max_z_index + 1] = 0
# Compute the spinal cord centerline based on the spinal cord segmentation
param_centerline.minmax = False # Set to false to extrapolate centerline
im_ctl, arr_ctl, arr_ctl_der, fit_results = get_centerline(im_seg, param=param_centerline, verbose=verbose)
im_ctl.change_orientation(native_orientation)
# Get coordinate of PMJ label
pmj_coord = np.argwhere(data_pmj != 0)[0]
# Get Z index of PMJ project on extrapolated centerline
pmj_index = get_min_distance(pmj_coord, arr_ctl, px, py, pz)
# Compute distance from PMJ along centerline
arr_length = get_distance_from_pmj(arr_ctl, pmj_index, px, py, pz)
# Check if distance is out of bound
if distance > arr_length[0][0]:
raise ValueError("Input distance of " + str(distance) + " mm is out of bound for maximum distance of " + str(arr_length[0][0]) + " mm")
if distance < arr_length[0][-1]:
raise ValueError("Input distance of " + str(distance) + " mm is out of bound for minimum distance of " + str(arr_length[0][-1]) + " mm")
zmin = np.argmin(np.array([np.abs(i - distance - extent/2) for i in arr_length[0]]))
zmax = np.argmin(np.array([np.abs(i - distance + extent/2) for i in arr_length[0]]))
# Check if the range of selected slices are covered by the segmentation
if not all(np.any(im_seg.data[:, :, z]) for z in range(zmin, zmax)):
raise ValueError(f"The requested distances from the PMJ are not fully covered by the segmentation.\n"
f"The range of slices are: [{zmin}, {zmax}]")
# Create mask from segmentation centered on distance from PMJ and with extent length on z axis.
mask = im_seg.copy()
mask.data[:, :, 0:zmin] = 0
mask.data[:, :, zmax:] = 0
mask.change_orientation(native_orientation)
# Get corresponding slices
slices = "{}:{}".format(zmin, zmax - 1) # TODO check if we include last slice
return im_ctl, mask, slices, arr_ctl
def get_distance_from_pmj(centerline_points, z_index, px, py, pz):
"""
Compute distance from projected PMJ on centerline and cord centerline.
:param centerline_points: 3xn array: Centerline in continuous coordinate (float) for each slice in RPI orientation.
:param z_index: z index of projected PMJ on the centerline.
:param px: x pixel size.
:param py: y pixel size.
:param pz: z pixel size.
:return: nd-array: distance from PMJ and corresponding indexes.
"""
length = 0
arr_length = [0]
for i in range(z_index, 0, -1):
distance = np.sqrt(((centerline_points[0, i] - centerline_points[0, i - 1]) * px) ** 2 +
((centerline_points[1, i] - centerline_points[1, i - 1]) * py) ** 2 +
((centerline_points[2, i] - centerline_points[2, i - 1]) * pz) ** 2)
length += distance
arr_length.append(length)
arr_length = arr_length[::-1]
arr_length = np.stack((arr_length, centerline_points[2][:z_index + 1]), axis=0)
return arr_length
def get_min_distance(pmj, centerline, px, py, pz):
"""
Get index of minimum distance from pmj coordinate and centerline.
:param pmj: 3xn array: coordinate of the PMJ with RPI orientation.
:param centerline: 3xn array: Centerline in continuous coordinate (float) for each slice in RPI orientation.
:param px: x pixel size.
:param py: y pixel size.
:param pz: z pixel size.
:retrun: int: z index.
"""
distance = np.sqrt(((centerline[0, :] - pmj[0]) * px) ** 2 +
((centerline[1, :] - pmj[1]) * py) ** 2 +
((centerline[2, :] - pmj[2]) * pz) ** 2)
return int(centerline[2, distance.argmin()])