Skip to content

Commit

Permalink
Migrate Track and concat_tracks to deepcell.data.tracking. (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf authored Oct 11, 2021
1 parent 5a1465e commit ea8076e
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 270 deletions.
220 changes: 0 additions & 220 deletions deepcell_tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,223 +618,3 @@ def get_image_features(X, y, appearance_dim=32):
'morphologies': morphologies,
# 'adj_matrix': adj_matrix,
}


def concat_tracks(tracks):
"""Join an iterable of Track objects into a single dictionary of features.
Args:
tracks (iterable): Iterable of tracks.
Returns:
dict: A dictionary of tracked features.
Raises:
TypeError: ``tracks`` is not iterable.
"""
try:
list(tracks) # check if iterable
except TypeError:
raise TypeError('concatenate_tracks requires an iterable input.')

def get_array_of_max_shape(lst):
# find max dimensions of all arrs in lst.
shape = None
size = 0
for arr in lst:
if shape is None:
shape = [0] * len(arr.shape[1:])
for i, dim in enumerate(arr.shape[1:]):
if dim > shape[i]:
shape[i] = dim
size += arr.shape[0]
# add batch dimension
shape = [size] + shape
return np.zeros(shape, dtype='float32')

# insert small array into larger array
# https://stackoverflow.com/a/50692782
def paste_slices(tup):
pos, w, max_w = tup
wall_min = max(pos, 0)
wall_max = min(pos + w, max_w)
block_min = -min(pos, 0)
block_max = max_w - max(pos + w, max_w)
block_max = block_max if block_max != 0 else None
return slice(wall_min, wall_max), slice(block_min, block_max)

def paste(wall, block, loc):
loc_zip = zip(loc, block.shape, wall.shape)
wall_slices, block_slices = zip(*map(paste_slices, loc_zip))
wall[wall_slices] = block[block_slices]

# TODO: these keys must match the Track attributes.
track_info = {
'appearances': get_array_of_max_shape((t.appearances for t in tracks)),
'centroids': get_array_of_max_shape((t.centroids for t in tracks)),
'morphologies': get_array_of_max_shape((t.morphologies for t in tracks)),
'adj_matrices': get_array_of_max_shape((t.adj_matrices for t in tracks)),
'norm_adj_matrices': get_array_of_max_shape(
(t.norm_adj_matrices for t in tracks)),
'temporal_adj_matrices': get_array_of_max_shape(
(t.temporal_adj_matrices for t in tracks))
}

for track in tracks:
for k in track_info:
feature = getattr(track, k)
paste(track_info[k], feature, (0,) * len(feature.shape))

return track_info


class Track(object): # pylint: disable=useless-object-inheritance

def __init__(self, path=None, tracked_data=None,
appearance_dim=32, distance_threshold=64):
if tracked_data:
training_data = tracked_data
elif path:
training_data = load_trks(path)
else:
raise ValueError('One of `tracked_data` or `path` is required')
self.X = training_data['X'].astype('float32')
self.y = training_data['y'].astype('int32')
self.lineages = training_data['lineages']
self.appearance_dim = appearance_dim
self.distance_threshold = distance_threshold

# Correct lineages and remove bad batches
self._correct_lineages()

# Create feature dictionaries
features_dict = self._get_features()
self.appearances = features_dict['appearances']
self.morphologies = features_dict['morphologies']
self.centroids = features_dict['centroids']
self.adj_matrices = features_dict['adj_matrix']
self.norm_adj_matrices = normalize_adj_matrix(self.adj_matrices)
self.temporal_adj_matrices = features_dict['temporal_adj_matrix']
self.mask = features_dict['mask']
self.track_length = features_dict['track_length']

def _correct_lineages(self):
"""Ensure valid lineages and sequential labels for all batches"""
new_X = []
new_y = []
new_lineages = []

for batch in range(self.y.shape[0]):
if is_valid_lineage(self.y[batch], self.lineages[batch]):

y_relabel, new_lineage = relabel_sequential_lineage(
self.y[batch], self.lineages[batch])

new_X.append(self.X[batch])
new_y.append(y_relabel)
new_lineages.append(new_lineage)

self.X = np.stack(new_X, axis=0)
self.y = np.stack(new_y, axis=0)
self.lineages = new_lineages

def _get_features(self):
"""
Extract the relevant features from the label movie
Appearance, morphologies, centroids, and adjacency matrices
"""
max_tracks = get_max_cells(self.y)
n_batches = self.X.shape[0]
n_frames = self.X.shape[1]
n_channels = self.X.shape[-1]

batch_shape = (n_batches, n_frames, max_tracks)

appearance_shape = (self.appearance_dim, self.appearance_dim, n_channels)

appearances = np.zeros(batch_shape + appearance_shape, dtype='float32')

morphologies = np.zeros(batch_shape + (3,), dtype='float32')

centroids = np.zeros(batch_shape + (2,), dtype='float32')

adj_matrix = np.zeros(batch_shape + (max_tracks,), dtype='float32')

temporal_adj_matrix = np.zeros((n_batches,
n_frames - 1,
max_tracks,
max_tracks,
3), dtype='float32')

mask = np.zeros(batch_shape, dtype='float32')

track_length = np.zeros((n_batches, max_tracks, 2), dtype='int32')

for batch in range(n_batches):
for frame in range(n_frames):

frame_features = get_image_features(
self.X[batch, frame], self.y[batch, frame],
appearance_dim=self.appearance_dim)

track_ids = frame_features['labels'] - 1
centroids[batch, frame, track_ids] = frame_features['centroids']
morphologies[batch, frame, track_ids] = frame_features['morphologies']
appearances[batch, frame, track_ids] = frame_features['appearances']
mask[batch, frame, track_ids] = 1

# Get adjacency matrix, cannot filter on track ids.
cent = centroids[batch, frame]
distance = cdist(cent, cent, metric='euclidean')
distance = distance < self.distance_threshold
adj_matrix[batch, frame] = distance.astype(np.float32)

# Get track length and temporal adjacency matrix
for label in self.lineages[batch]:
# Get track length
start_frame = self.lineages[batch][label]['frames'][0]
end_frame = self.lineages[batch][label]['frames'][-1]

track_id = label - 1
track_length[batch, track_id, 0] = start_frame
track_length[batch, track_id, 1] = end_frame

# Get temporal adjacency matrix
frames = self.lineages[batch][label]['frames']

# Assign same
for f0, f1 in zip(frames[0:-1], frames[1:]):
if f1 - f0 == 1:
temporal_adj_matrix[batch, f0, track_id, track_id, 0] = 1

# Assign daughter
# WARNING: This wont work if there's a time gap between mother
# cell disappearing and daughter cells appearing
last_frame = frames[-1]
daughters = self.lineages[batch][label]['daughters']
for daughter in daughters:
daughter_id = daughter - 1
temporal_adj_matrix[batch, last_frame, track_id, daughter_id, 2] = 1

# Assign different
same_prob = temporal_adj_matrix[batch, ..., 0]
daughter_prob = temporal_adj_matrix[batch, ..., 2]
temporal_adj_matrix[batch, ..., 1] = 1 - same_prob - daughter_prob

# Identify padding
for i in range(temporal_adj_matrix.shape[2]):
# index + 1 is the cell label
if i + 1 not in self.lineages[batch]:
temporal_adj_matrix[batch, :, i] = -1
temporal_adj_matrix[batch, :, :, i] = -1

feature_dict = {}
feature_dict['adj_matrix'] = adj_matrix
feature_dict['appearances'] = appearances
feature_dict['morphologies'] = morphologies
feature_dict['centroids'] = centroids
feature_dict['temporal_adj_matrix'] = temporal_adj_matrix
feature_dict['mask'] = mask
feature_dict['track_length'] = track_length

return feature_dict
50 changes: 0 additions & 50 deletions deepcell_tracking/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,57 +437,7 @@ def test_get_image_features(self):
assert labels.shape == expected_shape
np.testing.assert_array_equal(labels, np.array(list(range(1, num_labels + 1))))

def test_concat_tracks(self):
num_labels = 3

data = get_dummy_data(num_labels)
track_1 = utils.Track(tracked_data=data)
track_2 = utils.Track(tracked_data=data)

data = utils.concat_tracks([track_1, track_2])

for k, v in data.items():
starting_batch = 0
for t in (track_1, track_2):
assert hasattr(t, k)
w = getattr(t, k)
# data is put into top left corner of array
v_sub = v[
starting_batch:starting_batch + w.shape[0],
0:w.shape[1],
0:w.shape[2],
0:w.shape[3]
]
np.testing.assert_array_equal(v_sub, w)

# test that input must be iterable
with pytest.raises(TypeError):
utils.concat_tracks(track_1)

def test_trks_stats(self):
# Test bad extension
with pytest.raises(ValueError):
utils.trks_stats('bad-extension.npz')


class TestTrack(object):

def test_init(self, mocker):
num_labels = 3

data = get_dummy_data(num_labels)

# invalidate one lineage
mocker.patch('deepcell_tracking.utils.load_trks',
lambda x: data)

track1 = utils.Track(tracked_data=data)
track2 = utils.Track(path='path/to/data')

np.testing.assert_array_equal(track1.appearances, track2.appearances)
np.testing.assert_array_equal(
track1.temporal_adj_matrices,
track2.temporal_adj_matrices)

with pytest.raises(ValueError):
utils.Track()

0 comments on commit ea8076e

Please sign in to comment.