-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
start to add keras support #2
base: main
Are you sure you want to change the base?
Conversation
Hi Marius, Thank you for writing this out so quickly! I've looked through your extension and was surprised it was that straightforward to implement. I'd like to start a discussion on the best strategy to further add support for Keras. To start, it seems we could rewrite the entire library to only depend on Keras 3. It would then no longer be necessary to even have PyTorch installed. Certainly, all need to distinguish between backends with the many 'if' cases would be gone. E.g. Keras 3 has its own call operation, so it would clean up the issue you mentioned with the Statistic class. The code would then even work with Jax 'without any extra cost'. However, going full engine-agnostic (with Keras 3) would have some drawbacks:
Your current implementation is more of a middle-ground: we stay PyTorch-focused, but optionally switch to Tensorflow where useful. In the short-term, I like that idea much more! I do wonder if you could write it with keras.ops instead, and if there is a clean way to avoid having to have keras installed when running the code with pytorch only? Also: what would the benefit be of this approach vs ONNX? Would love to hear your thoughts! |
Thank you for your answer, I am also thinking in the short-term lets continue with what I implemented and I try to go with keras.ops so the installation of keras is not needed. For the long term:
We could also focus on using keras for everything. Since in keras 3 the backend can be JAX, TensorFlow, or PyTorch. So the user can write a pytorch model and the framework compiles it to keras runs it (see example here).
ONNX can be nice as well. The framework gets a model (tensorflow, pytorch etc.) we convert it to ONNX and then we convert it to one backend which we use to train it (e.g. pytorch?). And then we have to convert it back again. But I am not 100% convinced from this idea. But in general supporting ONNX would be nice to have and should be also easy to implement. |
I agree using keras.ops from Keras 3 is the way to go over tensorflow at the moment. For the short term, I have been wondering what the best strategy is in terms of software architecture. Given that research in ML is mostly done in PyTorch, it would be nice if 1) all code is readable in pure PyTorch code (without the clutter of the many if-else statements). However, it would also be nice if 2) all code is executable in any backend library (e.g. using Keras 3). I would therefore propose the following: we keep the library in pure PyTorch, but add a Of course the biggest drawback of this solution is that it involves extensive code duplication (though we could import some backend-agnostic code from the non-keras modules). It's worth keeping in mind that a better solution should be found for the long term, e.g. maybe future libraries more commonly code everything in keras.ops, or they code everything in a different interface like onnx.ops. However, I think the duplication approach above is maintainable for the moment. What do you think? Would this work? |
Yes I think that's the best to do. I was also thinking of having a separated module. Then let's do it like this. I can start working on this in the next week. |
Great, thank you!! |
As discussed this week at EWAF24 I started to add keras support for the libary. For now this PR should just show how it would look and what changes are needed to add this support. I think it should be easy to add but it would need some if cases here and there.
For now one example works for a sequential keras model (see
test_keras.py
). It feels a bit like a "hack" since one has to override the__call__
function of the torch.nn.Module. This also leads to some logging errors/warnings in keras. Maybe some kind of "conditionally inherits" of theclass Statistic(abc.ABC, torch.nn.Module):
would be cleaner. Also some conditional typing likevalue: torch.Tensor | tf.Tensor
would be needed.Looking forward for your feedback.