forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Clip.lua
35 lines (31 loc) · 1.21 KB
/
Clip.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
------------------------------------------------------------------------
--[[ Clip ]]--
-- clips values within minval and maxval
------------------------------------------------------------------------
local Clip, parent = torch.class("nn.Clip", "nn.Module")
function Clip:__init(minval, maxval)
assert(torch.type(minval) == 'number')
assert(torch.type(maxval) == 'number')
self.minval = minval
self.maxval = maxval
parent.__init(self)
end
function Clip:updateOutput(input)
-- bound results within height and width
self._mask = self._mask or input.new()
self._byte = self._byte or torch.ByteTensor()
self.output:resizeAs(input):copy(input)
self._mask:gt(self.output, self.maxval)
local byte = torch.type(self.output) == 'torch.CudaTensor' and self._mask
or self._byte:resize(self._mask:size()):copy(self._mask)
self.output[byte] = self.maxval
self._mask:lt(self.output, self.minval)
byte = torch.type(self.output) == 'torch.CudaTensor' and self._mask
or self._byte:resize(self._mask:size()):copy(self._mask)
self.output[byte] = self.minval
return self.output
end
function Clip:updateGradInput(input, gradOutput)
self.gradInput:set(gradOutput)
return self.gradInput
end