Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start to add keras support #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

makoeppel
Copy link

As discussed this week at EWAF24 I started to add keras support for the libary. For now this PR should just show how it would look and what changes are needed to add this support. I think it should be easy to add but it would need some if cases here and there.

For now one example works for a sequential keras model (see test_keras.py). It feels a bit like a "hack" since one has to override the __call__ function of the torch.nn.Module. This also leads to some logging errors/warnings in keras. Maybe some kind of "conditionally inherits" of the class Statistic(abc.ABC, torch.nn.Module): would be cleaner. Also some conditional typing like value: torch.Tensor | tf.Tensor would be needed.

Looking forward for your feedback.

@maartenbuyl
Copy link
Member

Hi Marius,

Thank you for writing this out so quickly!

I've looked through your extension and was surprised it was that straightforward to implement. I'd like to start a discussion on the best strategy to further add support for Keras.

To start, it seems we could rewrite the entire library to only depend on Keras 3. It would then no longer be necessary to even have PyTorch installed. Certainly, all need to distinguish between backends with the many 'if' cases would be gone. E.g. Keras 3 has its own call operation, so it would clean up the issue you mentioned with the Statistic class. The code would then even work with Jax 'without any extra cost'. However, going full engine-agnostic (with Keras 3) would have some drawbacks:

  1. The clarity of the code will degrade, as Keras and PyTorch have some small but important interface differences that may lead to bugs down the road. Keras also doesn't have engine-agnostic type hinting for tensors, for example (Backend-Agnostic Types keras-team/keras#19230).
  2. It's unclear to me how PyTorch-specific libraries like pytorch-lightning integrate with Keras code. Can Keras layers/models just be thrown into such pipelines without problems?
  3. The maintenance cost would inevitably increase, as the library would implicitly be promising to actually support all of PyTorch, Tensorflow and Jax. This kind of commitment right now may make it difficult to grow the library organically.

Your current implementation is more of a middle-ground: we stay PyTorch-focused, but optionally switch to Tensorflow where useful. In the short-term, I like that idea much more! I do wonder if you could write it with keras.ops instead, and if there is a clean way to avoid having to have keras installed when running the code with pytorch only? Also: what would the benefit be of this approach vs ONNX?

Would love to hear your thoughts!

@makoeppel
Copy link
Author

Thank you for your answer,

I am also thinking in the short-term lets continue with what I implemented and I try to go with keras.ops so the installation of keras is not needed.

For the long term:

To start, it seems we could rewrite the entire library to only depend on Keras 3. It would then no longer be necessary to even have PyTorch installed. Certainly, all need to distinguish between backends with the many 'if' cases would be gone. E.g. Keras 3 has its own call operation, so it would clean up the issue you mentioned with the Statistic class. The code would then even work with Jax 'without any extra cost'. However, going full engine-agnostic (with Keras 3) would have some drawbacks:

We could also focus on using keras for everything. Since in keras 3 the backend can be JAX, TensorFlow, or PyTorch. So the user can write a pytorch model and the framework compiles it to keras runs it (see example here).

Also: what would the benefit be of this approach vs ONNX?

ONNX can be nice as well. The framework gets a model (tensorflow, pytorch etc.) we convert it to ONNX and then we convert it to one backend which we use to train it (e.g. pytorch?). And then we have to convert it back again. But I am not 100% convinced from this idea. But in general supporting ONNX would be nice to have and should be also easy to implement.

@maartenbuyl
Copy link
Member

I agree using keras.ops from Keras 3 is the way to go over tensorflow at the moment.

For the short term, I have been wondering what the best strategy is in terms of software architecture. Given that research in ML is mostly done in PyTorch, it would be nice if 1) all code is readable in pure PyTorch code (without the clutter of the many if-else statements). However, it would also be nice if 2) all code is executable in any backend library (e.g. using Keras 3).

I would therefore propose the following: we keep the library in pure PyTorch, but add a fairret.keras module. We then duplicate all other code currently in fairret to fairret.keras, but rewrite it to use keras.ops. This accomplishes the two goals above, and creates a nice UX: if people are using the fairret code and want to use a different backend, they can simply append .keras to all occurences of fairret in the import statements.

Of course the biggest drawback of this solution is that it involves extensive code duplication (though we could import some backend-agnostic code from the non-keras modules). It's worth keeping in mind that a better solution should be found for the long term, e.g. maybe future libraries more commonly code everything in keras.ops, or they code everything in a different interface like onnx.ops. However, I think the duplication approach above is maintainable for the moment.

What do you think? Would this work?

@makoeppel
Copy link
Author

Yes I think that's the best to do. I was also thinking of having a separated module. Then let's do it like this. I can start working on this in the next week.

@maartenbuyl
Copy link
Member

Great, thank you!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants