diff --git a/use-cases/zebra2horse/dataloader.py b/use-cases/zebra2horse/dataloader.py index 25d2914c..ff414495 100644 --- a/use-cases/zebra2horse/dataloader.py +++ b/use-cases/zebra2horse/dataloader.py @@ -44,25 +44,25 @@ def preproc_test_fn(img, label): # Apply the preprocessing operations to the training data train_horses = ( train_horses.map(preproc_train_fn, num_parallel_calls=tf.data.AUTOTUNE) - .cache() .shuffle(self.buffer_size) + .cache() ) train_zebras = ( train_zebras.map(preproc_train_fn, num_parallel_calls=tf.data.AUTOTUNE) - .cache() .shuffle(self.buffer_size) + .cache() ) # Apply the preprocessing operations to the test data test_horses = ( test_horses.map(preproc_test_fn, num_parallel_calls=tf.data.AUTOTUNE) - .cache() .shuffle(self.buffer_size) + .cache() ) test_zebras = ( test_zebras.map(preproc_test_fn, num_parallel_calls=tf.data.AUTOTUNE) - .cache() .shuffle(self.buffer_size) + .cache() ) return tf.data.Dataset.zip((train_horses, train_zebras)), tf.data.Dataset.zip((test_horses, test_zebras))