From b3c820d21248998504d0789060755624cb8b92e3 Mon Sep 17 00:00:00 2001 From: willgraf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Oct 2021 16:21:58 -0700 Subject: [PATCH] Speed up `trks_stats` by a factor of 5 (#84) * No need to iterate over dict.keys() * Refactor `trks_stats` to avoid unnecessary computation. * test trks_stats on a trks file, not the dumped .trk file. * Bump version to 0.5.2 --- deepcell_tracking/tracking_test.py | 2 +- deepcell_tracking/utils.py | 53 ++++++++++++++---------------- setup.py | 2 +- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/deepcell_tracking/tracking_test.py b/deepcell_tracking/tracking_test.py index 5002307..29dd551 100644 --- a/deepcell_tracking/tracking_test.py +++ b/deepcell_tracking/tracking_test.py @@ -199,4 +199,4 @@ def test_track_cells(self, tmpdir): data = utils.load_trks(os.path.join(tempdir, 'all.trks')) # test trks_stats - utils.trks_stats(os.path.join(tempdir, 'test.trk')) + utils.trks_stats(os.path.join(tempdir, 'all.trks')) diff --git a/deepcell_tracking/utils.py b/deepcell_tracking/utils.py index 00933b5..36127ac 100644 --- a/deepcell_tracking/utils.py +++ b/deepcell_tracking/utils.py @@ -315,51 +315,46 @@ def trks_stats(filename): """ ext = os.path.splitext(filename)[-1].lower() if ext not in {'.trks', '.trk'}: - raise ValueError('`trks_stats` expects a .trk or .trks but found a ' + - str(ext)) + raise ValueError( + '`trks_stats` expects a .trk or .trks but found a {}'.format(ext)) training_data = load_trks(filename) X = training_data['X'] y = training_data['y'] - daughters = [{cell: fields['daughters'] - for cell, fields in tracks.items()} - for tracks in training_data['lineages']] + lineages = training_data['lineages'] print('Dataset Statistics: ') print('Image data shape: ', X.shape) - print('Number of lineages (should equal batch size): ', - len(training_data['lineages'])) + print('Number of lineages (should equal batch size): ', len(lineages)) + + total_tracks = 0 + total_divisions = 0 # Calculate cell density frame_area = X.shape[2] * X.shape[3] avg_cells_in_frame = [] + avg_frame_counts_in_batches = [] for batch in range(y.shape[0]): + tracks = lineages[batch] + total_tracks += len(tracks) + num_frames_per_track = [] + + for cell_lineage in tracks.values(): + num_frames_per_track.append(len(cell_lineage['frames'])) + if cell_lineage.get('daughters', []): + total_divisions += 1 + avg_frame_counts_in_batches.append(np.average(num_frames_per_track)) + num_cells_in_frame = [] - for frame in y[batch]: - cells_in_frame = len(np.unique(frame)) - 1 # unique returns 0 (BKGD) - num_cells_in_frame.append(cells_in_frame) + for frame in range(len(y[batch])): + y_frame = y[batch, frame] + cells_in_frame = np.unique(y_frame) + cells_in_frame = np.delete(cells_in_frame, 0) # rm background + num_cells_in_frame.append(len(cells_in_frame)) avg_cells_in_frame.append(np.average(num_cells_in_frame)) - avg_cells_per_sq_pixel = np.average(avg_cells_in_frame) / frame_area - # Calculate division information - total_tracks = 0 - total_divisions = 0 - avg_frame_counts_in_batches = [] - for batch, daughter_batch in enumerate(daughters): - num_tracks_in_batch = len(daughter_batch) - num_div_in_batch = len([c for c in daughter_batch if daughter_batch[c]]) - total_tracks = total_tracks + num_tracks_in_batch - total_divisions = total_divisions + num_div_in_batch - frame_counts = [] - for cell_id in daughter_batch.keys(): - frame_count = 0 - for frame in y[batch]: - cells_in_frame = np.unique(frame) - if cell_id in cells_in_frame: - frame_count += 1 - frame_counts.append(frame_count) - avg_frame_counts_in_batches.append(np.average(frame_counts)) + avg_cells_per_sq_pixel = np.average(avg_cells_in_frame) / frame_area avg_num_frames_per_track = np.average(avg_frame_counts_in_batches) print('Total number of unique tracks (cells) - ', total_tracks) diff --git a/setup.py b/setup.py index ddfda86..3e53dad 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ readme = f.read() -VERSION = '0.5.1' +VERSION = '0.5.2' NAME = 'DeepCell_Tracking' DESCRIPTION = 'Tracking cells and lineage with deep learning.' LICENSE = 'LICENSE'