forked from AlfredXiangWu/LightCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_features.py
112 lines (97 loc) · 4.02 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
'''
implement the feature extractions for light CNN
@author: Alfred Xiang Wu
@date: 2017.07.04
'''
from __future__ import print_function
import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import cv2
from light_cnn import LightCNN_9Layers, LightCNN_29Layers, LightCNN_29Layers_v2
from load_imglist import ImageList
parser = argparse.ArgumentParser(description='PyTorch ImageNet Feature Extracting')
parser.add_argument('--arch', '-a', metavar='ARCH', default='LightCNN')
parser.add_argument('--cuda', '-c', default=True)
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--model', default='', type=str, metavar='Model',
help='model type: LightCNN-9, LightCNN-29')
parser.add_argument('--root_path', default='', type=str, metavar='PATH',
help='root path of face images (default: none).')
parser.add_argument('--img_list', default='', type=str, metavar='PATH',
help='list of face images for feature extraction (default: none).')
parser.add_argument('--save_path', default='', type=str, metavar='PATH',
help='save root path for features of face images.')
parser.add_argument('--num_classes', default=79077, type=int,
metavar='N', help='mini-batch size (default: 79077)')
def main():
global args
args = parser.parse_args()
if args.model == 'LightCNN-9':
model = LightCNN_9Layers(num_classes=args.num_classes)
elif args.model == 'LightCNN-29':
model = LightCNN_29Layers(num_classes=args.num_classes)
elif args.model == 'LightCNN-29v2':
model = LightCNN_29Layers_v2(num_classes=args.num_classes)
else:
print('Error model type\n')
model.eval()
if args.cuda:
model = torch.nn.DataParallel(model).cuda()
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
else:
print("=> no checkpoint found at '{}'".format(args.resume))
img_list = read_list(args.img_list)
transform = transforms.Compose([transforms.ToTensor()])
count = 0
input = torch.zeros(1, 1, 128, 128)
for img_name in img_list:
count = count + 1
img = cv2.imread(os.path.join(args.root_path, img_name), cv2.IMREAD_GRAYSCALE)
img = np.reshape(img, (128, 128, 1))
img = transform(img)
input[0,:,:,:] = img
start = time.time()
if args.cuda:
input = input.cuda()
input_var = torch.autograd.Variable(input, volatile=True)
_, features = model(input_var)
end = time.time() - start
print("{}({}/{}). Time: {}".format(os.path.join(args.root_path, img_name), count, len(img_list), end))
save_feature(args.save_path, img_name, features.data.cpu().numpy()[0])
def read_list(list_path):
img_list = []
with open(list_path, 'r') as f:
for line in f.readlines()[0:]:
img_path = line.strip().split()
img_list.append(img_path[0])
print('There are {} images..'.format(len(img_list)))
return img_list
def save_feature(save_path, img_name, features):
img_path = os.path.join(save_path, img_name)
img_dir = os.path.dirname(img_path) + '/';
if not os.path.exists(img_dir):
os.makedirs(img_dir)
fname = os.path.splitext(img_path)[0]
fname = fname + '.feat'
fid = open(fname, 'wb')
fid.write(features)
fid.close()
if __name__ == '__main__':
main()