Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print warning when using tf.keras.Model.fit() with a tf.data dataset and batch_size argument is supplied #96

Open
maxfisher-g opened this issue Aug 15, 2023 · 4 comments
Assignees

Comments

@maxfisher-g
Copy link

maxfisher-g commented Aug 15, 2023

System information.

TensorFlow version (you are using): 2.15.0
Are you willing to contribute it (Yes/No) : No

Describe the feature and the current behavior/state.

The documentation for tf.keras.Model.fit() says, "do not specify the batch_size if your data is in the form of datasets [...] (since they generate batches)."

However, rather than printing a warning if batch_size is specified, the argument is silently ignored. If the .batch() function is also not used on the dataset object, fit() prints a cryptic error message about inconsistent dimensions in the model, and does not say anything about missing batch size or an ignored argument.

I think it's uncommon for an API to silently ignore arguments depending on the datatype of the input, and it took me a few hours to figure out the solution. I think it could improve the UX for new Keras users in the future if a helpful warning was printed in this case.

The message could say something like, Warning: batch_size is ignored if x has type tf.data.Dataset. Please use Dataset.batch() to set the batch size.

Will this change the current api? How?

No

Who will benefit from this feature?
New users of Keras and ones who are migrating from using x, y tensors to a tf.data dataset object for training.

Contributing

  • Do you want to contribute a PR? (yes/no): No
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing):
@maxfisher-g
Copy link
Author

maxfisher-g commented Aug 16, 2023

Note: this also applies to Model.evaluate() and Model.predict()

@fchollet
Copy link
Member

Thanks for the suggestion. batch_size has a value by default, so if we were to print a warning if it is passed with a tf.data.Dataset object, we would need to special case the default value (32) and not print the warning in this case.

We could also batch the dataset if it isn't batched, but detecting that a dataset isn't batched can be quite tricky.

@maxfisher-g
Copy link
Author

Hm, it does seem like detecting whether the dataset is batched or not could give a better indication of when a warning should be printed.

Another thought - can the default value for batch_size be initialised only if the type of x is not in the form of a dataset, generator, or keras.utils.Sequence?

@maxfisher-g
Copy link
Author

We could also batch the dataset if it isn't batched, but detecting that a dataset isn't batched can be quite tricky.

I'm sure I'm missing something, but I saw the function keras.utils.is_batched(dataset) while reading some other source code in the same file. Could this be a way of detecting whether the passed dataset is batched?

@fchollet fchollet self-assigned this Aug 17, 2023
@tilakrayal tilakrayal removed their assignment Aug 25, 2023
@sachinprasadhs sachinprasadhs transferred this issue from keras-team/keras Sep 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants