diff --git a/deepcell/layers/temporal.py b/deepcell/layers/temporal.py index 68135754..1b14afd1 100644 --- a/deepcell/layers/temporal.py +++ b/deepcell/layers/temporal.py @@ -97,7 +97,7 @@ def __init__(self, encoder_dim=64, **kwargs): def call(self, inputs): input_shape = tf.shape(inputs) # reshape away the temporal axis - x = tf.reshape(inputs, [-1, input_shape[2], self.encoder_dim]) + x = tf.reshape(inputs, [-1, input_shape[1], self.encoder_dim]) x = self.lstm(x) output_shape = [-1, input_shape[1], input_shape[2], self.encoder_dim] x = tf.reshape(x, output_shape) diff --git a/deepcell/model_zoo/tracking.py b/deepcell/model_zoo/tracking.py index 4f2aacb3..78f4bf21 100644 --- a/deepcell/model_zoo/tracking.py +++ b/deepcell/model_zoo/tracking.py @@ -220,6 +220,8 @@ class GNNTrackingModel: ``-kwarg:value-kwarg:value`` appearance_shape (tuple): shape of each object's appearance tensor norm_layer (str): Must be one of {'layer', 'batch'} + appearance_norm (bool): Whether to apply an input normalization layer + to the appearance head """ def __init__(self, max_cells=39, @@ -230,7 +232,8 @@ def __init__(self, n_layers=3, graph_layer='gcs', appearance_shape=(32, 32, 1), - norm_layer='batch'): + norm_layer='batch', + appearance_norm=True): self.n_filters = n_filters self.encoder_dim = encoder_dim @@ -238,6 +241,7 @@ def __init__(self, self.n_layers = n_layers self.max_cells = max_cells self.track_length = track_length + self.appearance_norm = appearance_norm if len(appearance_shape) != 3: raise ValueError('appearanace_shape should be a ' @@ -309,8 +313,9 @@ def get_appearance_encoder(self): inputs = Input(shape=app_shape, name='encoder_app_input') x = inputs - x = TimeDistributed(ImageNormalization2D(norm_method='whole_image', - name='imgnrm_ae'))(x) + if self.appearance_norm: + x = TimeDistributed(ImageNormalization2D(norm_method='whole_image', + name='imgnrm_ae'))(x) for i in range(int(math.log(app_shape[1], 2))): x = Conv3D(self.n_filters,