diff --git a/ai/src/itwinai/backend/tensorflow/trainer.py b/ai/src/itwinai/backend/tensorflow/trainer.py index 913dd19a..7b430866 100644 --- a/ai/src/itwinai/backend/tensorflow/trainer.py +++ b/ai/src/itwinai/backend/tensorflow/trainer.py @@ -53,8 +53,8 @@ def train(self, data): train, test = data # Set batch size to the dataset - train = train.batch(self.batch_size * self.num_devices, drop_remainder=True) - test = test.batch(self.batch_size * self.num_devices, drop_remainder=True) + train = train.batch(self.batch_size, drop_remainder=True) + test = test.batch(self.batch_size, drop_remainder=True) # Number of samples n_train = train.cardinality().numpy()