Skip to content

Commit

Permalink
Make the image norm layer in appearance head configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Jun 5, 2024
1 parent 09ba804 commit aaf0d44
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions deepcell/model_zoo/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

from spektral.layers import GCSConv, GCNConv, GATConv

from deepcell_tracking import TrackingConfig

from deepcell.layers import ImageNormalization2D
from deepcell.layers import Comparison, DeltaReshape, Unmerge, TemporalMerge

Expand Down Expand Up @@ -220,6 +222,7 @@ class GNNTrackingModel:
``<layer name>-kwarg:value-kwarg:value``
appearance_shape (tuple): shape of each object's appearance tensor
norm_layer (str): Must be one of {'layer', 'batch'}
appearance_norm (bool): Whether to apply an input normalization layer to the apperance head
"""
def __init__(self,
max_cells=39,
Expand All @@ -230,14 +233,16 @@ def __init__(self,
n_layers=3,
graph_layer='gcs',
appearance_shape=(32, 32, 1),
norm_layer='batch'):
norm_layer='batch',
appearance_norm=True):

self.n_filters = n_filters
self.encoder_dim = encoder_dim
self.embedding_dim = embedding_dim
self.n_layers = n_layers
self.max_cells = max_cells
self.track_length = track_length
self.appearance_norm = appearance_norm

if len(appearance_shape) != 3:
raise ValueError('appearanace_shape should be a '
Expand Down Expand Up @@ -309,8 +314,9 @@ def get_appearance_encoder(self):
inputs = Input(shape=app_shape, name='encoder_app_input')

x = inputs
x = TimeDistributed(ImageNormalization2D(norm_method='whole_image',
name='imgnrm_ae'))(x)
if self.appearance_norm:
x = TimeDistributed(ImageNormalization2D(norm_method='whole_image',
name='imgnrm_ae'))(x)

for i in range(int(math.log(app_shape[1], 2))):
x = Conv3D(self.n_filters,
Expand Down

0 comments on commit aaf0d44

Please sign in to comment.