diff --git a/deepcell_tracking/utils.py b/deepcell_tracking/utils.py index f70955a..121bfc3 100644 --- a/deepcell_tracking/utils.py +++ b/deepcell_tracking/utils.py @@ -409,18 +409,27 @@ def is_valid_lineage(y, lineage): return is_valid # if unchanged, all cell lineages are valid! -def get_image_features(X, y, appearance_dim=32): +def get_image_features(X, y, appearance_dim=32, crop_mode='resize', norm=True): """Return features for every object in the array. Args: X (np.array): a 3D numpy array of raw data of shape (x, y, c). y (np.array): a 3D numpy array of integer labels of shape (x, y, 1). appearance_dim (int): The resized shape of the appearance feature. + crop_mode (str): Whether to do a fixed crop or to crop and resize + to create the appearance features + norm (bool): Whether to remove non cell features and normalize the + foreground pixels by zero-meaning and dividing by the standard + deviation. Applies to fixed crop mode only. Returns: dict: A dictionary of feature names to np.arrays of shape (n, c) or (n, x, y, c) where n is the number of objects. """ + + if crop_mode not in ['resize', 'fixed']: + raise ValueError('crop_mode must be either resize or fixed') + appearance_dim = int(appearance_dim) # each feature will be ordered based on the label. @@ -432,9 +441,22 @@ def get_image_features(X, y, appearance_dim=32): appearances = np.zeros((num_labels, appearance_dim, appearance_dim, X.shape[-1]), dtype='float32') + if crop_mode == 'fixed': + # Zero-pad the X array for fixed crop mode + pad_width = ((appearance_dim, appearance_dim), + (appearance_dim, appearance_dim), + (0, 0)) + X_padded = np.pad(X, pad_width=pad_width) + y_padded = np.pad(y, pad_width=pad_width) + + props = regionprops(y_padded[..., 0], cache=False) + # iterate over all objects in y - props = regionprops(y[..., 0], cache=False) + if crop_mode == 'resize': + props = regionprops(y[..., 0], cache=False) + for i, prop in enumerate(props): + # Get label labels[i] = prop.label @@ -450,23 +472,43 @@ def get_image_features(X, y, appearance_dim=32): ]) morphologies[i] = morphology - # Get appearance - minr, minc, maxr, maxc = prop.bbox - appearance = np.copy(X[minr:maxr, minc:maxc, :]) - resize_shape = (appearance_dim, appearance_dim) - appearance = resize(appearance, resize_shape) - appearances[i] = appearance - - # Get adjacency matrix - # distance = cdist(centroids, centroids, metric='euclidean') < distance_threshold - # adj_matrix = distance.astype('float32') + if crop_mode == 'resize': + # Get appearance + minr, minc, maxr, maxc = prop.bbox + appearance = np.copy(X[minr:maxr, minc:maxc, :]) + resize_shape = (appearance_dim, appearance_dim) + appearance = resize(appearance, resize_shape) + appearances[i] = appearance + + if crop_mode == 'fixed': + cent = np.array(prop.centroid) + delta = appearance_dim // 2 + minr = int(cent[0]) - delta + maxr = int(cent[0]) + delta + minc = int(cent[1]) - delta + maxc = int(cent[1]) + delta + + app = np.copy(X_padded[minr:maxr, minc:maxc, :]) + label = np.copy(y_padded[minr:maxr, minc:maxc]) + + if norm: + # Use label as a mask to zero out non-label information + app = app * (label == prop.label) + idx = np.nonzero(app) + + # Check data and normalize + if len(idx) > 0: + mean = np.mean(app[idx]) + std = np.std(app[idx]) + app[idx] = (app[idx] - mean) / std + + appearances[i] = app return { 'appearances': appearances, 'centroids': centroids, 'labels': labels, 'morphologies': morphologies, - # 'adj_matrix': adj_matrix, } diff --git a/deepcell_tracking/utils_test.py b/deepcell_tracking/utils_test.py index fd69c2b..0f040d2 100644 --- a/deepcell_tracking/utils_test.py +++ b/deepcell_tracking/utils_test.py @@ -367,6 +367,13 @@ def test_get_image_features(self): assert labels.shape == expected_shape np.testing.assert_array_equal(labels, np.array(list(range(1, num_labels + 1)))) + # test appearance - fixed crop + features = utils.get_image_features(X, y, appearance_dim, + crop_mode='fixed', norm=True) + appearances = features['appearances'] + expected_shape = (num_labels, appearance_dim, appearance_dim, X.shape[-1]) + assert appearances.shape == expected_shape + def test_trks_stats(self): # Test bad extension with pytest.raises(ValueError):