Print warning when using tf.keras.Model.fit()
with a tf.data
dataset and batch_size
argument is supplied
#96
Labels
tf.keras.Model.fit()
with a tf.data
dataset and batch_size
argument is supplied
#96
System information.
TensorFlow version (you are using):
2.15.0
Are you willing to contribute it (Yes/No) : No
Describe the feature and the current behavior/state.
The documentation for
tf.keras.Model.fit()
says, "do not specify thebatch_size
if your data is in the form of datasets [...] (since they generate batches)."However, rather than printing a warning if
batch_size
is specified, the argument is silently ignored. If the.batch()
function is also not used on the dataset object,fit()
prints a cryptic error message about inconsistent dimensions in the model, and does not say anything about missing batch size or an ignored argument.I think it's uncommon for an API to silently ignore arguments depending on the datatype of the input, and it took me a few hours to figure out the solution. I think it could improve the UX for new Keras users in the future if a helpful warning was printed in this case.
The message could say something like,
Warning: batch_size is ignored if x has type tf.data.Dataset. Please use Dataset.batch() to set the batch size
.Will this change the current api? How?
No
Who will benefit from this feature?
New users of Keras and ones who are migrating from using x, y tensors to a
tf.data
dataset object for training.Contributing
The text was updated successfully, but these errors were encountered: