Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could you outline how to write a simplest RNN Module? #31

Open
jli05 opened this issue Sep 7, 2020 · 9 comments
Open

Could you outline how to write a simplest RNN Module? #31

jli05 opened this issue Sep 7, 2020 · 9 comments
Assignees
Labels
question Further information is requested

Comments

@jli05
Copy link

jli05 commented Sep 7, 2020

I'm looking for write a basic RNN that does f(Ax+b) at each time step.

What would be the best way to go about it? Could you outline some code to give an idea?

Can one apply JIT over the entire (unrolled) network for training/inference?

@kihyuks
Copy link
Collaborator

kihyuks commented Sep 7, 2020

Hi, would you take a look at our code example (https://github.com/google/objax/blob/master/examples/rnn/shakespeare.py) and see if this resolves your concern?

@NTT123
Copy link

NTT123 commented Sep 10, 2020

@jli05, I would recommend you to look at dm-haiku recurrent module (note that the syntax is a bit difference):

https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py

It includes VanillaRNN module, and the two important functions: static_unroll, dynamic_unroll that will help you to JIT the whole sequence in training.

You can also see how to unroll sampling process at https://github.com/deepmind/dm-haiku/blob/master/examples/rnn/train.py#L98

If you want to know more about how to unroll a loop, you can also look at the jax flow control ops (eg., scan and fori_loop) at
https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators

@jli05 jli05 closed this as completed Sep 11, 2020
@jli05
Copy link
Author

jli05 commented Sep 11, 2020

Thanks. Yes the references helped.

The Shakespeare code trains slowly on MacBook CPU. It takes >5 minutes to train 10 epochs. I killed it before it ran to the end.

I also found one can pickle dump a jit-tered network but cannot pickle load it. Maybe jittered code was never intended for pickling.

@jli05 jli05 reopened this Sep 11, 2020
@kihyuks
Copy link
Collaborator

kihyuks commented Sep 13, 2020

@aterzis-google Any help would be appreciated !

@aterzis-google aterzis-google self-assigned this Sep 20, 2020
@aterzis-google
Copy link
Collaborator

I am working on this and will have an update by Monday (Sep 28).

@jli05
Copy link
Author

jli05 commented Sep 24, 2020

If you could share any lessons/findings about the JAX framework itself through your work it'd be much appreciated.

JAX is apparently faster than TensorFlow. But it's quite young. I read some internal code of JAX and the Issues on the JAX repo. Still undecided whether to rewrite a TensorFlow application in JAX.

@david-berthelot david-berthelot added the question Further information is requested label Sep 24, 2020
@aterzis-google
Copy link
Collaborator

@jli05 I can say that many projects in Google Research use JAX and the frameworks built on top of it. If you can say a little more about your project then maybe I can give some more direct guidance.

I also wanted to ask you about your original question about speed of training.

*Did you consider using Colab which has support for GPUs?
*Also, have you seen the section about saving and loading model weights? about https://objax.readthedocs.io/en/latest/advanced/io.html#saving-and-loading-model-weights

@jli05
Copy link
Author

jli05 commented Sep 30, 2020

Yes I read about that documentation.

The model we're making is relatively small but need be distributed out and run many times for training and inference. What makes JAX and its derived frameworks attractive is that it is compact, small size, relatively quick load time, has no minimal hardware requirement for development. It sticks to its job.

jax package is about 240k, jaxlib is 40M. If jaxlib becomes optional that'd be more fantastic. I didn't give priority to exploring a 200M solution based on other frameworks, if it can be done by a 200k one.

Currently we're just trying to make things right on CPU. We were concerned about frequent data I/O so didn't start to benchmark on more advanced hardware. Just getting basic things right at this stage as anything will be parallelised on a bigger scale.

@aterzis-google
Copy link
Collaborator

@jli05 also see #97 for ongoing work to refactor RNN

@aterzis-google aterzis-google linked a pull request Dec 17, 2020 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants