diff --git a/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp b/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp index d41da8ad0..62d3c9db3 100644 --- a/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp +++ b/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp @@ -95,7 +95,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){ int boxes_num = boxes.size(0); const float * boxes_data = boxes.data(); - long * keep_data = keep.data(); + int32_t * keep_data = keep.data(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -107,14 +107,14 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){ // unsigned long long *mask_cpu = new unsigned long long [boxes_num * col_blocks]; std::vector mask_cpu(boxes_num * col_blocks); -// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long), cudaMemcpyDeviceToHost)); cudaFree(mask_data); - unsigned long long remv_cpu[col_blocks]; - memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); + std::vector remv_cpu(col_blocks); + memset(&remv_cpu[0], 0, col_blocks * sizeof(unsigned long long)); int num_to_keep = 0; @@ -145,7 +145,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){ int boxes_num = boxes.size(0); const float * boxes_data = boxes.data(); - long * keep_data = keep.data(); + int32_t * keep_data = keep.data(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); @@ -157,14 +157,14 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){ // unsigned long long *mask_cpu = new unsigned long long [boxes_num * col_blocks]; std::vector mask_cpu(boxes_num * col_blocks); -// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long), cudaMemcpyDeviceToHost)); cudaFree(mask_data); - unsigned long long remv_cpu[col_blocks]; - memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); + std::vector remv_cpu(col_blocks); + memset(&remv_cpu[0], 0, col_blocks * sizeof(unsigned long long)); int num_to_keep = 0; diff --git a/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu b/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu index e5e305cdb..c17d052f2 100644 --- a/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu +++ b/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu @@ -11,7 +11,7 @@ All Rights Reserved 2019-2020. // #define DEBUG const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; -const float EPS = 1e-8; +constexpr float EPS = 1e-8; struct Point { float x, y; __device__ Point() {} @@ -40,7 +40,7 @@ __device__ inline float cross(const Point &p1, const Point &p2, const Point &p0) return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); } -__device__ int check_rect_cross(const Point &p1, const Point &p2, const Point &q1, const Point &q2){ +__device__ int check_rect_cross_cuda(const Point &p1, const Point &p2, const Point &q1, const Point &q2){ int ret = min(p1.x,p2.x) <= max(q1.x,q2.x) && min(q1.x,q2.x) <= max(p1.x,p2.x) && min(p1.y,p2.y) <= max(q1.y,q2.y) && @@ -62,7 +62,7 @@ __device__ inline int check_in_box2d(const float *box, const Point &p){ __device__ inline int intersection(const Point &p1, const Point &p0, const Point &q1, const Point &q0, Point &ans){ // fast exclusion - if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; + if (check_rect_cross_cuda(p0, p1, q0, q1) == 0) return 0; // check cross standing float s1 = cross(q0, p1, p0);