Skip to content

Commit

Permalink
fix: iou3d build compatibility errors
Browse files Browse the repository at this point in the history
Fix open-mmlab#681.
1. const float -> constexpr float, referring to opencv/opencv#13491 and opencv/opencv#13960
2. rename check_rect_cross function to resolve name conflicts
3. change long into int32_t (unsigned long long seems no problem yet)
  • Loading branch information
yihuajack committed Jul 24, 2022
1 parent c233477 commit fe62793
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
16 changes: 8 additions & 8 deletions pcdet/ops/iou3d_nms/src/iou3d_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){

int boxes_num = boxes.size(0);
const float * boxes_data = boxes.data<float>();
long * keep_data = keep.data<long>();
int32_t * keep_data = keep.data<int32_t>();

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

Expand All @@ -107,14 +107,14 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){
// unsigned long long *mask_cpu = new unsigned long long [boxes_num * col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);

// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));

cudaFree(mask_data);

unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
std::vector<unsigned long long> remv_cpu(col_blocks);
memset(&remv_cpu[0], 0, col_blocks * sizeof(unsigned long long));

int num_to_keep = 0;

Expand Down Expand Up @@ -145,7 +145,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){

int boxes_num = boxes.size(0);
const float * boxes_data = boxes.data<float>();
long * keep_data = keep.data<long>();
int32_t * keep_data = keep.data<int32_t>();

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

Expand All @@ -157,14 +157,14 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){
// unsigned long long *mask_cpu = new unsigned long long [boxes_num * col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);

// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));

cudaFree(mask_data);

unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
std::vector<unsigned long long> remv_cpu(col_blocks);
memset(&remv_cpu[0], 0, col_blocks * sizeof(unsigned long long));

int num_to_keep = 0;

Expand Down
6 changes: 3 additions & 3 deletions pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ All Rights Reserved 2019-2020.

// #define DEBUG
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
const float EPS = 1e-8;
constexpr float EPS = 1e-8;
struct Point {
float x, y;
__device__ Point() {}
Expand Down Expand Up @@ -40,7 +40,7 @@ __device__ inline float cross(const Point &p1, const Point &p2, const Point &p0)
return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
}

__device__ int check_rect_cross(const Point &p1, const Point &p2, const Point &q1, const Point &q2){
__device__ int check_rect_cross_cuda(const Point &p1, const Point &p2, const Point &q1, const Point &q2){
int ret = min(p1.x,p2.x) <= max(q1.x,q2.x) &&
min(q1.x,q2.x) <= max(p1.x,p2.x) &&
min(p1.y,p2.y) <= max(q1.y,q2.y) &&
Expand All @@ -62,7 +62,7 @@ __device__ inline int check_in_box2d(const float *box, const Point &p){

__device__ inline int intersection(const Point &p1, const Point &p0, const Point &q1, const Point &q0, Point &ans){
// fast exclusion
if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
if (check_rect_cross_cuda(p0, p1, q0, q1) == 0) return 0;

// check cross standing
float s1 = cross(q0, p1, p0);
Expand Down

0 comments on commit fe62793

Please sign in to comment.