You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
localnet=nn.Sequencer(
nn.Sequential()
:add(nn.LSTM(100,100))
:add(nn.Linear(100,100))
:add(nn.LSTM(100,100))
)
localinputs= {}
fori=1,10dotable.insert(inputs, torch.randn(100))
endnet:forward(inputs) -- This should create 10 clones of my networknet:clearState()
The last line will clear all 10 clones of nn.Sequential. Every one of these clones will also clear all the nn.LSTM clones. Since the nn.LSTM module takes care of its clones internally, we are clearing the same 10 clones 10 times.
By adding a few prints in AbstractRecurrent:clearState() we get something like this:
clearState nn.Recursor @ nn.Sequential {
[input -> (1) -> (2) -> (3) -> output]
(1): nn.LSTM(100 -> 100)
(2): nn.Linear(100 -> 100)
(3): nn.LSTM(100 -> 100)
}
clearState nn.LSTM(100 -> 100)
cleared clone 1 in 0.00020408630371094
cleared clone 2 in 0.00022411346435547
cleared clone 3 in 0.0002291202545166
cleared clone 4 in 0.00021004676818848
cleared clone 5 in 0.00021100044250488
cleared clone 6 in 0.00018811225891113
cleared clone 7 in 0.00021791458129883
cleared clone 8 in 0.0002140998840332
cleared clone 9 in 0.00021195411682129
cleared clone 10 in 0.00020408630371094
clearState nn.LSTM(100 -> 100)
cleared clone 1 in 0.0001978874206543
cleared clone 2 in 0.00060486793518066
cleared clone 3 in 0.00049901008605957
cleared clone 4 in 0.0002589225769043
cleared clone 5 in 0.00022697448730469
cleared clone 6 in 0.00019097328186035
cleared clone 7 in 0.00020694732666016
cleared clone 8 in 0.00022196769714355
cleared clone 9 in 0.00023078918457031
cleared clone 10 in 0.00024318695068359
cleared clone 1 in 0.0056848526000977 <-- The first nn.Sequential clone
...
clearState nn.LSTM(100 -> 100)
cleared clone 1 in 0.00015807151794434
cleared clone 2 in 0.00019478797912598
cleared clone 3 in 0.00017786026000977
cleared clone 4 in 0.00020194053649902
cleared clone 5 in 0.00017094612121582
cleared clone 6 in 0.00017809867858887
cleared clone 7 in 0.00016403198242188
cleared clone 8 in 0.00015807151794434
cleared clone 9 in 0.00016117095947266
cleared clone 10 in 0.00016188621520996
clearState nn.LSTM(100 -> 100)
cleared clone 1 in 0.00016117095947266
cleared clone 2 in 0.00016403198242188
cleared clone 3 in 0.00015997886657715
cleared clone 4 in 0.00016307830810547
cleared clone 5 in 0.00016498565673828
cleared clone 6 in 0.00015807151794434
cleared clone 7 in 0.00015902519226074
cleared clone 8 in 0.00016093254089355
cleared clone 9 in 0.00016188621520996
cleared clone 10 in 0.00016617774963379
cleared clone 10 in 0.0038068294525146 <-- The last nn.Sequential clone
This might not seem like a big deal but that means #clones x #clones x 2 (there are 2 LSTMs in this example) calls to clearState. When dealing with longer sequences like documents, this can take a very long time to finish. I sometimes have sequences of 10k inputs (I'm experimenting with stuff...) which means 10k*10k calls taking each ~0.0002 seconds which is roughly 5.5 hours only to do clearState() before saving the model to disk.
Because nn.LSTM manages its clones internally and is contained inside the nn.Sequential, the same clones are being cleared again and again as I explained at the beginning. Is there a way I could clear those LSTMs only once effectively reducing the number of calls from O(n^2) to O(n)?
The text was updated successfully, but these errors were encountered:
I just found out that this actually seems to happen for a few other methods that iterate through clones with AbstractRecurrent:includingSharedClones(f). In order to reproduce this issue, all you need is to wrap a nn.Container containing any nn.AbstractRecurrent module with nn.Recursor.
Actually, why is clearState not removing all clones? After all, clones are mostly used to keep those extra output and gradInput buffers. Calling clearState IMO means: "Remove all buffers from memory. I will reallocate afterwards if necessary."
Example:
The last line will clear all 10 clones of
nn.Sequential
. Every one of these clones will also clear all thenn.LSTM
clones. Since thenn.LSTM
module takes care of its clones internally, we are clearing the same 10 clones 10 times.By adding a few prints in
AbstractRecurrent:clearState()
we get something like this:This might not seem like a big deal but that means #clones x #clones x 2 (there are 2 LSTMs in this example) calls to clearState. When dealing with longer sequences like documents, this can take a very long time to finish. I sometimes have sequences of 10k inputs (I'm experimenting with stuff...) which means 10k*10k calls taking each ~0.0002 seconds which is roughly 5.5 hours only to do clearState() before saving the model to disk.
Because
nn.LSTM
manages its clones internally and is contained inside thenn.Sequential
, the same clones are being cleared again and again as I explained at the beginning. Is there a way I could clear those LSTMs only once effectively reducing the number of calls from O(n^2) to O(n)?The text was updated successfully, but these errors were encountered: