Skip to content

Commit

Permalink
restructure wpsum_int launching
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 3, 2023
1 parent d14ac08 commit d3b0553
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 31 deletions.
15 changes: 15 additions & 0 deletions lib/csrc/agg/wpsum_bilin2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ void wpsum_bilin2d_forward_cuda(
int Q = inds.size(2);
int q_per_thread = 1;

// // -- kernel threads --
// int MAX_THREADS = 512;//1024
// int k_threads = 8;
// int q_threads = MAX_THREADS/(k_threads); // num of queries threads per block
// q_threads = min(Q,q_threads);
// int q_blocks = (Q-1)/(q_per_thread*q_threads)+1;
// int k_blocks = (K-1)/(k_threads)+1;
// dim3 nthreads(q_threads,k_threads);
// // fprintf(stdout,
// // "ps,pt,stride0,reflect_bounds,dilation,patch_offset: %d,%d,%d,%d,%d,%d\n",
// // ps,pt,stride0,reflect_bounds,dilation,patch_offset);

// // -- kernel blocks --
// dim3 nblocks(q_blocks,k_blocks,B*HD);

// -- kernel threads --
int MAX_THREADS = 512;//1024
int q_threads = MAX_THREADS/(ps*ps); // num of queries threads per block
Expand Down
82 changes: 53 additions & 29 deletions lib/csrc/agg/wpsum_int_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@ __global__ void wpsum_int_forward_kernel(
int K = inds.size(3);

// -- batching --
int query_start = (threadIdx.x + blockDim.x*blockIdx.x)*q_per_thread;
int ibatch = blockIdx.y;
int ihead = blockIdx.z;

// -- cuda threads --
int pi = threadIdx.y;
int pj = threadIdx.z;
// int query_start = (threadIdx.x + blockDim.x*blockIdx.x)*q_per_thread;
int query_start = blockIdx.x*blockDim.x+threadIdx.x;
int ki = blockIdx.y*blockDim.y+threadIdx.y;
int ihead = blockIdx.z/B;
int ibatch = (blockIdx.z-ihead*B) % B;
// int ibatch = blockIdx.y;
// int ihead = blockIdx.z;

// // -- cuda threads --
// int pi = threadIdx.y;
// int pj = threadIdx.z;

// -- pixel locations --
int qi;
Expand All @@ -58,22 +62,26 @@ __global__ void wpsum_int_forward_kernel(
if (qi >= Q){ continue; }
get_pixel_loc<int>(ref,qi,stride0,nW,nHW,H,W);

// -- reference pixel index --
ref_p[0] = ref[0];
ref_p[1] = ref[1]+dilation*(pi + patch_offset);
ref_p[2] = ref[2]+dilation*(pj + patch_offset);

// -- valid ref pixel only --
check_bounds(valid, ref_p, T, H, W);
if (not valid){ continue; }

// -- normalize --
if ((ref[0]==0) and (ibatch==0) and (ihead==0)){
atomicAdd(&counts[ref_p[1]][ref_p[2]],1);
}

for(int ki = 0; ki < K; ki++){
// -- iterate over patches --
for(int pi=0; pi < ps; pi++){
for(int pj=0; pj < ps; pj++){

// -- reference pixel index --
ref_p[0] = ref[0];
ref_p[1] = ref[1]+dilation*(pi + patch_offset);
ref_p[2] = ref[2]+dilation*(pj + patch_offset);

// -- valid ref pixel only --
check_bounds(valid, ref_p, T, H, W);
if (not valid){ continue; }

// -- normalize --
if ((ref[0]==0) and (ibatch==0) and (ihead==0) and (ki==0)){
atomicAdd(&counts[ref_p[1]][ref_p[2]],1);
}


// -- non-local index --
#pragma unroll
for (int _idx=0; _idx < 3; _idx++){
Expand Down Expand Up @@ -118,8 +126,8 @@ __global__ void wpsum_int_forward_kernel(

} // nfeatures-loop
} // pt-loop
} // k-loop
} // query-loop
}} // pi,pj
} // query-loop
}

void wpsum_int_forward_cuda(
Expand All @@ -133,18 +141,34 @@ void wpsum_int_forward_cuda(
int B = inds.size(0);
int HD = inds.size(1);
int Q = inds.size(2);
int K = inds.size(3);
int q_per_thread = 2;

// -- kernel threads --
int MAX_THREADS = 1024;
int q_threads = MAX_THREADS/(ps*ps); // num of queries threads per block
int MAX_THREADS = 512;//1024
int k_threads = 8;
int q_threads = MAX_THREADS/(k_threads); // num of queries threads per block
q_threads = min(Q,q_threads);
int q_blocks = (Q-1)/(q_per_thread*q_threads)+1;
dim3 nthreads(q_threads,ps,ps);
// fprintf(stdout,"ps,reflect_bounds,patch_offset: %d,%d,%d\n",ps,reflect_bounds,patch_offset);

int k_blocks = (K-1)/(k_threads)+1;
dim3 nthreads(q_threads,k_threads);
// fprintf(stdout,
// "ps,pt,stride0,reflect_bounds,dilation,patch_offset: %d,%d,%d,%d,%d,%d\n",
// ps,pt,stride0,reflect_bounds,dilation,patch_offset);
// -- kernel blocks --
dim3 nblocks(q_blocks,B,HD);
dim3 nblocks(q_blocks,k_blocks,B*HD);


// // -- kernel threads --
// int MAX_THREADS = 1024;
// int q_threads = MAX_THREADS/(ps*ps); // num of queries threads per block
// q_threads = min(Q,q_threads);
// int q_blocks = (Q-1)/(q_per_thread*q_threads)+1;
// dim3 nthreads(q_threads,ps,ps);
// // fprintf(stdout,"ps,reflect_bounds,patch_offset: %d,%d,%d\n",ps,reflect_bounds,patch_offset);

// // -- kernel blocks --
// dim3 nblocks(q_blocks,B,HD);

// -- launch kernel --
AT_DISPATCH_FLOATING_TYPES(in_vid.type(), "wpsum_int_forward_kernel", ([&] {
Expand Down
9 changes: 9 additions & 0 deletions lib/stnls/agg/scatter_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@

def run(flows,flows_k,ws,wt,stride0,stride1,H,W,full_ws):

"""
There is a minimum and maximum (ws) depending on (stride0)
- [max] we don't want overlap of query points
- [min] we don't want skipped key points
"""

# -- unpack shapes --
B,HD,T,nH,nW,K,_ = flows_k.shape
# B,HD,T,W_t,2,nH,nW = flows.shape
Expand Down
7 changes: 5 additions & 2 deletions lib/stnls/agg/scatter_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ def run(tensor,flows_k,labels,stride0,stride1,H,W):

# return flows_k

def run_topk(weights,flows_k,K,descending=True):
def run_topk(weights,flows_k,labels,K,descending=True):

# -- reshape --
B,HD,Q,S,_ = flows_k.shape
weights = rearrange(weights,'b hd q s -> (b hd q) s')
flows_k = rearrange(flows_k,'b hd q s tr -> (b hd q) s tr')
labels = rearrange(labels,'b hd q s -> (b hd q) s')
# names = rearrange(names,'b hd s t nh nw tw -> (b hd t nh nw) s tw')
device = weights.device

Expand All @@ -90,6 +91,7 @@ def run_topk(weights,flows_k,K,descending=True):

# -- get topk --
weights = th.gather(weights,-1,order)
labels = th.gather(labels,-1,order)

flows_topk = -th.inf*th.ones(weights.shape+(3,),device=device,dtype=flows_k.dtype)
for i in range(flows_k.shape[-1]):
Expand All @@ -101,7 +103,8 @@ def run_topk(weights,flows_k,K,descending=True):

# -- unpack --
weights = rearrange(weights,'(b hd q) k -> b hd q k',b=B,hd=HD)
labels = rearrange(labels,'(b hd q) k -> b hd q k',b=B,hd=HD)
flows_topk = rearrange(flows_topk,'(b hd q) k tr -> b hd q k tr',b=B,hd=HD)
flows_topk = flows_topk.type(flows_k.dtype)

return weights,flows_topk
return weights,flows_topk,labels
3 changes: 3 additions & 0 deletions lib/stnls/agg/wpsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def forward(ctx, vid, weights, flows, ps, stride0,
# -- normalize --
H,W = vid.shape[-2:]
# print(counts)
# print(counts.sum(-1))
# print(counts.sum(-2))
# exit()
counts = counts.view((1,1,1,1,H,W))
out_vid = out_vid / (counts+eps)
assert th.all(counts>1e-3)
Expand Down

0 comments on commit d3b0553

Please sign in to comment.