Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LMU allows the input to have ndim>3 (and then fails later) #32

Open
arvoelke opened this issue Feb 5, 2021 · 0 comments
Open

LMU allows the input to have ndim>3 (and then fails later) #32

arvoelke opened this issue Feb 5, 2021 · 0 comments

Comments

@arvoelke
Copy link
Contributor

arvoelke commented Feb 5, 2021

Minimal reproducer:

inp = tf.keras.layers.Input(shape=(1000, 8, 128))

out = keras_lmu.LMU(
    memory_d=128,
    order=6,
    theta=10,
    hidden_cell=tf.keras.layers.Dense(100),
)(inp)

model = tf.keras.Model(inp, out)
model.summary()

Output:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_0 (InputLayer)         [(None, 1000, 8, 128)]    0         
_________________________________________________________________
lmu_0 (LMU)                  (None, 100)               93326     
=================================================================
Total params: 93,326
Trainable params: 93,284
Non-trainable params: 42

It looks like the extra dimensions are being collapsed somewhere internally through a reshape, but attempting to use the model can result in shaping errors.

If you try this same thing with other RNNs, such as out = tf.keras.layers.SimpleRNN(128)(inp) then you will get the error message:

ValueError: Input 0 of layer simple_rnn_0 is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [200, 1000, 8, 128]

I think we should perform the same kind of validation to disallow ndim>3 explicitly (unless of course there's a way to make this work correctly)?

@arvoelke arvoelke changed the title LMU allows the input to have ndim>3 LMU allows the input to have ndim>3 (and then fails later) Apr 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

1 participant