-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcosplace.py
46 lines (39 loc) · 1.42 KB
/
cosplace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
'''
Code for loading models trained with CosPlace as a global features extractor
for geolocalization through image retrieval.
Multiple models are available with different backbones. Below is a summary of
models available (backbone : list of available output descriptors
dimensionality). For example you can use a model based on a ResNet50 with
descriptors dimensionality 1024.
ResNet18: [32, 64, 128, 256, 512]
ResNet50: [32, 64, 128, 256, 512, 1024, 2048]
ResNet101: [32, 64, 128, 256, 512, 1024, 2048]
ResNet152: [32, 64, 128, 256, 512, 1024, 2048]
VGG16: [ 64, 128, 256, 512]
CosPlace paper: https://arxiv.org/abs/2204.02287
'''
import torch
import torchvision.transforms as tvf
from ..utils.base_model import BaseModel
class CosPlace(BaseModel):
default_conf = {
'backbone': 'ResNet50',
'fc_output_dim' : 2048
}
required_inputs = ['image']
def _init(self, conf):
self.net = torch.hub.load(
'gmberton/CosPlace',
'get_trained_model',
backbone=conf['backbone'],
fc_output_dim=conf['fc_output_dim']
).eval()
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
self.norm_rgb = tvf.Normalize(mean=mean, std=std)
def _forward(self, data):
image = self.norm_rgb(data['image'])
desc = self.net(image)
return {
'global_descriptor': desc,
}