Skip to content

Commit

Permalink
updated slic; fixed modded agg; added slic in dev
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 3, 2023
1 parent 99a800b commit 0dc1a46
Show file tree
Hide file tree
Showing 14 changed files with 1,150 additions and 47 deletions.
158 changes: 158 additions & 0 deletions lib/csrc/agg/pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#include <torch/extension.h>

#include <vector>

// CUDA forward declarations

/*************************************
Int Forward
*************************************/

void pool_int_forward_cuda(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int stride0, int pt, int dilation,
bool reflect_bounds, int patch_offset);

void pool_int_backward_cuda(
torch::Tensor in_vid_grad,
torch::Tensor dists_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset);

/*************************************
Bilin2d Forward
*************************************/

// void pool_bilin2d_forward_cuda(
// torch::Tensor out_vid, torch::Tensor counts,
// const torch::Tensor in_vid,
// const torch::Tensor dists, const torch::Tensor inds,
// int ps, int stride0, int pt, int dilation,
// bool reflect_bounds, int patch_offset);

// void pool_bilin2d_backward_cuda(
// torch::Tensor in_vid_grad,
// torch::Tensor dists_grad,
// torch::Tensor inds_grad,
// const torch::Tensor out_vid_grad, const torch::Tensor vid,
// const torch::Tensor dists, const torch::Tensor inds,
// int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset);


// C++ interface

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

/***********************
Int Indexing
***********************/


void pool_int_forward(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists,
const torch::Tensor inds,
int ps, int stride0, int pt, int dilation,
bool reflect_bounds, int patch_offset){
CHECK_INPUT(out_vid);
CHECK_INPUT(counts);
CHECK_INPUT(in_vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
pool_int_forward_cuda(out_vid,counts,in_vid,dists,inds,
ps,stride0,pt,dilation,reflect_bounds,patch_offset);
}

void pool_int_backward( // "in" and "out" w.r.t. forward pass
torch::Tensor in_vid_grad, torch::Tensor dists_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset){
CHECK_INPUT(in_vid_grad);
CHECK_INPUT(dists_grad);
CHECK_INPUT(out_vid_grad);
CHECK_INPUT(vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
pool_int_backward_cuda(in_vid_grad,dists_grad,
out_vid_grad,vid,dists,inds,
ps,stride0,pt,dilation,reflect_bounds,patch_offset);
}

/***********************
Bilinear2d
***********************/

void pool_bilin2d_forward(
torch::Tensor out_vid, torch::Tensor counts,
const torch::Tensor in_vid,
const torch::Tensor dists,
const torch::Tensor inds,
int ps, int stride0, int pt, int dilation,
bool reflect_bounds, int patch_offset){
CHECK_INPUT(out_vid);
CHECK_INPUT(counts);
CHECK_INPUT(in_vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
// pool_bilin2d_forward_cuda(out_vid,counts,in_vid,dists,inds,
// ps,stride0,pt,dilation,reflect_bounds,patch_offset);
}

void pool_bilin2d_backward( // "in" and "out" w.r.t. forward pass
torch::Tensor in_vid_grad,
torch::Tensor dists_grad, torch::Tensor inds_grad,
const torch::Tensor out_vid_grad, const torch::Tensor vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset){
CHECK_INPUT(in_vid_grad);
CHECK_INPUT(dists_grad);
CHECK_INPUT(inds_grad);
CHECK_INPUT(out_vid_grad);
CHECK_INPUT(vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
// pool_bilin2d_backward_cuda(in_vid_grad,dists_grad,
// inds_grad,
// out_vid_grad,vid,dists,inds,
// ps,stride0,pt,dilation,reflect_bounds,patch_offset);
}

/***********************
Python Bindings
***********************/

void init_pool(py::module &m){
m.def("pool_int_forward", &pool_int_forward,
"WeightedPatchSum Forward (CUDA)");
m.def("pool_int_backward", &pool_int_backward,
"WeightedPatchSum Backward (CUDA)");
// m.def("pool_bilin2d_forward", &pool_bilin2d_forward,
// "WeightedPatchSum Forward (CUDA)");
// m.def("pool_bilin2d_backward", &pool_bilin2d_backward,
// "WeightedPatchSum Backward (CUDA)");

}

Loading

0 comments on commit 0dc1a46

Please sign in to comment.