-
Notifications
You must be signed in to change notification settings - Fork 4
/
LeNet.py
138 lines (118 loc) · 3.94 KB
/
LeNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch.nn as nn
import torch.nn.functional as F
from SBP_utils import Conv2d_SBP, Linear_SBP, SBP_layer
class LeNet_GS(nn.Module):
def __init__(self):
super(LeNet_GS, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 5)
self.bn2 = nn.BatchNorm2d(16)
self.fc1 = nn.Linear(16*5*5, 120)
self.fcbn1 = nn.BatchNorm1d(120)
self.fc2 = nn.Linear(120, 84)
self.fcbn2 = nn.BatchNorm1d(84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.fcbn1(out)
out = F.relu(out)
out = self.fc2(out)
out = self.fcbn2(out)
out = F.relu(out)
out = self.fc3(out)
return out
def group_sparse(self):
bn_weight_list = [self.bn1.weight, self.bn2.weight, self.fcbn1.weight, self.fcbn2.weight]
return bn_weight_list
class LeNet(nn.Module):
"""
Input - 1x32x32
C1 - 6@28x28 (5x5 kernel)
tanh
S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
C3 - 16@10x10 (5x5 kernel, complicated shit)
tanh
S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
C5 - 120@1x1 (5x5 kernel)
F6 - 84
tanh
F7 - 10 (Output)
"""
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 50, 5)
self.fc1 = nn.Linear(800, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = F.relu(out)
out = self.fc2(out)
return out
class LeNet_SBP(nn.Module):
def __init__(self):
super(LeNet_SBP, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 50, 5)
self.fc1 = nn.Linear(800, 500)
self.fc2 = nn.Linear(500, 10)
self.sbp_1 = SBP_layer(20)
self.sbp_2 = SBP_layer(50)
self.sbp_3 = SBP_layer(800)
self.sbp_4 = SBP_layer(500)
def forward(self, x):
if self.training:
out = self.conv1(x)
out,kl1 = self.sbp_1(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out,kl2 = self.sbp_2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out, kl3 = self.sbp_3(out)
out = self.fc1(out)
out, kl4 = self.sbp_4(out)
out = F.relu(out)
out = self.fc2(out)
kl_sum = (0.3*kl1+0.3*kl2+0.2*kl3+0.2*kl4)
return out,kl_sum
else:
out = self.conv1(x)
out = self.sbp_1(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out= self.conv2(out)
out = self.sbp_2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = self.sbp_3(out)
out = self.fc1(out)
out = self.sbp_4(out)
out = F.relu(out)
out = self.fc2(out)
return out
def layerwise_sparsity(self):
return [self.sbp_1.layer_sparsity(), self.sbp_2.layer_sparsity(), self.sbp_3.layer_sparsity(),
self.sbp_4.layer_sparsity()]
def display_snr(self):
return [self.sbp_1.display_snr(), self.sbp_2.display_snr(), self.sbp_3.display_snr(), self.sbp_4.display_snr()]