diff --git a/pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu b/pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu new file mode 100644 index 0000000..38c0063 --- /dev/null +++ b/pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu @@ -0,0 +1,73 @@ +/* +batch version of ball query, modified from the original implementation of official PointNet++ codes. +Written by Shaoshuai Shi +All Rights Reserved 2018. +*/ + +#include +#include +#include + +#include "ball_query_gpu.h" +#include "cuda_utils.h" + + +__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, + const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= m) return; + + new_xyz += bs_idx * m * 3 + pt_idx * 3; + xyz += bs_idx * n * 3; + idx += bs_idx * m * nsample + pt_idx * nsample; + + float radius2 = radius * radius; + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + int cnt = 0; + for (int k = 0; k < n; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + if (d2 < radius2){ + if (cnt == 0){ + for (int l = 0; l < nsample; ++l) { + idx[l] = k; + } + } + idx[cnt] = k; + ++cnt; + if (cnt >= nsample) break; + } + } +} + + +void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ + const float *new_xyz, const float *xyz, int *idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + + cudaError_t err; + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +}