Skip to content

Commit

Permalink
renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 6, 2023
1 parent ffc05a8 commit 2fdc2e1
Show file tree
Hide file tree
Showing 25 changed files with 1,292 additions and 129 deletions.
34 changes: 17 additions & 17 deletions lib/csrc/agg/wpsum.cpp → lib/csrc/agg/gather_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
*************************************/

void wpsum_int_forward_cuda(
void gather_add_int_forward_cuda(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int stride0, int pt, int dilation,
bool reflect_bounds, int patch_offset);

void wpsum_int_backward_cuda(
void gather_add_int_backward_cuda(
torch::Tensor in_vid_grad,
torch::Tensor dists_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
Expand All @@ -30,14 +30,14 @@ void wpsum_int_backward_cuda(
*************************************/

void wpsum_bilin2d_forward_cuda(
void gather_add_bilin2d_forward_cuda(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int stride0, int pt, int dilation,
bool reflect_bounds, int patch_offset);

void wpsum_bilin2d_backward_cuda(
void gather_add_bilin2d_backward_cuda(
torch::Tensor in_vid_grad,
torch::Tensor dists_grad,
torch::Tensor inds_grad,
Expand All @@ -61,7 +61,7 @@ void wpsum_bilin2d_backward_cuda(
***********************/


void wpsum_int_forward(
void gather_add_int_forward(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists,
Expand All @@ -73,11 +73,11 @@ void wpsum_int_forward(
CHECK_INPUT(in_vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
wpsum_int_forward_cuda(out_vid,counts,in_vid,dists,inds,
gather_add_int_forward_cuda(out_vid,counts,in_vid,dists,inds,
ps,stride0,pt,dilation,reflect_bounds,patch_offset);
}

void wpsum_int_backward( // "in" and "out" w.r.t. forward pass
void gather_add_int_backward( // "in" and "out" w.r.t. forward pass
torch::Tensor in_vid_grad, torch::Tensor dists_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
const torch::Tensor dists, const torch::Tensor inds,
Expand All @@ -88,7 +88,7 @@ void wpsum_int_backward( // "in" and "out" w.r.t. forward pass
CHECK_INPUT(vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
wpsum_int_backward_cuda(in_vid_grad,dists_grad,
gather_add_int_backward_cuda(in_vid_grad,dists_grad,
out_vid_grad,vid,dists,inds,
ps,stride0,pt,dilation,reflect_bounds,patch_offset);
}
Expand All @@ -101,7 +101,7 @@ void wpsum_int_backward( // "in" and "out" w.r.t. forward pass
***********************/

void wpsum_bilin2d_forward(
void gather_add_bilin2d_forward(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists,
Expand All @@ -113,11 +113,11 @@ void wpsum_bilin2d_forward(
CHECK_INPUT(in_vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
wpsum_bilin2d_forward_cuda(out_vid,counts,in_vid,dists,inds,
gather_add_bilin2d_forward_cuda(out_vid,counts,in_vid,dists,inds,
ps,stride0,pt,dilation,reflect_bounds,patch_offset);
}

void wpsum_bilin2d_backward( // "in" and "out" w.r.t. forward pass
void gather_add_bilin2d_backward( // "in" and "out" w.r.t. forward pass
torch::Tensor in_vid_grad,
torch::Tensor dists_grad, torch::Tensor inds_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
Expand All @@ -130,7 +130,7 @@ void wpsum_bilin2d_backward( // "in" and "out" w.r.t. forward pass
CHECK_INPUT(vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
wpsum_bilin2d_backward_cuda(in_vid_grad,dists_grad,
gather_add_bilin2d_backward_cuda(in_vid_grad,dists_grad,
inds_grad,
out_vid_grad,vid,dists,inds,
ps,stride0,pt,dilation,reflect_bounds,patch_offset);
Expand All @@ -144,14 +144,14 @@ void wpsum_bilin2d_backward( // "in" and "out" w.r.t. forward pass
***********************/

void init_wpsum(py::module &m){
m.def("wpsum_int_forward", &wpsum_int_forward,
void init_gather_add(py::module &m){
m.def("gather_add_int_forward", &gather_add_int_forward,
"WeightedPatchSum Forward (CUDA)");
m.def("wpsum_int_backward", &wpsum_int_backward,
m.def("gather_add_int_backward", &gather_add_int_backward,
"WeightedPatchSum Backward (CUDA)");
m.def("wpsum_bilin2d_forward", &wpsum_bilin2d_forward,
m.def("gather_add_bilin2d_forward", &gather_add_bilin2d_forward,
"WeightedPatchSum Forward (CUDA)");
m.def("wpsum_bilin2d_backward", &wpsum_bilin2d_backward,
m.def("gather_add_bilin2d_backward", &gather_add_bilin2d_backward,
"WeightedPatchSum Backward (CUDA)");

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
****************************/

template <typename scalar_t>
__global__ void wpsum_bilin2d_forward_kernel(
__global__ void gather_add_bilin2d_forward_kernel(
torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> out_vid,
torch::PackedTensorAccessor32<int,2,torch::RestrictPtrTraits> counts,
const torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> in_vid,
Expand Down Expand Up @@ -124,7 +124,7 @@ __global__ void wpsum_bilin2d_forward_kernel(
} // query-loop
}

void wpsum_bilin2d_forward_cuda(
void gather_add_bilin2d_forward_cuda(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists, const torch::Tensor inds,
Expand Down Expand Up @@ -166,8 +166,9 @@ void wpsum_bilin2d_forward_cuda(
dim3 nblocks(q_blocks,B,HD);

// -- launch kernel --
AT_DISPATCH_FLOATING_TYPES(in_vid.type(), "wpsum_bilin2d_forward_kernel", ([&] {
wpsum_bilin2d_forward_kernel<scalar_t><<<nblocks, nthreads>>>(
AT_DISPATCH_FLOATING_TYPES(in_vid.type(),
"gather_add_bilin2d_forward_kernel", ([&] {
gather_add_bilin2d_forward_kernel<scalar_t><<<nblocks, nthreads>>>(
out_vid.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
counts.packed_accessor32<int,2,torch::RestrictPtrTraits>(),
in_vid.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
Expand All @@ -187,7 +188,7 @@ void wpsum_bilin2d_forward_cuda(
*************************************/

template <typename scalar_t>
__global__ void wpsum_bilin2d_backward_kernel(
__global__ void gather_add_bilin2d_backward_kernel(
torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> in_vid_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> dists_grad,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> inds_grad,
Expand Down Expand Up @@ -325,7 +326,7 @@ __global__ void wpsum_bilin2d_backward_kernel(
} // qi
}

void wpsum_bilin2d_backward_cuda(
void gather_add_bilin2d_backward_cuda(
torch::Tensor in_vid_grad,
torch::Tensor dists_grad, torch::Tensor inds_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
Expand Down Expand Up @@ -355,9 +356,9 @@ void wpsum_bilin2d_backward_cuda(
// fprintf(stdout,"q_threads: %d\n",q_threads);

// launch kernel
AT_DISPATCH_FLOATING_TYPES(in_vid_grad.type(), "wpsum_bilin2d_backward_vid_kernel",
([&] {
wpsum_bilin2d_backward_kernel<scalar_t><<<nblocks, nthreads>>>(
AT_DISPATCH_FLOATING_TYPES(in_vid_grad.type(),
"gather_add_bilin2d_backward_vid_kernel", ([&] {
gather_add_bilin2d_backward_kernel<scalar_t><<<nblocks, nthreads>>>(
in_vid_grad.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
dists_grad.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
inds_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
****************************/

template <typename scalar_t>
__global__ void wpsum_int_forward_kernel(
__global__ void gather_add_int_forward_kernel(
torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> out_vid,
torch::PackedTensorAccessor32<int,2,torch::RestrictPtrTraits> counts,
const torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> in_vid,
Expand Down Expand Up @@ -63,6 +63,23 @@ __global__ void wpsum_int_forward_kernel(
if (qi >= Q){ continue; }
get_pixel_loc<int>(ref,qi,stride0,nW,nHW,H,W);

// -- non-local index --
#pragma unroll
for (int _idx=0; _idx < 3; _idx++){
nl[_idx] = ref[_idx] + inds[ibatch][ihead][qi][ki][_idx];
}

// -- check "inf" (but it won't be inf sometimes) --
valid = (abs(nl[1]) < 1e7) and (abs(nl[2]) < 1e7);
if (not(valid)){ continue; }

// -- always reflect anchor point --
nl[0] = bounds(nl[0],T);
nl[1] = bounds(nl[1],H);
nl[2] = bounds(nl[2],W);

// -- non-local weight --
weight = dists[ibatch][ihead][qi][ki];

// -- iterate over patches --
for(int pi=0; pi < ps; pi++){
Expand All @@ -72,47 +89,25 @@ __global__ void wpsum_int_forward_kernel(
ref_p[0] = ref[0];
ref_p[1] = ref[1]+dilation*(pi + patch_offset);
ref_p[2] = ref[2]+dilation*(pj + patch_offset);

// -- valid ref pixel only --
check_bounds(valid, ref_p, T, H, W);
if (not valid){ continue; }

// -- normalize --
// -- increment legal refs --
if ((ref[0]==0) and (ibatch==0) and (ihead==0) and (ki==0)){
atomicAdd(&counts[ref_p[1]][ref_p[2]],1);
}


// -- non-local index --
#pragma unroll
for (int _idx=0; _idx < 3; _idx++){
nl[_idx] = ref[_idx] + inds[ibatch][ihead][qi][ki][_idx];
}

// -- check "inf" (but it won't be inf sometimes) --
valid = (abs(nl[1]) < 1e7) and (abs(nl[2]) < 1e7);
if (not(valid)){ continue; }

// -- always reflect anchor point --
nl[0] = bounds(nl[0],T);
nl[1] = bounds(nl[1],H);
nl[2] = bounds(nl[2],W);

// -- non-local pixel index --
nl[1] = nl[1]+dilation*(pi + patch_offset);
nl[1] = reflect_bounds ? bounds(nl[1],H) : nl[1];
nl[2] = nl[2]+dilation*(pj + patch_offset);
nl[2] = reflect_bounds ? bounds(nl[2],W) : nl[2];

// -- valid non-local patches only --
valid = (nl[0] >= 0) && (nl[0] < T);
valid = valid && (nl[1] >= 0) && (nl[1] < H);
valid = valid && (nl[2] >= 0) && (nl[2] < W);
check_bounds(valid, nl, T, H, W);
// valid = (nl[0] >= 0) && (nl[0] < T);
// valid = valid && (nl[1] >= 0) && (nl[1] < H);
// valid = valid && (nl[2] >= 0) && (nl[2] < W);
if (not valid){ continue; }

// -- non-local weight --
weight = dists[ibatch][ihead][qi][ki];

// -- iterate over loop --
for(int pk = 0; pk < pt; pk++){

Expand All @@ -135,7 +130,7 @@ __global__ void wpsum_int_forward_kernel(
} // query-loop
}

void wpsum_int_forward_cuda(
void gather_add_int_forward_cuda(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists, const torch::Tensor inds,
Expand Down Expand Up @@ -176,8 +171,9 @@ void wpsum_int_forward_cuda(
// dim3 nblocks(q_blocks,B,HD);

// -- launch kernel --
AT_DISPATCH_FLOATING_TYPES(in_vid.type(), "wpsum_int_forward_kernel", ([&] {
wpsum_int_forward_kernel<scalar_t><<<nblocks, nthreads>>>(
AT_DISPATCH_FLOATING_TYPES(in_vid.type(),
"gather_add_int_forward_kernel", ([&] {
gather_add_int_forward_kernel<scalar_t><<<nblocks, nthreads>>>(
out_vid.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
counts.packed_accessor32<int,2,torch::RestrictPtrTraits>(),
in_vid.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
Expand All @@ -197,7 +193,7 @@ void wpsum_int_forward_cuda(
*************************************/

template <typename scalar_t>
__global__ void wpsum_int_backward_kernel(
__global__ void gather_add_int_backward_kernel(
torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> in_vid_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> dists_grad,
const torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> out_vid_grad,
Expand Down Expand Up @@ -307,7 +303,7 @@ __global__ void wpsum_int_backward_kernel(
} // qi
}

void wpsum_int_backward_cuda(
void gather_add_int_backward_cuda(
torch::Tensor in_vid_grad, torch::Tensor dists_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
const torch::Tensor dists, const torch::Tensor inds,
Expand Down Expand Up @@ -336,8 +332,9 @@ void wpsum_int_backward_cuda(
// fprintf(stdout,"q_threads: %d\n",q_threads);

// launch kernel
AT_DISPATCH_FLOATING_TYPES(in_vid_grad.type(), "wpsum_int_backward_vid_kernel", ([&] {
wpsum_int_backward_kernel<scalar_t><<<nblocks, nthreads>>>(
AT_DISPATCH_FLOATING_TYPES(in_vid_grad.type(),
"gather_add_int_backward_vid_kernel", ([&] {
gather_add_int_backward_kernel<scalar_t><<<nblocks, nthreads>>>(
in_vid_grad.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
dists_grad.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
out_vid_grad.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
Expand Down
74 changes: 74 additions & 0 deletions lib/csrc/agg/scatter_add.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// imports
#include <torch/extension.h>
// #include <torch/types.h>
#include <vector>
// #include "pybind.hpp"


// CUDA forward declarations

void scatter_add_int_forward_cuda(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int strideIn, int strideOut, int pt,
int dilation, bool reflect_bounds, int patch_offset);

// C++ interface

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

/*********************************
Using Raster Order
*********************************/


void scatter_add_int_forward(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists,
const torch::Tensor inds,
int ps, int strideIn, int strideOut, int pt,
int dilation, bool reflect_bounds, int patch_offset){
CHECK_INPUT(out_vid);
CHECK_INPUT(counts);
CHECK_INPUT(in_vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
scatter_add_int_forward_cuda(out_vid,counts,in_vid,dists,inds,
ps,strideIn,strideOut,pt,dilation,
reflect_bounds,patch_offset);
}


// void scatter_add_int_backward(
// torch::Tensor out_grad, const torch::Tensor in_grad,
// const torch::Tensor vid, const torch::Tensor weights,
// const torch::Tensor inds, const torch::Tensor labels,
// torch::Tensor stack, torch::Tensor mask, torch::Tensor counts,
// int ps, int pt, int dilation, int stride0, bool reflect_bounds, int patch_offset){
// CHECK_INPUT(vid);
// CHECK_INPUT(weights);
// CHECK_INPUT(inds);
// CHECK_INPUT(labels);
// CHECK_INPUT(stack);
// CHECK_INPUT(mask);
// CHECK_INPUT(counts);
// scatter_add_int_backward_cuda(vid,weights,inds,labels,stack,mask,counts,
// ps,pt,dilation,stride0,
// reflect_bounds,patch_offset);
// }

// -- python bindings --
void init_scatter_add(py::module &m){
m.def("scatter_add_int_forward",
&scatter_add_int_forward,
"Scatter Forward with Int Indexing");
m.def("scatter_add_int_backward",
&scatter_add_int_backward,
"Scatter Backward with Int Indexing");
}
Loading

0 comments on commit 2fdc2e1

Please sign in to comment.