Skip to content

Commit

Permalink
Add random reflection augmentation for tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Mar 6, 2024
1 parent e6d7af8 commit ae1d345
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions deepcell/data/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,33 @@ def random_translate(X, y, range=512):
return X, y


def random_reflect(X, y):
horizontal = tf.random.uniform([1], 0, 2, dtype=tf.int32)
vertical = tf.random.uniform([1], 0, 2, dtype=tf.int32)
print(horizontal, vertical)

appearances = X['appearances']
centroids = X['centroids']

old_shape = tf.shape(appearances)
new_shape = [-1, old_shape[2], old_shape[3], old_shape[4]]
img = tf.reshape(appearances, new_shape)

if horizontal == 1:
centroids[:, :, 1] = centroids[:, :, 1] * -1
img = tf.image.flip_left_right(img)

if vertical == 1:
centroids[:, :, 0] = centroids[:, :, 0] * -1
img = tf.image.flip_up_down(img)

img = tf.reshape(img, old_shape)
X['appearances'] = img
X['centroids'] = centroids

return X, y


def prepare_dataset(track_info,
batch_size=32, buffer_size=256,
seed=None, track_length=8, rotation_range=0,
Expand Down

0 comments on commit ae1d345

Please sign in to comment.