-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathfire_network.py
130 lines (109 loc) · 5.75 KB
/
fire_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import os
import torch
from torch import nn
import torchvision
from cirtorch.networks import imageretrievalnet
from how import layers
from how.layers import functional as HF
from lit import LocalfeatureIntegrationTransformer
from how.networks.how_net import HOWNet, CORERCF_SIZE
class FIReNet(HOWNet):
def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
super().__init__(features, attention, None, dim_reduction, meta, runtime)
self.lit = lit
self.return_global = False
def copy_excluding_dim_reduction(self):
"""Return a copy of this network without the dim_reduction layer"""
meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)
def copy_with_runtime(self, runtime):
"""Return a copy of this network with a different runtime dict"""
return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)
def parameter_groups(self):
"""Return torch parameter groups"""
layers = [self.features, self.attention, self.smoothing, self.lit]
parameters = [{'params': x.parameters()} for x in layers if x is not None]
if self.dim_reduction:
# Do not update dimensionality reduction layer
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
return parameters
def get_superfeatures(self, x, *, scales):
"""
return a list of tuple (features, attentionmpas) where each is a list containing requested scales
features is a tensor BxDxNx1
attentionmaps is a tensor BxNxHxW
"""
feats = []
attns = []
strengths = []
for s in scales:
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
o = self.features(xs)
o, attn = self.lit(o)
strength = self.attention(o)
if self.smoothing:
o = self.smoothing(o)
if self.dim_reduction:
o = self.dim_reduction(o)
feats.append(o)
attns.append(attn)
strengths.append(strength)
return feats, attns, strengths
def forward(self, x):
if self.return_global:
return self.forward_global(x, scales=self.runtime['training_scales'])
return self.get_superfeatures(x, scales=self.runtime['training_scales'])
def forward_global(self, x, *, scales):
"""Return global descriptor"""
feats, _, strengths = self.get_superfeatures(x, scales=scales)
return HF.weighted_spoc(feats, strengths)
def forward_local(self, x, *, features_num, scales):
"""Return selected super features"""
feats, _, strengths = self.get_superfeatures(x, scales=scales)
return HF.how_select_local(feats, strengths, scales=scales, features_num=features_num)
def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
"""Initialize FIRe network
:param str architecture: Network backbone architecture (e.g. resnet18)
:param str pretrained: url of the pretrained model (None for using random initialization)
:param int skip_layer: How many layers of blocks should be skipped (from the end)
:param dict dim_reduction: Options for the dimensionality reduction layer
:param dict lit: Options for the lit layer
:param dict runtime: Runtime options to be stored in the network
:return FIRe: Initialized network
"""
# Take convolutional layers as features, always ends with ReLU to make last activations non-negative
net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead
if architecture.startswith('alexnet') or architecture.startswith('vgg'):
features = list(net_in.features.children())[:-1]
elif architecture.startswith('resnet'):
features = list(net_in.children())[:-2]
elif architecture.startswith('densenet'):
features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
elif architecture.startswith('squeezenet'):
features = list(net_in.features.children())
else:
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
if skip_layer > 0:
features = features[:-skip_layer]
backbone_dim = imageretrievalnet.OUTPUT_DIM[architecture] // (2 ** skip_layer)
att_layer = layers.attention.L2Attention()
lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)
reduction_layer = None
if dim_reduction:
reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])
meta = {
"architecture": architecture,
"backbone_dim": lit['dim'],
"outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
"corercf_size": CORERCF_SIZE[architecture] // (2 ** skip_layer),
}
net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
if pretrained is not None:
assert os.path.isfile(pretrained), pretrained
ckpt = torch.load(pretrained, map_location='cpu')
missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
assert all(['fc' in a for a in unexpected]), "Loading did not go well"
return net