-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Could you outline how to write a simplest RNN Module? #31
Comments
Hi, would you take a look at our code example (https://github.com/google/objax/blob/master/examples/rnn/shakespeare.py) and see if this resolves your concern? |
@jli05, I would recommend you to look at https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py It includes You can also see how to unroll sampling process at https://github.com/deepmind/dm-haiku/blob/master/examples/rnn/train.py#L98 If you want to know more about how to unroll a loop, you can also look at the jax flow control ops (eg., |
Thanks. Yes the references helped. The Shakespeare code trains slowly on MacBook CPU. It takes >5 minutes to train 10 epochs. I killed it before it ran to the end. I also found one can pickle dump a jit-tered network but cannot pickle load it. Maybe jittered code was never intended for pickling. |
@aterzis-google Any help would be appreciated ! |
I am working on this and will have an update by Monday (Sep 28). |
If you could share any lessons/findings about the JAX framework itself through your work it'd be much appreciated. JAX is apparently faster than TensorFlow. But it's quite young. I read some internal code of JAX and the |
@jli05 I can say that many projects in Google Research use JAX and the frameworks built on top of it. If you can say a little more about your project then maybe I can give some more direct guidance. I also wanted to ask you about your original question about speed of training. *Did you consider using Colab which has support for GPUs? |
Yes I read about that documentation. The model we're making is relatively small but need be distributed out and run many times for training and inference. What makes JAX and its derived frameworks attractive is that it is compact, small size, relatively quick load time, has no minimal hardware requirement for development. It sticks to its job.
Currently we're just trying to make things right on CPU. We were concerned about frequent data I/O so didn't start to benchmark on more advanced hardware. Just getting basic things right at this stage as anything will be parallelised on a bigger scale. |
I'm looking for write a basic RNN that does f(Ax+b) at each time step.
What would be the best way to go about it? Could you outline some code to give an idea?
Can one apply JIT over the entire (unrolled) network for training/inference?
The text was updated successfully, but these errors were encountered: