Skip to content

Commit

Permalink
update int searches with queryStride to allow different spatial sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 4, 2023
1 parent b04ed45 commit ffc05a8
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 121 deletions.
22 changes: 11 additions & 11 deletions lib/csrc/search/non_local_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ void non_local_search_int_forward_cuda(
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor flows,
torch::Tensor dists, torch::Tensor inds,
int ps, int k, int stride0, int stride1, int dilation, int pt,
bool reflect_bounds, bool full_ws, int patch_offset,
int off_Hq, int off_Wq, int dist_type);
int ps, int k, int stride0, int stride1, int strideQ,
int dilation, int pt, bool reflect_bounds, bool full_ws,
int patch_offset, int off_Hq, int off_Wq, int dist_type);

void non_local_search_bilin2d_forward_cuda(
const torch::Tensor vid0, const torch::Tensor vid1,
Expand All @@ -25,7 +25,7 @@ void non_local_search_int_vid_backward_cuda(
torch::Tensor grad_vid0, torch::Tensor grad_vid1,
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor grad_dists, const torch::Tensor inds,
int ps, int pt, int stride0, int dilation,
int ps, int pt, int stride0, int strideQ, int dilation,
bool reflect_bounds, int patch_offset,
int off_Hq, int off_Wq, int dist_type);

Expand Down Expand Up @@ -57,7 +57,7 @@ void non_local_search_bilin2d_vidflows_backward_cuda(
void non_local_search_int_forward(
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor flows, torch::Tensor dists, torch::Tensor inds,
int ps, int k, int stride0, int stride1, int dilation, int pt,
int ps, int k, int stride0, int stride1, int strideQ, int dilation, int pt,
bool reflect_bounds, bool full_ws, int patch_offset,
int off_Hq, int off_Wq, int dist_type){
CHECK_INPUT(vid0);
Expand All @@ -66,9 +66,9 @@ void non_local_search_int_forward(
CHECK_INPUT(dists);
CHECK_INPUT(inds);
non_local_search_int_forward_cuda(vid0, vid1, flows, dists, inds,
ps, k, stride0, stride1, dilation, pt,
reflect_bounds, full_ws, patch_offset,
off_Hq, off_Wq, dist_type);
ps, k, stride0, stride1, strideQ,
dilation, pt, reflect_bounds, full_ws,
patch_offset, off_Hq, off_Wq, dist_type);
}

void non_local_search_bilin2d_forward(
Expand All @@ -92,18 +92,18 @@ void non_local_search_int_vid_backward(
torch::Tensor grad_vid0, torch::Tensor grad_vid1,
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor grad_dists, const torch::Tensor inds,
int ps, int pt, int stride0, int dilation,
int ps, int pt, int stride0, int strideQ, int dilation,
bool reflect_bounds, int patch_offset,
int off_Hq, int off_Wq, int dist_type) {

CHECK_INPUT(grad_vid0);
CHECK_INPUT(grad_vid1);
CHECK_INPUT(vid0);
CHECK_INPUT(vid1);
CHECK_INPUT(grad_dists);
CHECK_INPUT(inds);
non_local_search_int_vid_backward_cuda(grad_vid0, grad_vid1, vid0, vid1,
grad_dists, inds, ps, pt, stride0, dilation,
grad_dists, inds, ps, pt,
stride0, strideQ, dilation,
reflect_bounds, patch_offset,
off_Hq, off_Wq, dist_type);

Expand Down
71 changes: 38 additions & 33 deletions lib/csrc/search/non_local_search_int_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ __global__ void non_local_search_int_forward_kernel(
const torch::PackedTensorAccessor32<int,7,torch::RestrictPtrTraits> flows,
torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> dists,
torch::PackedTensorAccessor32<int,7,torch::RestrictPtrTraits> inds,
int ws, int wt, int ps, int pt, int stride0, int stride1, int dilation,
int ws, int wt, int ps, int pt,
int stride0, int stride1, int strideQ, int dilation,
bool reflect_bounds, bool full_ws, int patch_offset,
int nH0, int nW0, int nHW0, int st_offset,
int nH, int nW, int nHW, int st_offset,
int off_Hq, int off_Wq, int q_per_thread, int ws_per_thread, int wt_per_thread){

// -- unpack shape --
Expand Down Expand Up @@ -62,6 +63,7 @@ __global__ void non_local_search_int_forward_kernel(

// -- decls --
int ref_patch[3];
int adj_patch[3];
int prop_patch[3];
int frame_anchor[2];
int ref_pix[3];
Expand Down Expand Up @@ -92,9 +94,10 @@ __global__ void non_local_search_int_forward_kernel(
if (qi >= Q){ continue; }

// -- pixel location from query index --
get_pixel_loc(ref_patch,qi,stride0,nW0,nHW0,qH,qW);
n_hi = ref_patch[1] / stride0;
n_wi = ref_patch[2] / stride0;
get_pixel_loc(ref_patch,qi,strideQ,nW,nHW,qH,qW);
get_pixel_loc(adj_patch,qi,stride0,nW,nHW,kH,kW);
n_hi = ref_patch[1] / strideQ;
n_wi = ref_patch[2] / strideQ;

// -- check bounds of pixel location --
// check_bounds(valid_ref_patch,ref_patch,T,H,W);
Expand All @@ -113,19 +116,19 @@ __global__ void non_local_search_int_forward_kernel(
// ---------------------------------------

// -- select time --
prop_patch[0] = ref_patch[0] + st_i; // t_next
prop_patch[0] = adj_patch[0] + st_i; // t_next
prop_patch[0] = (prop_patch[0] > t_max) ? t_max - st_i : prop_patch[0];

// -- offset with flows --
if (st_i >= st_offset){
auto flows_t = flows[ibatch][ihead_f][ref_patch[0]][st_i-st_offset];
frame_anchor[0] = ref_patch[1] + flows_t[1][n_hi][n_wi];
frame_anchor[1] = ref_patch[2] + flows_t[0][n_hi][n_wi];
auto flows_t = flows[ibatch][ihead_f][adj_patch[0]][st_i-st_offset];
frame_anchor[0] = adj_patch[1] + flows_t[1][n_hi][n_wi];
frame_anchor[1] = adj_patch[2] + flows_t[0][n_hi][n_wi];
frame_anchor[0] = bounds(frame_anchor[0],kH);
frame_anchor[1] = bounds(frame_anchor[1],kW);
}else{
frame_anchor[0] = ref_patch[1];
frame_anchor[1] = ref_patch[2];
frame_anchor[0] = adj_patch[1];
frame_anchor[1] = adj_patch[2];
}

// -- search region offsets --
Expand Down Expand Up @@ -167,9 +170,9 @@ __global__ void non_local_search_int_forward_kernel(
// -- assignent --
if (!valid){ dist = invalid; }
dists[ibatch][ihead][qi][st_i][ws_i][ws_j] = dist;
inds[ibatch][ihead][qi][st_i][ws_i][ws_j][0] = prop_patch[0] - ref_patch[0];
inds[ibatch][ihead][qi][st_i][ws_i][ws_j][1] = prop_patch[1] - ref_patch[1];
inds[ibatch][ihead][qi][st_i][ws_i][ws_j][2] = prop_patch[2] - ref_patch[2];
inds[ibatch][ihead][qi][st_i][ws_i][ws_j][0] = prop_patch[0] - adj_patch[0];
inds[ibatch][ihead][qi][st_i][ws_i][ws_j][1] = prop_patch[1] - adj_patch[1];
inds[ibatch][ihead][qi][st_i][ws_i][ws_j][2] = prop_patch[2] - adj_patch[2];

}
}
Expand All @@ -181,9 +184,9 @@ void non_local_search_int_forward_cuda(
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor flows,
torch::Tensor dists, torch::Tensor inds,
int ps, int k, int stride0, int stride1, int dilation, int pt,
bool reflect_bounds, bool full_ws, int patch_offset,
int off_Hq, int off_Wq, int dist_type){
int ps, int k, int stride0, int stride1, int strideQ,
int dilation, int pt, bool reflect_bounds, bool full_ws,
int patch_offset, int off_Hq, int off_Wq, int dist_type){

// -- derived quantities --
int kH = vid1.size(4);
Expand Down Expand Up @@ -232,7 +235,7 @@ void non_local_search_int_forward_cuda(
flows.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
dists.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
ws, wt, ps, pt, stride0, stride1, dilation,
ws, wt, ps, pt, stride0, stride1, strideQ, dilation,
reflect_bounds, full_ws, patch_offset, nH, nW, nHW,
st_offset, off_Hq, off_Wq, q_per_thread, ws_per_thread, wt_per_thread);
}));
Expand All @@ -245,7 +248,7 @@ void non_local_search_int_forward_cuda(
flows.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
dists.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
ws, wt, ps, pt, stride0, stride1, dilation,
ws, wt, ps, pt, stride0, stride1, strideQ, dilation,
reflect_bounds, full_ws, patch_offset, nH, nW, nHW,
st_offset, off_Hq, off_Wq, q_per_thread, ws_per_thread, wt_per_thread);
}));
Expand All @@ -270,7 +273,7 @@ __global__ void non_local_search_int_vid_backward_kernel(
const torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> vid1,
const torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> grad_dists,
const torch::PackedTensorAccessor32<int,7,torch::RestrictPtrTraits> inds,
int ps, int pt, int stride0, int dilation, bool reflect_bounds,
int ps, int pt, int stride0, int strideQ, int dilation, bool reflect_bounds,
int patch_offset, int off_Hq, int off_Wq, int ftrs_per_thread) {

// -- shape --
Expand All @@ -284,14 +287,15 @@ __global__ void non_local_search_int_vid_backward_kernel(
int qW = vid0.size(5);
int kH = vid1.size(4);
int kW = vid1.size(5);
int nH0 = inds.size(3);
int nW0 = inds.size(4);
int nHW0 = nH0*nW0;
int nH = inds.size(3);
int nW = inds.size(4);
int nHW = nH*nW;
int K = inds.size(5);
int Q = T*nH0*nW0;
int Q = T*nH*nW;

// -- fwd decl registers --
int ref_patch[3];
int adj_patch[3];
int prop_patch[3];
int ref[3];
int prop[3];
Expand Down Expand Up @@ -320,16 +324,17 @@ __global__ void non_local_search_int_vid_backward_kernel(
if ((qi < Q) && (ki < K)){

// -- pixel location from query index --
get_pixel_loc(ref_patch,qi,stride0,nW0,nHW0,qH,qW);
get_pixel_loc(ref_patch,qi,strideQ,nW,nHW,qH,qW);
get_pixel_loc(adj_patch,qi,stride0,nW,nHW,kH,kW);
int ti = ref_patch[0];
int nh = ref_patch[1]/stride0;
int nw = ref_patch[2]/stride0;
int nh = ref_patch[1]/strideQ;
int nw = ref_patch[2]/strideQ;

// -- proposed location --
weight = grad_dists[ibatch][ihead][ti][nh][nw][ki];
prop_patch[0] = ref_patch[0] + inds[ibatch][ihead][ti][nh][nw][ki][0];
prop_patch[1] = ref_patch[1] + inds[ibatch][ihead][ti][nh][nw][ki][1];
prop_patch[2] = ref_patch[2] + inds[ibatch][ihead][ti][nh][nw][ki][2];
prop_patch[0] = adj_patch[0] + inds[ibatch][ihead][ti][nh][nw][ki][0];
prop_patch[1] = adj_patch[1] + inds[ibatch][ihead][ti][nh][nw][ki][1];
prop_patch[2] = adj_patch[2] + inds[ibatch][ihead][ti][nh][nw][ki][2];

// -- update patch --
update_bwd_patch_int<scalar_t,DIST_TYPE>(
Expand All @@ -348,7 +353,7 @@ void non_local_search_int_vid_backward_cuda(
torch::Tensor grad_vid0, torch::Tensor grad_vid1,
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor grad_dists, const torch::Tensor inds,
int ps, int pt, int stride0, int dilation,
int ps, int pt, int stride0, int strideQ, int dilation,
bool reflect_bounds, int patch_offset,
int off_Hq, int off_Wq, int dist_type) {

Expand Down Expand Up @@ -406,7 +411,7 @@ void non_local_search_int_vid_backward_cuda(
vid1.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
grad_dists.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
ps, pt, stride0, dilation, reflect_bounds, patch_offset,
ps, pt, stride0, strideQ, dilation, reflect_bounds, patch_offset,
off_Hq, off_Wq, ftrs_per_thread);
}));
}else if (dist_type == 1){ // l2
Expand All @@ -419,7 +424,7 @@ void non_local_search_int_vid_backward_cuda(
vid1.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
grad_dists.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
ps, pt, stride0, dilation, reflect_bounds, patch_offset,
ps, pt, stride0, strideQ, dilation, reflect_bounds, patch_offset,
off_Hq, off_Wq, ftrs_per_thread);
}));
}else{
Expand Down
6 changes: 3 additions & 3 deletions lib/csrc/search/refinement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
void refinement_int_forward_cuda(
const torch::Tensor vid0, const torch::Tensor vid1, const torch::Tensor flows,
torch::Tensor dists, torch::Tensor inds,
int ws, int ps, int stride0, int stride1, int dilation, int pt,
int ws, int ps, int stride0, int stride1, int strideQ, int dilation, int pt,
bool restrict_radius, bool reflect_bounds, bool full_ws,
int patch_offset, int off_Hq, int off_Wq, int dist_type);

Expand Down Expand Up @@ -39,7 +39,7 @@ void refinement_bilin2d_vidflows_backward_cuda(
void refinement_int_forward(
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor flows, torch::Tensor dists, torch::Tensor inds,
int ws, int ps, int stride0, int stride1, int dilation, int pt,
int ws, int ps, int stride0, int stride1, int strideQ, int dilation, int pt,
bool restrict_radius, bool reflect_bounds, bool full_ws,
int patch_offset, int off_Hq, int off_Wq, int dist_type){
CHECK_INPUT(vid0);
Expand All @@ -48,7 +48,7 @@ void refinement_int_forward(
CHECK_INPUT(dists);
CHECK_INPUT(inds);
refinement_int_forward_cuda(vid0, vid1, flows, dists, inds,
ws, ps, stride0, stride1, dilation, pt,
ws, ps, stride0, stride1, strideQ, dilation, pt,
restrict_radius, reflect_bounds, full_ws,
patch_offset, off_Hq, off_Wq, dist_type);
}
Expand Down
38 changes: 22 additions & 16 deletions lib/csrc/search/refinement_int_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ __global__ void refinement_forward_kernel(
const torch::PackedTensorAccessor32<int,7,torch::RestrictPtrTraits> flows,
torch::PackedTensorAccessor32<scalar_t,8,torch::RestrictPtrTraits> dists,
torch::PackedTensorAccessor32<int,9,torch::RestrictPtrTraits> inds,
int wr, int ws, int ps, int pt, int stride0, int stride1, int dilation,
int wr, int ws, int ps, int pt,
int stride0, int stride1, int strideQ, int dilation,
bool reflect_bounds, bool full_ws, bool restrict_radius, int patch_offset,
int off_Hq, int off_Wq, int q_per_thread, int k_per_thread, int wr_per_thread){

Expand Down Expand Up @@ -68,6 +69,7 @@ __global__ void refinement_forward_kernel(
int prop_patch[3];
int prop_pix[3];
int ref_patch[3];
int adj_patch[3];
int ref_pix[3];
bool valid;
bool valid_prop[4];
Expand All @@ -85,10 +87,11 @@ __global__ void refinement_forward_kernel(
if (qi >= Q){ continue; }

// -- pixel location from query index --
get_pixel_loc(ref_patch,qi,stride0,nW,nHW,qH,qW);
ti = ref_patch[0];
nh = ref_patch[1]/stride0;
nw = ref_patch[2]/stride0;
get_pixel_loc(ref_patch,qi,strideQ,nW,nHW,qH,qW);
get_pixel_loc(adj_patch,qi,stride0,nW,nHW,kH,kW);
ti = adj_patch[0];
nh = adj_patch[1]/stride0;
nw = adj_patch[2]/stride0;

// -- check bounds of pixel location --
// check_bounds(valid_ref[3],ref_patch,T,H,W);
Expand All @@ -101,9 +104,9 @@ __global__ void refinement_forward_kernel(
if (ki >= Ks){ continue; }

// -- unpack base --
prop_patch[0] = ref_patch[0] + flows[ibatch][ihead_f][ti][nh][nw][ki][0];
prop_center[0] = ref_patch[1] + flows[ibatch][ihead_f][ti][nh][nw][ki][1];
prop_center[1] = ref_patch[2] + flows[ibatch][ihead_f][ti][nh][nw][ki][2];
prop_patch[0] = adj_patch[0] + flows[ibatch][ihead_f][ti][nh][nw][ki][0];
prop_center[0] = adj_patch[1] + flows[ibatch][ihead_f][ti][nh][nw][ki][1];
prop_center[1] = adj_patch[2] + flows[ibatch][ihead_f][ti][nh][nw][ki][2];
prop_patch[0] = bounds(prop_patch[0],T);
prop_center[0] = bounds(prop_center[0],kH);
prop_center[1] = bounds(prop_center[1],kW);
Expand Down Expand Up @@ -164,9 +167,9 @@ __global__ void refinement_forward_kernel(
// -- assignent --
if (!valid){ dist = invalid; }
dists[ibatch][ihead][ti][nh][nw][ki][wh][ww] = dist;
inds[ibatch][ihead][ti][nh][nw][ki][wh][ww][0] = prop_patch[0]-ref_patch[0];
inds[ibatch][ihead][ti][nh][nw][ki][wh][ww][1] = prop_patch[1]-ref_patch[1];
inds[ibatch][ihead][ti][nh][nw][ki][wh][ww][2] = prop_patch[2]-ref_patch[2];
inds[ibatch][ihead][ti][nh][nw][ki][wh][ww][0] = prop_patch[0]-adj_patch[0];
inds[ibatch][ihead][ti][nh][nw][ki][wh][ww][1] = prop_patch[1]-adj_patch[1];
inds[ibatch][ihead][ti][nh][nw][ki][wh][ww][2] = prop_patch[2]-adj_patch[2];

} // ww
} // wh
Expand All @@ -177,9 +180,10 @@ __global__ void refinement_forward_kernel(
void refinement_int_forward_cuda(
const torch::Tensor vid0, const torch::Tensor vid1,
const torch::Tensor flows, torch::Tensor dists, torch::Tensor inds,
int ws, int ps, int stride0, int stride1, int dilation, int pt,
bool restrict_radius, bool reflect_bounds, bool full_ws,
int patch_offset, int off_Hq, int off_Wq, int dist_type){
int ws, int ps, int stride0, int stride1, int strideQ,
int dilation, int pt, bool restrict_radius, bool reflect_bounds,
bool full_ws, int patch_offset, int off_Hq, int off_Wq, int dist_type){


// dists.shape = (B,HD,T,nH,nW,K,wr,wr)
// inds.shape = (B,HD,T,nH,nW,K,wr,wr,3)
Expand Down Expand Up @@ -214,6 +218,8 @@ void refinement_int_forward_cuda(
// q_per_thread = ((Q - 1) / nquery_blocks) + 1;
// dim3 nblocks(nquery_blocks,B,HD);

// int strideQ = ps;

// launch kernel
if (dist_type == 0){
AT_DISPATCH_FLOATING_TYPES(vid0.type(),"refinement_forward_kernel", ([&] {
Expand All @@ -223,7 +229,7 @@ void refinement_int_forward_cuda(
flows.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
dists.packed_accessor32<scalar_t,8,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,9,torch::RestrictPtrTraits>(),
wr, ws, ps, pt, stride0, stride1, dilation,
wr, ws, ps, pt, stride0, stride1, strideQ, dilation,
restrict_radius, reflect_bounds, full_ws, patch_offset,
off_Hq, off_Wq, q_per_thread, k_per_thread, wr_per_thread);
}));
Expand All @@ -235,7 +241,7 @@ void refinement_int_forward_cuda(
flows.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
dists.packed_accessor32<scalar_t,8,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,9,torch::RestrictPtrTraits>(),
wr, ws, ps, pt, stride0, stride1, dilation,
wr, ws, ps, pt, stride0, stride1, strideQ, dilation,
restrict_radius, reflect_bounds, full_ws, patch_offset,
off_Hq, off_Wq, q_per_thread, k_per_thread, wr_per_thread);
}));
Expand Down
5 changes: 2 additions & 3 deletions lib/stnls/agg/scatter_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def run(flows,flows_k,ws,wt,stride0,stride1,H,W,full_ws):

# -- check --
nvalid = (names[...,0] >= 0).float().sum(2)
# if full_ws:
# print(int(nvalid.sum().item()),Q*K)
# # assert(int(nvalid.sum().item()) == Q*K)
if full_ws:
assert(int(nvalid.sum().item()) == Q*K)

return names,labels

Expand Down
2 changes: 1 addition & 1 deletion lib/stnls/agg/scatter_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def run(tensor,flows_k,labels,stride0,stride1,H,W,invalid=th.inf):
# -- unpack shapes --
B,HD,T,nH0,nW0,K = tensor.shape[:6]
Q0 = T*nH0*nW0
S = labels.max().int()+1
S = labels.max().long().item()+1
tensor = tensor.reshape(B,HD,Q0,K,-1)
M = tensor.shape[-1]
nH1 = (H-1)//stride1+1
Expand Down
Loading

0 comments on commit ffc05a8

Please sign in to comment.