-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathprovider.lua
132 lines (117 loc) · 4.14 KB
/
provider.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
require 'nn'
require 'image'
require 'xlua'
torch.setdefaulttensortype('torch.FloatTensor')
-- parse STL-10 data from table into Tensor
function parseDataLabel(d, numSamples, numChannels, height, width)
local t = torch.ByteTensor(numSamples, numChannels, height, width)
local l = torch.ByteTensor(numSamples)
local idx = 1
for i = 1, #d do
local this_d = d[i]
for j = 1, #this_d do
t[idx]:copy(this_d[j])
l[idx] = i
idx = idx + 1
end
end
assert(idx == numSamples+1)
return t, l
end
local Provider = torch.class 'Provider'
function Provider:__init(full)
local trsize = 4000
local valsize = 1000 -- Use the validation here as the valing set
local channel = 3
local height = 96
local width = 96
-- download dataset
if not paths.dirp('stl-10') then
os.execute('mkdir stl-10')
local www = {
train = 'https://s3.amazonaws.com/dsga1008-spring16/data/a2/train.t7b',
val = 'https://s3.amazonaws.com/dsga1008-spring16/data/a2/val.t7b',
extra = 'https://s3.amazonaws.com/dsga1008-spring16/data/a2/extra.t7b',
test = 'https://s3.amazonaws.com/dsga1008-spring16/data/a2/test.t7b'
}
os.execute('wget ' .. www.train .. '; '.. 'mv train.t7b stl-10/train.t7b')
os.execute('wget ' .. www.val .. '; '.. 'mv val.t7b stl-10/val.t7b')
os.execute('wget ' .. www.test .. '; '.. 'mv test.t7b stl-10/test.t7b')
os.execute('wget ' .. www.extra .. '; '.. 'mv extra.t7b stl-10/extra.t7b')
end
local raw_train = torch.load('stl-10/train.t7b')
local raw_val = torch.load('stl-10/val.t7b')
-- load and parse dataset
self.trainData = {
data = torch.Tensor(),
labels = torch.Tensor(),
size = function() return trsize end
}
self.trainData.data, self.trainData.labels = parseDataLabel(raw_train.data,
trsize, channel, height, width)
local trainData = self.trainData
self.valData = {
data = torch.Tensor(),
labels = torch.Tensor(),
size = function() return valsize end
}
self.valData.data, self.valData.labels = parseDataLabel(raw_val.data,
valsize, channel, height, width)
local valData = self.valData
-- convert from ByteTensor to Float
self.trainData.data = self.trainData.data:float()
self.trainData.labels = self.trainData.labels:float()
self.valData.data = self.valData.data:float()
self.valData.labels = self.valData.labels:float()
collectgarbage()
end
function Provider:normalize()
----------------------------------------------------------------------
-- preprocess/normalize train/val sets
--
local trainData = self.trainData
local valData = self.valData
print '<trainer> preprocessing data (color space + normalization)'
collectgarbage()
-- preprocess trainSet
local normalization = nn.SpatialContrastiveNormalization(1, image.gaussian1D(7))
for i = 1,trainData:size() do
xlua.progress(i, trainData:size())
-- rgb -> yuv
local rgb = trainData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[1] = normalization(yuv[{{1}}])
trainData.data[i] = yuv
end
-- normalize u globally:
local mean_u = trainData.data:select(2,2):mean()
local std_u = trainData.data:select(2,2):std()
trainData.data:select(2,2):add(-mean_u)
trainData.data:select(2,2):div(std_u)
-- normalize v globally:
local mean_v = trainData.data:select(2,3):mean()
local std_v = trainData.data:select(2,3):std()
trainData.data:select(2,3):add(-mean_v)
trainData.data:select(2,3):div(std_v)
trainData.mean_u = mean_u
trainData.std_u = std_u
trainData.mean_v = mean_v
trainData.std_v = std_v
-- preprocess valSet
for i = 1,valData:size() do
xlua.progress(i, valData:size())
-- rgb -> yuv
local rgb = valData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[{1}] = normalization(yuv[{{1}}])
valData.data[i] = yuv
end
-- normalize u globally:
valData.data:select(2,2):add(-mean_u)
valData.data:select(2,2):div(std_u)
-- normalize v globally:
valData.data:select(2,3):add(-mean_v)
valData.data:select(2,3):div(std_v)
end