Skip to content

Commit

Permalink
Speed up trks_stats by a factor of 5 (#84)
Browse files Browse the repository at this point in the history
* No need to iterate over dict.keys()

* Refactor `trks_stats` to avoid unnecessary computation.

* test trks_stats on a trks file, not the dumped .trk file.

* Bump version to 0.5.2
  • Loading branch information
willgraf authored Oct 25, 2021
1 parent 1c6d160 commit b3c820d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 31 deletions.
2 changes: 1 addition & 1 deletion deepcell_tracking/tracking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ def test_track_cells(self, tmpdir):
data = utils.load_trks(os.path.join(tempdir, 'all.trks'))

# test trks_stats
utils.trks_stats(os.path.join(tempdir, 'test.trk'))
utils.trks_stats(os.path.join(tempdir, 'all.trks'))
53 changes: 24 additions & 29 deletions deepcell_tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,51 +315,46 @@ def trks_stats(filename):
"""
ext = os.path.splitext(filename)[-1].lower()
if ext not in {'.trks', '.trk'}:
raise ValueError('`trks_stats` expects a .trk or .trks but found a ' +
str(ext))
raise ValueError(
'`trks_stats` expects a .trk or .trks but found a {}'.format(ext))

training_data = load_trks(filename)
X = training_data['X']
y = training_data['y']
daughters = [{cell: fields['daughters']
for cell, fields in tracks.items()}
for tracks in training_data['lineages']]
lineages = training_data['lineages']

print('Dataset Statistics: ')
print('Image data shape: ', X.shape)
print('Number of lineages (should equal batch size): ',
len(training_data['lineages']))
print('Number of lineages (should equal batch size): ', len(lineages))

total_tracks = 0
total_divisions = 0

# Calculate cell density
frame_area = X.shape[2] * X.shape[3]

avg_cells_in_frame = []
avg_frame_counts_in_batches = []
for batch in range(y.shape[0]):
tracks = lineages[batch]
total_tracks += len(tracks)
num_frames_per_track = []

for cell_lineage in tracks.values():
num_frames_per_track.append(len(cell_lineage['frames']))
if cell_lineage.get('daughters', []):
total_divisions += 1
avg_frame_counts_in_batches.append(np.average(num_frames_per_track))

num_cells_in_frame = []
for frame in y[batch]:
cells_in_frame = len(np.unique(frame)) - 1 # unique returns 0 (BKGD)
num_cells_in_frame.append(cells_in_frame)
for frame in range(len(y[batch])):
y_frame = y[batch, frame]
cells_in_frame = np.unique(y_frame)
cells_in_frame = np.delete(cells_in_frame, 0) # rm background
num_cells_in_frame.append(len(cells_in_frame))
avg_cells_in_frame.append(np.average(num_cells_in_frame))
avg_cells_per_sq_pixel = np.average(avg_cells_in_frame) / frame_area

# Calculate division information
total_tracks = 0
total_divisions = 0
avg_frame_counts_in_batches = []
for batch, daughter_batch in enumerate(daughters):
num_tracks_in_batch = len(daughter_batch)
num_div_in_batch = len([c for c in daughter_batch if daughter_batch[c]])
total_tracks = total_tracks + num_tracks_in_batch
total_divisions = total_divisions + num_div_in_batch
frame_counts = []
for cell_id in daughter_batch.keys():
frame_count = 0
for frame in y[batch]:
cells_in_frame = np.unique(frame)
if cell_id in cells_in_frame:
frame_count += 1
frame_counts.append(frame_count)
avg_frame_counts_in_batches.append(np.average(frame_counts))
avg_cells_per_sq_pixel = np.average(avg_cells_in_frame) / frame_area
avg_num_frames_per_track = np.average(avg_frame_counts_in_batches)

print('Total number of unique tracks (cells) - ', total_tracks)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
readme = f.read()


VERSION = '0.5.1'
VERSION = '0.5.2'
NAME = 'DeepCell_Tracking'
DESCRIPTION = 'Tracking cells and lineage with deep learning.'
LICENSE = 'LICENSE'
Expand Down

0 comments on commit b3c820d

Please sign in to comment.