-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeometry.py
175 lines (140 loc) · 6.77 KB
/
geometry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
@torch.no_grad()
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
""" Warp kpts0 from I0 to I1 with depth, K and Rt
Also check covisibility and depth consistency.
Depth is consistent if relative error < 0.2 (hard-coded).
Args:
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
depth0 (torch.Tensor): [N, H, W],
depth1 (torch.Tensor): [N, H, W],
T_0to1 (torch.Tensor): [N, 4, 4],
K0 (torch.Tensor): [N, 3, 3],
K1 (torch.Tensor): [N, 3, 3],
Returns:
calculable_mask (torch.Tensor): [N, L]
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
"""
kpts0_long = kpts0.round().long()
torch.set_printoptions(profile="full")
# print("kpts0", kpts0_long)
# print("-----------------------------")
# Sample depth, get calculable_mask on depth != 0
kpts0_depth = torch.stack(
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
) # (N, L)
nonzero_mask = kpts0_depth != 0
# print("kpts0_depth", kpts0_depth)
# print("-----------------------------")
# Unproject
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
# Rigid Transform
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
# Project
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
# print("w_kpts0", w_kpts0)
# print("-----------------------------")
# Covisible Check
h, w = depth1.shape[1:3]
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
(w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
w_kpts0_long = w_kpts0.long()
w_kpts0_long[~covisible_mask, :] = 0
w_kpts0_depth = torch.stack(
[depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
) # (N, L)
consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
valid_mask = nonzero_mask * covisible_mask * consistent_mask
# print("nonzero_mask", torch.sum(nonzero_mask), nonzero_mask.shape)
# print("covisible_mask", torch.sum(covisible_mask), covisible_mask.shape)
# print("consistent_mask", torch.sum(consistent_mask), consistent_mask.shape)
# print("valid_mask", torch.sum(valid_mask), valid_mask.shape)
return valid_mask, w_kpts0
@torch.no_grad()
def warp_kpts_chd(kpts0, depth0, depth1, height_map0, height_map_info0, T0, T1, K0, K1):
""" Warp kpts0 from I0 to I1 with depth, K and Rt
Compensate for Height Difference using height maps of real crop images.
Also check covisibility and depth consistency.
Depth is consistent if relative error < 0.2 (hard-coded).
Args:
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
depth0 (torch.Tensor): [N, H, W],
depth1 (torch.Tensor): [N, H, W],
height_map0 (torch.Tensor): [N, H, W],
height_map_info0 (torch.Tensor): [N, 5], [cell_size, x_min, y_min, x_max, y_max],
T0 (torch.Tensor): [N, 4, 4],
T1 (torch.Tensor): [N, 4, 4],
K0 (torch.Tensor): [N, 3, 3],
K1 (torch.Tensor): [N, 3, 3],
Returns:
calculable_mask (torch.Tensor): [N, L]
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
"""
cell_size = height_map_info0[:,0].reshape(-1,1,1)
x_min = height_map_info0[:,1].reshape(-1,1,1)
y_min = height_map_info0[:,2].reshape(-1,1,1)
xy_min = torch.cat([x_min, y_min], dim=-1)
kpts0_long = kpts0.round().long()
torch.set_printoptions(profile="full")
# print("kpts0", kpts0_long)
# print("-----------------------------")
# Sample depth, get calculable_mask on depth != 0
kpts0_depth = torch.stack(
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
) # (N, L)
nonzero_mask = kpts0_depth != 0
# Unproject
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
# To Ground
w_kpts0_ground = T0[:, :3, :3] @ kpts0_cam + T0[:, :3, [3]] # (N, 3, L)
# Get x, y coordinates
w_kpts0_ground_xy = w_kpts0_ground[:, :2, :].transpose(1, 2) # (N, L, 2)
# Move indices to the range of the height map
kpts0_height_map_indices = (
(w_kpts0_ground_xy - xy_min) / cell_size
).to(torch.int64)
# Clip indices to stay within valid range
num_y, num_x = height_map0.shape[1], height_map0.shape[2]
clipped_x_indices = torch.clamp(kpts0_height_map_indices[:, :, 0], 0, num_x - 1)
clipped_y_indices = torch.clamp(kpts0_height_map_indices[:, :, 1], 0, num_y - 1)
# Query height map to get the new z
batch_indices = torch.arange(kpts0.shape[0], device=kpts0.device)
kpts0_height_map = height_map0[
batch_indices[:,None],
clipped_x_indices,
clipped_y_indices
] # (N, L)
height_nonzero_mask = kpts0_height_map != 0
nonzero_mask *= height_nonzero_mask
# replace z with the new ones
w_kpts0_ground[:, 2, :] = kpts0_height_map
# To Cam1
T1_inv = T1.inverse()
w_kpts0_cam = T1_inv[:, :3, :3] @ w_kpts0_ground + T1_inv[:, :3, [3]] # (N, 3, L)
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
# Project
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
w_kpts0[~nonzero_mask] = 0 # if height is invalid, or depth is 0, warp the point to the left-up corner
# print("w_kpts0", w_kpts0)
# print("-----------------------------")
# Covisible Check
h, w = depth1.shape[1:3]
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
(w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
w_kpts0_long = w_kpts0.long()
w_kpts0_long[~covisible_mask, :] = 0
w_kpts0_depth = torch.stack(
[depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
) # (N, L)
consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
valid_mask = nonzero_mask * covisible_mask * consistent_mask
# print("nonzero_mask", torch.sum(nonzero_mask), nonzero_mask.shape)
# print("covisible_mask", torch.sum(covisible_mask), covisible_mask.shape)
# print("consistent_mask", torch.sum(consistent_mask), consistent_mask.shape)
# print("valid_mask", torch.sum(valid_mask), valid_mask.shape)
return valid_mask, w_kpts0