-
Notifications
You must be signed in to change notification settings - Fork 313
/
init.lua
66 lines (54 loc) · 1.86 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
require 'dpnn'
require 'torchx'
dpnn.version = dpnn.version or 0
assert(dpnn.version > 1, "Please update dpnn : luarocks install dpnn")
-- create global rnn table:
rnn = {}
rnn.version = 2
rnn.version = 2.1 -- [get,set][Grad]HiddenState(step)
unpack = unpack or table.unpack
torch.include('rnn', 'recursiveUtils.lua')
-- extensions to nn.Module
torch.include('rnn', 'Module.lua')
-- override nn.Dropout
torch.include('rnn', 'Dropout.lua')
-- for testing:
torch.include('rnn', 'test/test.lua')
torch.include('rnn', 'test/bigtest.lua')
-- support modules
torch.include('rnn', 'ZeroGrad.lua')
torch.include('rnn', 'LinearNoBias.lua')
torch.include('rnn', 'SAdd.lua')
torch.include('rnn', 'CopyGrad.lua')
-- recurrent modules
torch.include('rnn', 'LookupTableMaskZero.lua')
torch.include('rnn', 'MaskZero.lua')
torch.include('rnn', 'TrimZero.lua')
torch.include('rnn', 'AbstractRecurrent.lua')
torch.include('rnn', 'Recurrent.lua')
torch.include('rnn', 'LSTM.lua')
torch.include('rnn', 'FastLSTM.lua')
torch.include('rnn', 'GRU.lua')
torch.include('rnn', 'Mufuru.lua')
torch.include('rnn', 'Recursor.lua')
torch.include('rnn', 'Recurrence.lua')
torch.include('rnn', 'NormStabilizer.lua')
-- sequencer modules
torch.include('rnn', 'AbstractSequencer.lua')
torch.include('rnn', 'Repeater.lua')
torch.include('rnn', 'Sequencer.lua')
torch.include('rnn', 'BiSequencer.lua')
torch.include('rnn', 'BiSequencerLM.lua')
torch.include('rnn', 'RecurrentAttention.lua')
-- sequencer + recurrent modules
torch.include('rnn', 'SeqLSTM.lua')
torch.include('rnn', 'SeqLSTMP.lua')
torch.include('rnn', 'SeqGRU.lua')
torch.include('rnn', 'SeqReverseSequence.lua')
torch.include('rnn', 'SeqBRNN.lua')
-- recurrent criterions:
torch.include('rnn', 'SequencerCriterion.lua')
torch.include('rnn', 'RepeaterCriterion.lua')
torch.include('rnn', 'MaskZeroCriterion.lua')
-- prevent likely name conflicts
nn.rnn = rnn