diff --git a/deepcell/data/tracking.py b/deepcell/data/tracking.py index 1d9f1817..9e56d3f4 100644 --- a/deepcell/data/tracking.py +++ b/deepcell/data/tracking.py @@ -405,6 +405,33 @@ def random_translate(X, y, range=512): return X, y +def random_reflect(X, y): + horizontal = tf.random.uniform([1], 0, 2, dtype=tf.int32) + vertical = tf.random.uniform([1], 0, 2, dtype=tf.int32) + print(horizontal, vertical) + + appearances = X['appearances'] + centroids = X['centroids'] + + old_shape = tf.shape(appearances) + new_shape = [-1, old_shape[2], old_shape[3], old_shape[4]] + img = tf.reshape(appearances, new_shape) + + if horizontal == 1: + centroids[:, :, 1] = centroids[:, :, 1] * -1 + img = tf.image.flip_left_right(img) + + if vertical == 1: + centroids[:, :, 0] = centroids[:, :, 0] * -1 + img = tf.image.flip_up_down(img) + + img = tf.reshape(img, old_shape) + X['appearances'] = img + X['centroids'] = centroids + + return X, y + + def prepare_dataset(track_info, batch_size=32, buffer_size=256, seed=None, track_length=8, rotation_range=0,