Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

File name collision fix and minor extentions #125

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions match_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import numpy as np
import matplotlib.cm as cm
import torch
import pickle


from models.matching import Matching
Expand All @@ -62,6 +63,15 @@
torch.set_grad_enabled(False)


def pair_names_to_id(names):
# Remove extention
names = [str(Path(name).with_suffix('')) for name in names]
# Replace '/'
names = [name.replace('/', '__') for name in names]
# Concat
return names[0] + '___' + names[1]


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Image pair matching and pose evaluation with SuperGlue',
Expand All @@ -77,6 +87,13 @@
'--output_dir', type=str, default='dump_match_pairs/',
help='Path to the directory in which the .npz results and optionally,'
'the visualization images are written')
parser.add_argument(
'--input_points', type=str, default=None,
help='Path to the directory in which the .pkl files with optional custom keypoints'
'are stored. Each file should be named as a <stem0>_<stem1>.pkl of names from'
'input_pairs file (the same as a format of output_dir). And each file'
'shold contain a pair of [n_points x 2] integer tensors which store'
'coordinates of custom keypoints.')

parser.add_argument(
'--max_length', type=int, default=-1,
Expand Down Expand Up @@ -209,11 +226,11 @@
for i, pair in enumerate(pairs):
name0, name1 = pair[:2]
stem0, stem1 = Path(name0).stem, Path(name1).stem
matches_path = output_dir / '{}_{}_matches.npz'.format(stem0, stem1)
eval_path = output_dir / '{}_{}_evaluation.npz'.format(stem0, stem1)
viz_path = output_dir / '{}_{}_matches.{}'.format(stem0, stem1, opt.viz_extension)
viz_eval_path = output_dir / \
'{}_{}_evaluation.{}'.format(stem0, stem1, opt.viz_extension)
pair_id = pair_names_to_id((name0, name1))
matches_path = output_dir / f'{pair_id}_matches.npz'
eval_path = output_dir / f'{pair_id}_evaluation.npz'
viz_path = output_dir / f'{pair_id}_matches.{opt.viz_extension}'
viz_eval_path = output_dir / f'{pair_id}_evaluation.{opt.viz_extension}'

# Handle --cache logic.
do_match = True
Expand Down Expand Up @@ -269,17 +286,32 @@
exit(1)
timer.update('load_image')

# Load the optional custom points
if opt.input_points is not None:
input_points_dir = Path(opt.input_points)
with open(input_points_dir / f'{pair_id}.pkl', 'rb') as f:
pts0, pts1 = pickle.load(f)
else:
pts0, pts1 = None, None

if do_match:
# Perform the matching.
pred = matching({'image0': inp0, 'image1': inp1})
matching_args = {'image0': inp0, 'image1': inp1}
if pts0 is not None:
matching_args['points0'] = pts0
if pts1 is not None:
matching_args['points1'] = pts1
pred = matching(matching_args)
pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
matches, conf = pred['matches0'], pred['matching_scores0']
matches1, conf1 = pred['matches1'], pred['matching_scores1']
timer.update('matcher')

# Write the matches to disk.
out_matches = {'keypoints0': kpts0, 'keypoints1': kpts1,
'matches': matches, 'match_confidence': conf}
'matches0': matches, 'match_confidence0': conf,
'matches1': matches1, 'match_confidence1': conf1}
np.savez(str(matches_path), **out_matches)

# Keep the matching keypoints.
Expand Down
10 changes: 8 additions & 2 deletions models/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,16 @@ def forward(self, data):

# Extract SuperPoint (keypoints, scores, descriptors) if not provided
if 'keypoints0' not in data:
pred0 = self.superpoint({'image': data['image0']})
superpoint_args0 = {'image': data['image0']}
if 'points0' in data:
superpoint_args0['points'] = data['points0']
pred0 = self.superpoint(superpoint_args0)
pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
if 'keypoints1' not in data:
pred1 = self.superpoint({'image': data['image1']})
superpoint_args1 = {'image': data['image1']}
if 'points1' in data:
superpoint_args1['points'] = data['points1']
pred1 = self.superpoint(superpoint_args1)
pred = {**pred, **{k+'1': v for k, v in pred1.items()}}

# Batch all features
Expand Down
19 changes: 15 additions & 4 deletions models/superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,15 @@ def __init__(self, config):

def forward(self, data):
""" Compute keypoints, scores, descriptors for image """
# Pad
x = data['image']
original_shape = x.shape
w_pad = original_shape[2] % 8
h_pad = original_shape[3] % 8
x = nn.functional.pad(x, (0, 0, w_pad, h_pad))

# Shared Encoder
x = self.relu(self.conv1a(data['image']))
x = self.relu(self.conv1a(x))
x = self.relu(self.conv1b(x))
x = self.pool(x)
x = self.relu(self.conv2a(x))
Expand All @@ -167,9 +174,13 @@ def forward(self, data):
scores = simple_nms(scores, self.config['nms_radius'])

# Extract keypoints
keypoints = [
torch.nonzero(s > self.config['keypoint_threshold'])
for s in scores]
if 'points' not in data:
keypoints = [
torch.nonzero(s > self.config['keypoint_threshold'])
for s in scores]
else:
keypoints = [data['points'].long()]

scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]

# Discard keypoints near the image borders
Expand Down