forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Module.lua
625 lines (552 loc) · 19.1 KB
/
Module.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
local _ = require 'moses'
local Module = nn.Module
function Module:sparseParameters()
return self:parameters()
end
function Module:updateParameters(learningRate)
-- sparse params can have different learningRate scales per param
local params, gradParams, scales = self:sparseParameters()
if params then
for i,param in pairs(params) do -- pairs for sparse params
local scale = scales and scales[i] or 1
param:add(-learningRate*scale, gradParams[i])
end
end
end
function Module:zeroGradParameters()
local _,gradParams = self:sparseParameters()
if gradParams then
for i,gradParam in pairs(gradParams) do -- pairs for sparse params
gradParam:zero()
end
end
end
------------------------ clone and type --------------------------------
Module.dpnn_parameters = {'weight', 'bias'}
Module.dpnn_gradParameters = {'gradWeight', 'gradBias'}
-- efficient version of :
-- clone = self:clone()
-- clone:share(self, paramNames, gradParamNames)
-- Note that this method is the very bane of my existence.
-- I have worked on it too many times...
function Module:sharedClone(shareParams, shareGradParams, stepClone)
shareParams = (shareParams == nil) and true or shareParams
shareGradParams = (shareGradParams == nil) and true or shareGradParams
if stepClone and self.dpnn_stepclone then
-- this is for AbstractRecurrent modules (in rnn)
return self
end
local pointers = {} -- to params/gradParams (dont clone params/gradParams)
local scdone = {}
-- 1. remove all params/gradParams
local function recursiveRemove(obj) -- remove modules
local moduleTree
local isTable = type(obj) == 'table'
if torch.isTypeOf(obj, 'nn.Module') then
assert(isTable)
if stepClone and obj.dpnn_stepclone then
-- this is for AbstractRecurrent modules (in rnn)
moduleTree = obj
obj = nil
isTable = false
elseif obj.dpnn_sharedClone then
-- allow to use a custom sharedClone method on one module
moduleTree = obj
obj = nil
isTable = false
elseif scdone[torch.pointer(obj)] then
moduleTree = scdone[torch.pointer(obj)]
else
-- remove the params, gradParams. Save for later.
local params = {}
if shareParams then
for i,paramName in ipairs(obj.dpnn_parameters) do
local param = obj[paramName]
if param then
params[paramName] = param
obj[paramName] = nil
if torch.isTensor(param) and param.storage and param:storage() then
pointers[torch.pointer(param:storage():data())] = true
end
end
end
end
if shareGradParams then
for i,paramName in ipairs(obj.dpnn_gradParameters) do
local gradParam = obj[paramName]
if gradParam then
params[paramName] = gradParam
obj[paramName] = nil
if torch.isTensor(gradParam) and gradParam.storage and gradParam:storage() then
pointers[torch.pointer(gradParam:storage():data())] = true
end
end
end
end
-- find all obj.attribute tensors that share storage with the shared params
for paramName, param in pairs(obj) do
if torch.isTensor(param) and param:storage() then
if pointers[torch.pointer(param:storage():data())] then
params[paramName] = param
obj[paramName] = nil
end
end
end
moduleTree = params
scdone[torch.pointer(obj)] = moduleTree
for k,v in pairs(obj) do
moduleTree[k], obj[k] = recursiveRemove(v)
end
end
elseif isTable then
if scdone[torch.pointer(obj)] then
moduleTree = scdone[torch.pointer(obj)]
else
assert(not moduleTree)
moduleTree = {}
for k,v in pairs(obj) do
moduleTree[k], obj[k] = recursiveRemove(v)
end
scdone[torch.pointer(obj)] = moduleTree
end
end
return moduleTree, obj
end
local moduleTree, original = recursiveRemove(self)
assert(original)
-- 2. clone everything but parameters, gradients and modules (removed above)
local clone = self:clone()
-- 3. add back to self/clone everything that was removed in step 1
local function recursiveSet(clone, original, moduleTree)
assert(clone)
assert(original)
if scdone[torch.pointer(original)] then
for k,param in pairs(moduleTree) do
if torch.isTypeOf(param,'nn.Module') then
if param.dpnn_sharedClone then
-- Call the custom sharedClone
clone[k] = param:dpnn_sharedClone()
else
-- AbstractRecurrent instances branch here with stepClone = true
clone[k] = param
end
original[k] = param
elseif torch.isTensor(param) then
if param.storage then
clone[k] = param.new():set(param)
original[k] = param
else -- for torch.MultiCudaTensor
clone[k] = param
original[k] = param
end
elseif type(param) == 'table' then
recursiveSet(clone[k], original[k], param)
end
end
scdone[torch.pointer(original)] = nil
end
end
recursiveSet(clone, self, moduleTree)
return clone
end
-- we override this method such that hidden modules
-- will be included in the getParameters call.
-- Hidden modules are common for recurrent modules that
-- have internal references to modules that share parameters
-- with the main modules.
-- These must also be included in the getParameters() call in order
-- to maintain shared storage for tensors.
function Module:getParameters()
local con = nn.Container()
con:add(self)
-- recursive get all modules (modules, sharedclones, etc.)
local function recursiveGetModules(tbl)
for k,m in pairs(tbl) do
if torch.isTypeOf(m, 'nn.Module') then
if not m.dpnn_getParameters_found then
con:add(m)
m.dpnn_getParameters_found = true
recursiveGetModules(m)
end
elseif torch.type(m) == 'table' then
recursiveGetModules(m)
end
end
end
recursiveGetModules(self)
for i,m in ipairs(con.modules) do
m.dpnn_getParameters_found = nil
end
-- get ALL parameters
local parameters,gradParameters = con:parameters()
return Module.flatten(parameters), Module.flatten(gradParameters)
end
----------------- serialization (see nn.Serial) -------------------
Module.dpnn_mediumEmpty = {'output', 'gradInput', 'momGradParams', 'dpnn_input'}
Module.dpnn_lightEmpty = Module.dpnn_gradParameters
-- defaults to heavy serialization
Module.dpnn_serialEmpty = {}
-- sets the serialization behavior of the entire module structure
function Module:serialMode(empty)
assert(torch.type(empty) == 'table', "Expecting table at arg 1")
self.dpnn_serialEmpty = empty
-- set the serial of all encapsulated modules
local function recursiveSerial(tbl)
for k,v in pairs(tbl) do
if torch.isTypeOf(v, 'nn.Module') then
v:serialMode(empty)
elseif torch.type(v) == 'table' then
recursiveSerial(v)
end
end
end
recursiveSerial(self)
return self
end
-- serialMode : serialize everything
function Module:heavySerial()
return self:serialMode({})
end
-- serialMode : serialize everything except dpnn_mediumEmpty attributes
function Module:mediumSerial()
self.dpnn_serialEmpty = self.dpnn_mediumEmpty
-- set the serial of all encapsulated modules
local function recursiveSerial(tbl)
for k,v in pairs(tbl) do
if torch.isTypeOf(v, 'nn.Module') then
v:mediumSerial()
elseif torch.type(v) == 'table' then
recursiveSerial(v)
end
end
end
recursiveSerial(self)
return self
end
-- serialMode : serialize everything except dpnn_mediumEmpty and dpnn_lightEmpty attributes
function Module:lightSerial()
self.dpnn_serialEmpty = _.clone(self.dpnn_mediumEmpty)
for k,v in ipairs(self.dpnn_lightEmpty) do
table.insert(self.dpnn_serialEmpty, v)
end
-- set the serial of all encapsulated modules
local function recursiveSerial(tbl)
for k,v in pairs(tbl) do
if torch.isTypeOf(v, 'nn.Module') then
v:lightSerial()
elseif torch.type(v) == 'table' then
recursiveSerial(v)
end
end
end
recursiveSerial(self)
return self
end
function Module:getSerialState(states)
states = states or {}
-- dont get the serial state of the same module twice (reuse existing)
if states[self] then
return states[self]
end
-- returns the object structure as tables (i.e. without metatables)
local function recursiveState(tbl)
local state = _.map(tbl,
function(k,v)
if torch.isTypeOf(tbl, 'nn.Module') and _.contains(tbl.dpnn_serialEmpty, k) then
-- "empties" module attributes found in empty
if torch.type(v) == 'table' then
-- empty table
return {}
elseif torch.isTensor(v) then
-- empty tensor
return v.new()
else
-- not table nor tensor? then serialize as is
return v
end
elseif torch.isTypeOf(v, 'nn.Module') then
-- recursive, yet can be overwritten
return v:getSerialState(states)
elseif torch.type(v) == 'table' then
-- in case it is a table of modules
if not states[v] then
states[v] = recursiveState(v)
end
return states[v]
else
return v
end
end
)
return state
end
local state = recursiveState(self)
-- include typename so that module can be reconstructed from the state
state.dpnn_typename = torch.type(self)
states[self] = state
return state
end
-- decorates self with nn.Serial
function Module:Serial(tensortype)
return nn.Serial(self, tensortype)
end
----------------------- for training -----------------------------
-- useful to get the output size
-- I chose this method name because it is less likely to be overriden.
function Module:outside(insize)
local input
if torch.type(insize) == 'table' then
input = torch.randn(table.unpack(insize))
else
input = torch.randn(insize)
end
local output = self:updateOutput(input)
return output:size()
end
-- for those interested in implementing the visitor design pattern
function Module:accept(visitor)
visitor:visit(self)
end
-- Can be used as a regularizer instead of weight decay
-- Assumes that parameters are arranged (output dim x ... x input dim)
function Module:maxParamNorm(maxOutNorm, maxInNorm)
-- this allows each module to set its own max[Out,In]Norm
maxOutNorm = self.maxOutNorm or maxOutNorm
maxInNorm = self.maxInNorm or maxInNorm
if not (maxOutNorm or maxInNorm) then
return
end
if self.modules then
for i,module in ipairs(self.modules) do
module:maxParamNorm(maxOutNorm, maxInNorm)
end
else
local params = self:parameters()
if not params or gradParams then
return
end
for k,param in pairs(params) do -- pairs for sparse params
-- By default, only affects non-1D params.
if param:dim() > 1 then
if maxOutNorm and maxOutNorm > 0 then
-- rows feed into output neurons
param:renorm(2, 1, maxOutNorm)
end
if maxInNorm and maxInNorm > 0 then
-- cols feed out from input neurons
param:renorm(2, param:dim(), maxInNorm)
end
end
end
end
end
-- Similar to maxParamNorm, but norm is global to Module for which
-- this is called. Unless moduleLocal is true, in which case, the
-- norm constraint is applied to the norm of all parameters in each
-- component (non-container) module.
function Module:gradParamClip(cutoffNorm, moduleLocal)
-- this allows each module to set its own cutoffNorm
cutoffNorm = self.cutoffNorm or cutoffNorm
if cutoffNorm <= 0 then
return
end
if self.moduleLocal ~= nil then
moduleLocal = self.moduleLocal
end
local norm = 0
if moduleLocal and self.modules then
for i,module in ipairs(self.modules) do
norm = norm + math.pow(module:gradParamClip(cutoffNorm, moduleLocal), 2)
end
norm = math.sqrt(norm)
else
local params, gradParams = self:parameters()
if not (params and gradParams) then
return norm
end
for k,gradParam in pairs(gradParams) do -- pairs for sparse params
if torch.type(gradParam) == 'torch.CudaTensor' then
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
norm = norm + math.pow(gradParam:norm(),2)
end)
else
norm = norm + math.pow(gradParam:norm(),2)
end
end
norm = math.sqrt(norm)
if norm > cutoffNorm then
-- rescale gradParams to obtain desired cutoffNorm
for k,gradParam in pairs(gradParams) do
if torch.type(gradParam) == 'torch.CudaTensor' then
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
gradParam:mul(cutoffNorm/norm)
end)
else
gradParam:mul(cutoffNorm/norm)
end
end
end
end
return norm
end
-- Adds weight decay constraint on params with dims > 2 (default).
-- TODO : allow inplace weightDecay (before calling accUpdateGradParameters)
function Module:weightDecay(wdFactor, wdMinDim)
-- this allows each module to set its own hyper-parameters
wdFactor = self.wdFactor or wdFactor
if wdFactor <= 0 then
return
end
wdMinDim = self.wdMinDim or wdMinDim or 2
if self.modules then
for i,module in ipairs(self.modules) do
module:weightDecay(wdFactor, wdMinDim)
end
else
local params, gradParams = self:parameters()
if not (params and gradParams) then
return
end
for i,param in pairs(params) do -- pairs for sparse params
if param:dim() >= wdMinDim then
gradParams[i]:add(wdFactor, param)
end
end
end
end
function Module:momentumGradParameters()
if (not self.momGradParams) or _.isEmpty(self.momGradParams) then
local params, gradParams = self:parameters()
if not gradParams or _.isEmpty(gradParams) then
return
end
self.momGradParams = {}
for i,gradParam in pairs(gradParams) do
if torch.type(gradParam) == 'torch.CudaTensor' then
cutorch.withDevice(gradParam:getDevice(), function() -- support multi-device models
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
end)
else
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
end
end
end
return self.momGradParams
end
-- uses momentum learning to update gradParams
function Module:updateGradParameters(momFactor, momDamp, momNesterov)
-- this allows each module to set its own hyper-parameters
momFactor = self.momFactor or momFactor
if momFactor <= 0 then
return
end
momDamp = self.momDamp or momDamp or momFactor
if self.momNesterov ~= nil then
momNesterov = self.momNesterov
end
if self.modules then
for i,module in ipairs(self.modules) do
module:updateGradParameters(momFactor, momDamp, momNesterov)
end
else
local params, gradParams = self:parameters()
if (not params) or _.isEmpty(params) then
return
end
local momGradParams = self:momentumGradParameters()
for i,gradParam in pairs(gradParams) do
momGradParams[i]:mul(momFactor)
momGradParams[i]:add(1-momDamp, gradParam)
end
if momNesterov then
for i,gradParam in pairs(gradParams) do
gradParam:add(momFactor, momGradParams[i])
end
else
for i,gradParam in pairs(gradParams) do
gradParam:copy(momGradParams[i])
end
end
end
end
function Module:checkParameters()
local params = self:parameters() or {}
for k,param in pairs(params) do
if _.isNaN(param:sum()) then
error("NaN Error for param at index" ..k)
end
end
end
function Module:dontBackward()
self.updateGradInput = function() end
self.accGradParameters = function() end
self.accUpdateGradParameters = function() end
return self
end
function Module:contiguousInput(input, backward)
if backward then
return self.dpnn_cinput or input
end
if not input:isContiguous() then
self.dpnn_cinput = self.dpnn_cinput or input.new()
self.dpnn_cinput:resizeAs(input):copy(input)
input = self.dpnn_cinput
end
return input
end
function Module:toBatch(tensor, nDim, batchDim)
local batchDim = batchDim or 1
if tensor:dim() == nDim then
self.dpnn_online = true
local size = tensor:size():totable()
table.insert(size, batchDim, 1)
tensor = tensor:view(table.unpack(size))
else
self.dpnn_online = false
end
return tensor
end
function Module:fromBatch(tensor, batchDim)
if self.dpnn_online then
local size = tensor:size():totable()
assert(table.remove(size, batchDim) == 1)
tensor = tensor:view(table.unpack(size))
end
return tensor
end
function Module:extrapolateType()
local params = module:parameters()
if params then
-- extrapolate the tensor type of the module
local types = {}
for i, param in ipairs(params) do
local tensorType = torch.type(param)
types[tensorType] = (types[tensorType] or 0) + 1
end
local maxCount = 0
local maxType
for tensorType, count in pairs(types) do
if count > maxCount then
maxtype = tensorType
maxCount = count
end
end
return maxType
end
return nil --unknown otherwise
end
function Module:profile()
if self.modules then
for i, module in ipairs(self.modules) do
module:profile()
end
end
self.dpnn_profile = true
end
function Module:reinforce(reward)
if self.modules then
for i, module in ipairs(self.modules) do
module:reinforce(reward)
end
end
end