Skip to content

Commit

Permalink
Merge pull request #21 from qzcode/master
Browse files Browse the repository at this point in the history
loss_merge
  • Loading branch information
layumi authored Dec 16, 2021
2 parents 284cee7 + e80ed9a commit f3908c7
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 119 deletions.
26 changes: 18 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def forward(self, x):
return x

class two_view_net(nn.Module):
def __init__(self, class_num, droprate, stride = 2, pool = 'avg', share_weight = False, VGG16=False):
def __init__(self, class_num, droprate, stride = 2, pool = 'avg', share_weight = False, VGG16=False, circle=False,):
super(two_view_net, self).__init__()
if VGG16:
self.model_1 = ft_net_VGG16(class_num, stride=stride, pool = pool)
Expand All @@ -208,13 +208,15 @@ def __init__(self, class_num, droprate, stride = 2, pool = 'avg', share_weight =
else:
self.model_2 = ft_net(class_num, stride = stride, pool = pool)

self.classifier = ClassBlock(2048, class_num, droprate)
self.circle = circle

self.classifier = ClassBlock(2048, class_num, droprate, return_f = circle)
if pool =='avg+max':
self.classifier = ClassBlock(4096, class_num, droprate)
self.classifier = ClassBlock(4096, class_num, droprate, return_f = circle)
if VGG16:
self.classifier = ClassBlock(512, class_num, droprate)
self.classifier = ClassBlock(512, class_num, droprate, return_f = circle)
if pool =='avg+max':
self.classifier = ClassBlock(1024, class_num, droprate)
self.classifier = ClassBlock(1024, class_num, droprate, return_f = circle)

def forward(self, x1, x2):
if x1 is None:
Expand All @@ -232,7 +234,7 @@ def forward(self, x1, x2):


class three_view_net(nn.Module):
def __init__(self, class_num, droprate, stride = 2, pool = 'avg', share_weight = False, VGG16=False):
def __init__(self, class_num, droprate, stride = 2, pool = 'avg', share_weight = False, VGG16=False, circle=False):
super(three_view_net, self).__init__()
if VGG16:
self.model_1 = ft_net_VGG16(class_num, stride = stride, pool = pool)
Expand All @@ -248,33 +250,41 @@ def __init__(self, class_num, droprate, stride = 2, pool = 'avg', share_weight =
self.model_3 = ft_net_VGG16(class_num, stride = stride, pool = pool)
else:
self.model_3 = ft_net(class_num, stride = stride, pool = pool)
self.classifier = ClassBlock(2048, class_num, droprate)

self.circle = circle

self.classifier = ClassBlock(2048, class_num, droprate, return_f = circle)
if pool =='avg+max':
self.classifier = ClassBlock(4096, class_num, droprate)

self.classifier = ClassBlock(4096, class_num, droprate, return_f = circle)

def forward(self, x1, x2, x3, x4 = None): # x4 is extra data
if x1 is None:
y1 = None
else:
x1 = self.model_1(x1)
x1 = x1.view(x1.size(0), x1.size(1))
y1 = self.classifier(x1)

if x2 is None:
y2 = None
else:
x2 = self.model_2(x2)
x2 = x2.view(x2.size(0), x2.size(1))
y2 = self.classifier(x2)

if x3 is None:
y3 = None
else:
x3 = self.model_3(x3)
x3 = x3.view(x3.size(0), x3.size(1))
y3 = self.classifier(x3)

if x4 is None:
return y1, y2, y3
else:
x4 = self.model_2(x4)
x4 = x4.view(x4.size(0), x4.size(1))
y4 = self.classifier(x4)
return y1, y2, y3, y4

Expand Down
Loading

0 comments on commit f3908c7

Please sign in to comment.