-
Notifications
You must be signed in to change notification settings - Fork 313
/
SeqBRNN.lua
77 lines (65 loc) · 2.58 KB
/
SeqBRNN.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
------------------------------------------------------------------------
--[[ SeqBRNN ]] --
-- Bi-directional RNN using two SeqLSTM modules.
-- Input is a tensor e.g time x batch x inputdim.
-- Output is a tensor of the same length e.g time x batch x outputdim.
-- Applies a forward rnn to input tensor in forward order
-- and applies a backward rnn in reverse order.
-- Reversal of the sequence happens on the time dimension.
-- For each step, the outputs of both rnn are merged together using
-- the merge module (defaults to nn.CAddTable() which sums the activations).
------------------------------------------------------------------------
local SeqBRNN, parent = torch.class('nn.SeqBRNN', 'nn.Container')
function SeqBRNN:__init(inputDim, hiddenDim, batchFirst, merge)
self.forwardModule = nn.SeqLSTM(inputDim, hiddenDim)
self.backwardModule = nn.SeqLSTM(inputDim, hiddenDim)
self.merge = merge
if not self.merge then
self.merge = nn.CAddTable()
end
self.dim = 1
local backward = nn.Sequential()
backward:add(nn.SeqReverseSequence(self.dim)) -- reverse
backward:add(self.backwardModule)
backward:add(nn.SeqReverseSequence(self.dim)) -- unreverse
local concat = nn.ConcatTable()
concat:add(self.forwardModule):add(backward)
local brnn = nn.Sequential()
brnn:add(concat)
brnn:add(self.merge)
if(batchFirst) then
-- Insert transposes before and after the brnn.
brnn:insert(nn.Transpose({1, 2}), 1)
brnn:insert(nn.Transpose({1, 2}))
end
parent.__init(self)
self.output = torch.Tensor()
self.gradInput = torch.Tensor()
self.module = brnn
-- so that it can be handled like a Container
self.modules[1] = brnn
end
function SeqBRNN:updateOutput(input)
self.output = self.module:updateOutput(input)
return self.output
end
function SeqBRNN:updateGradInput(input, gradOutput)
self.gradInput = self.module:updateGradInput(input, gradOutput)
return self.gradInput
end
function SeqBRNN:accGradParameters(input, gradOutput, scale)
self.module:accGradParameters(input, gradOutput, scale)
end
function SeqBRNN:accUpdateGradParameters(input, gradOutput, lr)
self.module:accUpdateGradParameters(input, gradOutput, lr)
end
function SeqBRNN:sharedAccUpdateGradParameters(input, gradOutput, lr)
self.module:sharedAccUpdateGradParameters(input, gradOutput, lr)
end
function SeqBRNN:__tostring__()
if self.module.__tostring__ then
return torch.type(self) .. ' @ ' .. self.module:__tostring__()
else
return torch.type(self) .. ' @ ' .. torch.type(self.module)
end
end