Skip to content

Commit

Permalink
Correct the axis which temporal merges are applied to in tracking mod…
Browse files Browse the repository at this point in the history
…el (#718)

* Apply temporal merge to the time dimensions, not the cell/batch dimension

* Make the image norm layer in appearance head configurable

* Remove import for a tracking config object

* Fix spelling typo

Co-authored-by: Ross Barnowski <[email protected]>

* Fix long line

---------

Co-authored-by: Ross Barnowski <[email protected]>
  • Loading branch information
msschwartz21 and rossbar authored Jun 25, 2024
1 parent 2ef52e0 commit 802ea72
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion deepcell/layers/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, encoder_dim=64, **kwargs):
def call(self, inputs):
input_shape = tf.shape(inputs)
# reshape away the temporal axis
x = tf.reshape(inputs, [-1, input_shape[2], self.encoder_dim])
x = tf.reshape(inputs, [-1, input_shape[1], self.encoder_dim])
x = self.lstm(x)
output_shape = [-1, input_shape[1], input_shape[2], self.encoder_dim]
x = tf.reshape(x, output_shape)
Expand Down
11 changes: 8 additions & 3 deletions deepcell/model_zoo/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ class GNNTrackingModel:
``<layer name>-kwarg:value-kwarg:value``
appearance_shape (tuple): shape of each object's appearance tensor
norm_layer (str): Must be one of {'layer', 'batch'}
appearance_norm (bool): Whether to apply an input normalization layer
to the appearance head
"""
def __init__(self,
max_cells=39,
Expand All @@ -230,14 +232,16 @@ def __init__(self,
n_layers=3,
graph_layer='gcs',
appearance_shape=(32, 32, 1),
norm_layer='batch'):
norm_layer='batch',
appearance_norm=True):

self.n_filters = n_filters
self.encoder_dim = encoder_dim
self.embedding_dim = embedding_dim
self.n_layers = n_layers
self.max_cells = max_cells
self.track_length = track_length
self.appearance_norm = appearance_norm

if len(appearance_shape) != 3:
raise ValueError('appearanace_shape should be a '
Expand Down Expand Up @@ -309,8 +313,9 @@ def get_appearance_encoder(self):
inputs = Input(shape=app_shape, name='encoder_app_input')

x = inputs
x = TimeDistributed(ImageNormalization2D(norm_method='whole_image',
name='imgnrm_ae'))(x)
if self.appearance_norm:
x = TimeDistributed(ImageNormalization2D(norm_method='whole_image',
name='imgnrm_ae'))(x)

for i in range(int(math.log(app_shape[1], 2))):
x = Conv3D(self.n_filters,
Expand Down

0 comments on commit 802ea72

Please sign in to comment.