Skip to content

Commit

Permalink
FIX: Distributed Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
User3574 committed Aug 9, 2023
1 parent 26c7c0f commit b5da5c3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
6 changes: 3 additions & 3 deletions ai/src/itwinai/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def __init__(
print(f"Strategy is working with: {self.num_devices} devices")

def train(self, data):
# TODO: Batch_per_worker? Validation_steps? Train_steps?
# TODO: FIX Steps sizes in model.fit
train, test = data

# Set batch size to the dataset
train = train.batch(self.batch_size * self.num_devices, drop_remainder=True)
test = test.batch(self.batch_size * self.num_devices, drop_remainder=True)
train = train.batch(self.batch_size * self.num_devices, drop_remainder=True).repeat()
test = test.batch(self.batch_size * self.num_devices, drop_remainder=True).repeat()

# Number of samples
n_train = train.cardinality().numpy()
Expand Down
4 changes: 0 additions & 4 deletions ai/src/itwinai/models/tensorflow/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,6 @@ def test_step(self, inputs):
"D_Y_loss": disc_Y_loss,
}

def call(self, inputs):
# TODO: Fix validation loss, how?
return self.gen_G(inputs, training=False)

def get_config(self):
config = super().get_config().copy()
config.update({
Expand Down

0 comments on commit b5da5c3

Please sign in to comment.