Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move RNN to layers.py and make it stateless. #97

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

aterzis-google
Copy link
Collaborator

No description provided.

objax/nn/layers.py Outdated Show resolved Hide resolved
@@ -327,6 +327,63 @@ def __call__(self, x: JaxArray) -> JaxArray:
self.avg.value += (self.avg.value - x) * (self.momentum - 1)
return self.avg.value

class RNN(Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name RNN is too generic.
Pretty much any type of recurrent block (LSTM, GRU, ....) could be called RNN.
Is there some better way to call it?

Copy link
Contributor

@david-berthelot david-berthelot Oct 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also RNN refers to the architecture, not to the cell. Here's what TF/Keras does https://www.tensorflow.org/api_docs/python/tf/compat/v1/nn/rnn_cell/RNNCell
Not sure what PyTorch does.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a specific RNN architecture that operates across time (so not a cell). I would call this something like SimpleRNN; and make sure it replicates keras' SimpleRNN functionality with default arguments:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/SimpleRNN

RNN you could reserve as an object that takes an RNNCell and performs a scan across time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the name to SimpleRNN


self.output_layer = Linear(self.nstate, self.num_outputs)

def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest adding a get_initial_state method and optional initial_state argument here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an optional initial_state argument to the call() method.

Can you clarify what the get_initial_state() method would do, considering that the state is initialized during every call() (unless explicitly passed in through the optional argument)?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two reasons to have a get_initial_state: One, the caller wants to know if this layer is recurrent, without checking for some general instance type. Two, the caller wants to know the shapes etc of the state, without running __call__. This is useful for many reasons, like creating buffers for storing state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, does get_init_state really act like a create_init_state? Or is there an init_state stored inside the instance?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no; it's a purely functional thing that returns some arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understood from some of the Keras code, get_initial_state simply returns zero array of appropriate shape (ex: https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/layers/recurrent.py#L1948 )

It's still not very clear how useful it is.
Could you point us to some example of how it's actually used (either in Tensorflow or any other framework)?

To know shape of the state it would be better to just call rnn_cell_layer.nstate or maybe have helper method get_state_shape.
Using get_initial_state as a way to determine whether layer is RNN seems like a little weird. I don't see how getattr(layer, 'get_initial_state') is better than isinstance(layer, RNNCell). If there is a need to determine whether layer is RNN cell, I think it's better just to make all RNN cells to inherit from some base class and do isinstance check.

objax/nn/layers.py Outdated Show resolved Hide resolved
only_return_final: return only the last output if ``True``, or all output otherwise.`

Returns:
Output tensor with dimensions ``N * batch_size, vocabulary_size``.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is vocabulary_size the right terminology for RNNs? perhaps you mean nout here?

Also why is batch_size included here? I thought you don't consider batch_size in these layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed vocabulary_size -> nout

I include batch_size because we can process a batch of input data.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-berthelot do other layers "know" about batch dimensions? does this one need to?

Copy link

@ebrevdo ebrevdo Nov 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(from david on another PR: no, layers don't know about batch dimensions, so this one shouldn't either. instead, add a unit test with this object and Vectorized)

Copy link

@ebrevdo ebrevdo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a unit test.

Copy link

@jli05 jli05 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What the RNN stands out by in this lib for me is the code readability and simplicity. Any person can easily extend it.

objax/nn/layers.py Show resolved Hide resolved
Comment on lines +375 to +392
jn.dot(x, self.w_xh.value)
+ jn.dot(state, self.w_hh.value)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_inputs could be zero. -- Essentially empty inputs but internal states continue to evolve along time.

Not sure if we shall use two weight matrices or one to act on concatenated [h, x].

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically it's more efficient to act on one concatenated [h, x], but depends on the system and sizes. At some point you can make this an __init__ mode parameter like Keras does. For now I'd suggest using the concatenated format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another nit, use x.dot(y) rather than jn.dot(x, y) since we might as well take advantage of object oriented APIs.

+ jn.dot(state, self.w_hh.value)
+ self.b_h.value
)
y = self.output_layer(state)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need output_layer or can we directly return internal states h and let user do further transform on that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for having an output_layer

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question why: this is something the user can do themselves after, right? So is there any purpose to add an output_layer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would drop the output layer, that's forcing a decision on the user about what type of output they'd want.

@google-cla
Copy link

google-cla bot commented Oct 28, 2020

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

1 similar comment
@google-cla
Copy link

google-cla bot commented Oct 28, 2020

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@ebrevdo
Copy link

ebrevdo commented Oct 29, 2020 via email


if only_return_final:
return y, state
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for else.

if only_return_final:
return y, state
else:
return jn.concatenate(outputs, axis=0), state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be jn.stack?

Comment on lines +375 to +392
jn.dot(x, self.w_xh.value)
+ jn.dot(state, self.w_hh.value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another nit, use x.dot(y) rather than jn.dot(x, y) since we might as well take advantage of object oriented APIs.

Comment on lines +368 to +369
def __call__(self, inputs: JaxArray, initial_state: JaxArray = None,
only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One argument per line if they don't all fit on one line.



def loss(x, label): # sum(label * log(softmax(logit)))
logit = model(x)
return objax.functional.loss.cross_entropy_logits(logit, label).mean()
logits, _ = model(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits = model(x)[0]

outputs = [vocab[prefix[0]]]
get_input = lambda: one_hot(jn.array([outputs[-1]]).reshape(1, 1), len(vocab))
for y in prefix[1:]: # Warmup state with prefix
model(get_input())
outputs.append(vocab[y])
for _ in range(num_predicts): # Predict num_predicts steps
Y = model(get_input())
Y, _ = model(get_input())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Uppercase are for global constants, use lower case identifiers for variables please.
  2. Also rather than doing two assigns, the better way is to just assign what you use.
    Y = model(get_input())[0]

Comment on lines +25 to +29
<<<<<<< HEAD:examples/text_generation/shakespeare_rnn.py
from objax.nn import SimpleRNN
=======
from objax.nn import RNN
>>>>>>> 2c04d4e (Move RNN to layers.py and make it stateless.):examples/rnn/shakespeare.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your commit contains an unresolved merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants