From ea8076e92937e9da53640da4d022c3f77f1cedab Mon Sep 17 00:00:00 2001 From: willgraf <7930703+willgraf@users.noreply.github.com> Date: Mon, 11 Oct 2021 15:45:49 -0700 Subject: [PATCH] Migrate `Track` and `concat_tracks` to `deepcell.data.tracking`. (#79) --- deepcell_tracking/utils.py | 220 -------------------------------- deepcell_tracking/utils_test.py | 50 -------- 2 files changed, 270 deletions(-) diff --git a/deepcell_tracking/utils.py b/deepcell_tracking/utils.py index 39fc9e5..fbf0305 100644 --- a/deepcell_tracking/utils.py +++ b/deepcell_tracking/utils.py @@ -618,223 +618,3 @@ def get_image_features(X, y, appearance_dim=32): 'morphologies': morphologies, # 'adj_matrix': adj_matrix, } - - -def concat_tracks(tracks): - """Join an iterable of Track objects into a single dictionary of features. - - Args: - tracks (iterable): Iterable of tracks. - - Returns: - dict: A dictionary of tracked features. - - Raises: - TypeError: ``tracks`` is not iterable. - """ - try: - list(tracks) # check if iterable - except TypeError: - raise TypeError('concatenate_tracks requires an iterable input.') - - def get_array_of_max_shape(lst): - # find max dimensions of all arrs in lst. - shape = None - size = 0 - for arr in lst: - if shape is None: - shape = [0] * len(arr.shape[1:]) - for i, dim in enumerate(arr.shape[1:]): - if dim > shape[i]: - shape[i] = dim - size += arr.shape[0] - # add batch dimension - shape = [size] + shape - return np.zeros(shape, dtype='float32') - - # insert small array into larger array - # https://stackoverflow.com/a/50692782 - def paste_slices(tup): - pos, w, max_w = tup - wall_min = max(pos, 0) - wall_max = min(pos + w, max_w) - block_min = -min(pos, 0) - block_max = max_w - max(pos + w, max_w) - block_max = block_max if block_max != 0 else None - return slice(wall_min, wall_max), slice(block_min, block_max) - - def paste(wall, block, loc): - loc_zip = zip(loc, block.shape, wall.shape) - wall_slices, block_slices = zip(*map(paste_slices, loc_zip)) - wall[wall_slices] = block[block_slices] - - # TODO: these keys must match the Track attributes. - track_info = { - 'appearances': get_array_of_max_shape((t.appearances for t in tracks)), - 'centroids': get_array_of_max_shape((t.centroids for t in tracks)), - 'morphologies': get_array_of_max_shape((t.morphologies for t in tracks)), - 'adj_matrices': get_array_of_max_shape((t.adj_matrices for t in tracks)), - 'norm_adj_matrices': get_array_of_max_shape( - (t.norm_adj_matrices for t in tracks)), - 'temporal_adj_matrices': get_array_of_max_shape( - (t.temporal_adj_matrices for t in tracks)) - } - - for track in tracks: - for k in track_info: - feature = getattr(track, k) - paste(track_info[k], feature, (0,) * len(feature.shape)) - - return track_info - - -class Track(object): # pylint: disable=useless-object-inheritance - - def __init__(self, path=None, tracked_data=None, - appearance_dim=32, distance_threshold=64): - if tracked_data: - training_data = tracked_data - elif path: - training_data = load_trks(path) - else: - raise ValueError('One of `tracked_data` or `path` is required') - self.X = training_data['X'].astype('float32') - self.y = training_data['y'].astype('int32') - self.lineages = training_data['lineages'] - self.appearance_dim = appearance_dim - self.distance_threshold = distance_threshold - - # Correct lineages and remove bad batches - self._correct_lineages() - - # Create feature dictionaries - features_dict = self._get_features() - self.appearances = features_dict['appearances'] - self.morphologies = features_dict['morphologies'] - self.centroids = features_dict['centroids'] - self.adj_matrices = features_dict['adj_matrix'] - self.norm_adj_matrices = normalize_adj_matrix(self.adj_matrices) - self.temporal_adj_matrices = features_dict['temporal_adj_matrix'] - self.mask = features_dict['mask'] - self.track_length = features_dict['track_length'] - - def _correct_lineages(self): - """Ensure valid lineages and sequential labels for all batches""" - new_X = [] - new_y = [] - new_lineages = [] - - for batch in range(self.y.shape[0]): - if is_valid_lineage(self.y[batch], self.lineages[batch]): - - y_relabel, new_lineage = relabel_sequential_lineage( - self.y[batch], self.lineages[batch]) - - new_X.append(self.X[batch]) - new_y.append(y_relabel) - new_lineages.append(new_lineage) - - self.X = np.stack(new_X, axis=0) - self.y = np.stack(new_y, axis=0) - self.lineages = new_lineages - - def _get_features(self): - """ - Extract the relevant features from the label movie - Appearance, morphologies, centroids, and adjacency matrices - """ - max_tracks = get_max_cells(self.y) - n_batches = self.X.shape[0] - n_frames = self.X.shape[1] - n_channels = self.X.shape[-1] - - batch_shape = (n_batches, n_frames, max_tracks) - - appearance_shape = (self.appearance_dim, self.appearance_dim, n_channels) - - appearances = np.zeros(batch_shape + appearance_shape, dtype='float32') - - morphologies = np.zeros(batch_shape + (3,), dtype='float32') - - centroids = np.zeros(batch_shape + (2,), dtype='float32') - - adj_matrix = np.zeros(batch_shape + (max_tracks,), dtype='float32') - - temporal_adj_matrix = np.zeros((n_batches, - n_frames - 1, - max_tracks, - max_tracks, - 3), dtype='float32') - - mask = np.zeros(batch_shape, dtype='float32') - - track_length = np.zeros((n_batches, max_tracks, 2), dtype='int32') - - for batch in range(n_batches): - for frame in range(n_frames): - - frame_features = get_image_features( - self.X[batch, frame], self.y[batch, frame], - appearance_dim=self.appearance_dim) - - track_ids = frame_features['labels'] - 1 - centroids[batch, frame, track_ids] = frame_features['centroids'] - morphologies[batch, frame, track_ids] = frame_features['morphologies'] - appearances[batch, frame, track_ids] = frame_features['appearances'] - mask[batch, frame, track_ids] = 1 - - # Get adjacency matrix, cannot filter on track ids. - cent = centroids[batch, frame] - distance = cdist(cent, cent, metric='euclidean') - distance = distance < self.distance_threshold - adj_matrix[batch, frame] = distance.astype(np.float32) - - # Get track length and temporal adjacency matrix - for label in self.lineages[batch]: - # Get track length - start_frame = self.lineages[batch][label]['frames'][0] - end_frame = self.lineages[batch][label]['frames'][-1] - - track_id = label - 1 - track_length[batch, track_id, 0] = start_frame - track_length[batch, track_id, 1] = end_frame - - # Get temporal adjacency matrix - frames = self.lineages[batch][label]['frames'] - - # Assign same - for f0, f1 in zip(frames[0:-1], frames[1:]): - if f1 - f0 == 1: - temporal_adj_matrix[batch, f0, track_id, track_id, 0] = 1 - - # Assign daughter - # WARNING: This wont work if there's a time gap between mother - # cell disappearing and daughter cells appearing - last_frame = frames[-1] - daughters = self.lineages[batch][label]['daughters'] - for daughter in daughters: - daughter_id = daughter - 1 - temporal_adj_matrix[batch, last_frame, track_id, daughter_id, 2] = 1 - - # Assign different - same_prob = temporal_adj_matrix[batch, ..., 0] - daughter_prob = temporal_adj_matrix[batch, ..., 2] - temporal_adj_matrix[batch, ..., 1] = 1 - same_prob - daughter_prob - - # Identify padding - for i in range(temporal_adj_matrix.shape[2]): - # index + 1 is the cell label - if i + 1 not in self.lineages[batch]: - temporal_adj_matrix[batch, :, i] = -1 - temporal_adj_matrix[batch, :, :, i] = -1 - - feature_dict = {} - feature_dict['adj_matrix'] = adj_matrix - feature_dict['appearances'] = appearances - feature_dict['morphologies'] = morphologies - feature_dict['centroids'] = centroids - feature_dict['temporal_adj_matrix'] = temporal_adj_matrix - feature_dict['mask'] = mask - feature_dict['track_length'] = track_length - - return feature_dict diff --git a/deepcell_tracking/utils_test.py b/deepcell_tracking/utils_test.py index 06dcb9a..2cdaefa 100644 --- a/deepcell_tracking/utils_test.py +++ b/deepcell_tracking/utils_test.py @@ -437,57 +437,7 @@ def test_get_image_features(self): assert labels.shape == expected_shape np.testing.assert_array_equal(labels, np.array(list(range(1, num_labels + 1)))) - def test_concat_tracks(self): - num_labels = 3 - - data = get_dummy_data(num_labels) - track_1 = utils.Track(tracked_data=data) - track_2 = utils.Track(tracked_data=data) - - data = utils.concat_tracks([track_1, track_2]) - - for k, v in data.items(): - starting_batch = 0 - for t in (track_1, track_2): - assert hasattr(t, k) - w = getattr(t, k) - # data is put into top left corner of array - v_sub = v[ - starting_batch:starting_batch + w.shape[0], - 0:w.shape[1], - 0:w.shape[2], - 0:w.shape[3] - ] - np.testing.assert_array_equal(v_sub, w) - - # test that input must be iterable - with pytest.raises(TypeError): - utils.concat_tracks(track_1) - def test_trks_stats(self): # Test bad extension with pytest.raises(ValueError): utils.trks_stats('bad-extension.npz') - - -class TestTrack(object): - - def test_init(self, mocker): - num_labels = 3 - - data = get_dummy_data(num_labels) - - # invalidate one lineage - mocker.patch('deepcell_tracking.utils.load_trks', - lambda x: data) - - track1 = utils.Track(tracked_data=data) - track2 = utils.Track(path='path/to/data') - - np.testing.assert_array_equal(track1.appearances, track2.appearances) - np.testing.assert_array_equal( - track1.temporal_adj_matrices, - track2.temporal_adj_matrices) - - with pytest.raises(ValueError): - utils.Track()