-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.py
96 lines (79 loc) · 3.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
""" Segment and generate Locus descriptor for each scan in a sequence. """
import sys
import os
import glob
import yaml
import argparse
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), 'segmentation'))
sys.path.append(os.path.join(os.path.dirname(__file__), 'descriptor_generation'))
from utils.kitti_dataloader import *
from utils.augment_scans import *
from segmentation.extract_segments import *
from segmentation.extract_segment_features import *
from descriptor_generation.locus_descriptor import *
seg_timer, feat_timer, desc_timer = Timer(), Timer(), Timer()
# Load params
parser = argparse.ArgumentParser()
parser.add_argument("--seq", default='02', help="KITTI sequence number")
parser.add_argument("--aug_type", default='none', help="Scan augmentation type ['occ', 'rot', 'ds']")
parser.add_argument("--aug_param", default=0, type=float, help="Scan augmentation parameter")
args = parser.parse_args()
print('Sequence: ', args.seq, ', Augmentation: ', args.aug_type, ', Param: ', args.aug_param)
cfg_file = open('config.yml', 'r')
cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader)
desc_params = cfg_params['descriptor_generation']
seg_params = cfg_params['segmentation']
# Load data
basedir = cfg_params['paths']['KITTI_dataset']
sequence_path = basedir + 'sequences/' + args.seq + '/'
bin_files = sorted(glob.glob(os.path.join(
sequence_path, 'velodyne', '*.bin')))
scans = yield_bin_scans(bin_files)
transforms, _ = load_poses_from_txt(sequence_path + 'poses.txt')
rel_transforms = get_delta_pose(transforms)
# Setup database variables
num_queries = len(rel_transforms)
segments_database, features_database = [], []
seg_corres_database, locus_descriptor_database = [], []
database_dict = {'segments_database': segments_database,
'features_database': features_database,
'seg_corres_database': seg_corres_database,
'rel_transforms': rel_transforms}
for query_idx in tqdm(range(num_queries)):
# Load LiDAR scan point cloud
scan = next(scans)
scan = scan[:, :-1]
# Optional scan augmentation for robustness tests
if args.aug_type == 'rot':
scan, rot_mat = augmented_scan(scan, args.aug_type, args.aug_param)
transforms[query_idx][:3,:3] = np.dot(transforms[query_idx][:3,:3], rot_mat)
if query_idx > 0:
database_dict['rel_transforms'][query_idx-1] = get_delta_pose([transforms[query_idx-1], transforms[query_idx]])[0]
elif args.aug_type == 'occ':
scan = augmented_scan(scan, args.aug_type, args.aug_param)
# Extract segments
seg_timer.tic()
segments = get_segments(scan, seg_params)
segments_database.append(segments)
seg_timer.toc()
# Extract segment features
feat_timer.tic()
features = get_segment_features(segments)
features_database.append(features)
feat_timer.toc()
# Generate 'Locus' global descriptor
desc_timer.tic()
locus_descriptor = get_locus_descriptor(query_idx, desc_params, database_dict)
locus_descriptor_database.append(locus_descriptor)
desc_timer.toc()
print('Average time per scan:')
print(f"--- seg: {seg_timer.avg}s, feat: {feat_timer.avg}s, desc: {desc_timer.avg}s ---")
save_dir = cfg_params['paths']['save_dir'] + args.seq
if not os.path.exists(save_dir):
os.makedirs(save_dir)
desc_file_name = '/locus_descriptor_' + desc_params['fb_mode']
if args.aug_type != 'none':
desc_file_name = desc_file_name + '_' + args.aug_type + str(int(args.aug_param))
save_pickle(locus_descriptor_database, save_dir +
desc_file_name + '.pickle')