forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ZipTableOneToMany.lua
37 lines (32 loc) · 1.1 KB
/
ZipTableOneToMany.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
local ZipTableOneToMany, parent = torch.class('nn.ZipTableOneToMany', 'nn.Container')
-- based on ZipTable in dpnn
-- input : { v, {a, b, c} }
-- output : { {v,a}, {v,b}, {v,c} }
function ZipTableOneToMany:__init()
parent.__init(self)
self.output = {}
self.gradInput = {}
-- make buffer to update during forward/backward
self.gradInputEl = torch.Tensor()
end
function ZipTableOneToMany:updateOutput(input)
assert(#input == 2, "input must be table of element and table")
local inputEl, inputTable = input[1], input[2]
self.output = {}
for i,v in ipairs(inputTable) do
self.output[i] = {inputEl, v}
end
return self.output
end
function ZipTableOneToMany:updateGradInput(input, gradOutput)
assert(#input == 2, "input must be table of element and table")
local inputEl, inputTable = input[1], input[2]
self.gradInputEl:resizeAs(inputEl):zero()
local gradInputTable = {}
for i,gradV in ipairs(gradOutput) do
self.gradInputEl:add(gradV[1])
gradInputTable[i] = gradV[2]
end
self.gradInput = {self.gradInputEl, gradInputTable}
return self.gradInput
end