-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
153 lines (125 loc) · 5.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.optim import lr_scheduler
from torch.autograd import Variable
def squash(x):
lengths2 = x.pow(2).sum(dim=2)
lengths = lengths2.sqrt()
x = x * (lengths2 / (1 + lengths2) / lengths).view(x.size(0), x.size(1), 1)
return x
class AgreementRouting(nn.Module):
def __init__(self, input_caps, output_caps, n_iterations):
super(AgreementRouting, self).__init__()
self.n_iterations = n_iterations
self.b = nn.Parameter(torch.zeros((input_caps, output_caps)))
def forward(self, u_predict):
batch_size, input_caps, output_caps, output_dim = u_predict.size()
c = F.softmax(self.b)
s = (c.unsqueeze(2) * u_predict).sum(dim=1)
v = squash(s)
if self.n_iterations > 0:
b_batch = self.b.expand((batch_size, input_caps, output_caps))
for r in range(self.n_iterations):
v = v.unsqueeze(1)
b_batch = b_batch + (u_predict * v).sum(-1)
c = F.softmax(b_batch.view(-1, output_caps)).view(-1, input_caps, output_caps, 1)
s = (c * u_predict).sum(dim=1)
v = squash(s)
return v
class CapsLayer(nn.Module):
def __init__(self, input_caps, input_dim, output_caps, output_dim, routing_module):
super(CapsLayer, self).__init__()
self.input_dim = input_dim
self.input_caps = input_caps
self.output_dim = output_dim
self.output_caps = output_caps
self.weights = nn.Parameter(torch.Tensor(input_caps, input_dim, output_caps * output_dim))
self.routing_module = routing_module
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.input_caps)
self.weights.data.uniform_(-stdv, stdv)
def forward(self, caps_output):
caps_output = caps_output.unsqueeze(2)
u_predict = caps_output.matmul(self.weights)
u_predict = u_predict.view(u_predict.size(0), self.input_caps, self.output_caps, self.output_dim)
v = self.routing_module(u_predict)
return v
class PrimaryCapsLayer(nn.Module):
def __init__(self, input_channels, output_caps, output_dim, kernel_size, stride):
super(PrimaryCapsLayer, self).__init__()
self.conv = nn.Conv2d(input_channels, output_caps * output_dim, kernel_size=kernel_size, stride=stride)
self.input_channels = input_channels
self.output_caps = output_caps
self.output_dim = output_dim
def forward(self, input):
out = self.conv(input)
N, C, H, W = out.size()
out = out.view(N, self.output_caps, self.output_dim, H, W)
# will output N x OUT_CAPS x OUT_DIM
out = out.permute(0, 1, 3, 4, 2).contiguous()
out = out.view(out.size(0), -1, out.size(4))
out = squash(out)
return out
class CapsNet(nn.Module):
def __init__(self, routing_iterations, n_classes=10):
super(CapsNet, self).__init__()
self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
self.primaryCaps = PrimaryCapsLayer(256, 32, 8, kernel_size=9, stride=2) # outputs 6*6
self.num_primaryCaps = 32 * 6 * 6
routing_module = AgreementRouting(self.num_primaryCaps, n_classes, routing_iterations)
self.digitCaps = CapsLayer(self.num_primaryCaps, 8, n_classes, 16, routing_module)
def forward(self, input):
x = self.conv1(input)
x = F.relu(x)
x = self.primaryCaps(x)
x = self.digitCaps(x)
probs = x.pow(2).sum(dim=2).sqrt()
return x, probs
class ReconstructionNet(nn.Module):
def __init__(self, n_dim=16, n_classes=10):
super(ReconstructionNet, self).__init__()
self.fc1 = nn.Linear(n_dim * n_classes, 512)
self.fc2 = nn.Linear(512, 1024)
self.fc3 = nn.Linear(1024, 784)
self.n_dim = n_dim
self.n_classes = n_classes
def forward(self, x, target):
mask = Variable(torch.zeros((x.size()[0], self.n_classes)), requires_grad=False)
if next(self.parameters()).is_cuda:
mask = mask.cuda()
mask.scatter_(1, target.view(-1, 1), 1.)
mask = mask.unsqueeze(2)
x = x * mask
x = x.view(-1, self.n_dim * self.n_classes)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.sigmoid(self.fc3(x))
return x
class CapsNetWithReconstruction(nn.Module):
def __init__(self, capsnet, reconstruction_net):
super(CapsNetWithReconstruction, self).__init__()
self.capsnet = capsnet
self.reconstruction_net = reconstruction_net
def forward(self, x, target):
x, probs = self.capsnet(x)
reconstruction = self.reconstruction_net(x, target)
return reconstruction, probs
class MarginLoss(nn.Module):
def __init__(self, m_pos, m_neg, lambda_):
super(MarginLoss, self).__init__()
self.m_pos = m_pos
self.m_neg = m_neg
self.lambda_ = lambda_
def forward(self, lengths, targets, size_average=True):
t = torch.zeros(lengths.size()).long()
if targets.is_cuda:
t = t.cuda()
t = t.scatter_(1, targets.data.view(-1, 1), 1)
targets = Variable(t)
losses = targets.float() * F.relu(self.m_pos - lengths).pow(2) + \
self.lambda_ * (1. - targets.float()) * F.relu(lengths - self.m_neg).pow(2)
return losses.mean() if size_average else losses.sum()