114 lines
3.9 KiB
Plaintext
114 lines
3.9 KiB
Plaintext
|
|
#include <math.h>
|
||
|
|
#include <stdio.h>
|
||
|
|
#include <stdlib.h>
|
||
|
|
#include <curand_kernel.h>
|
||
|
|
|
||
|
|
#include "voxel_query_gpu.h"
|
||
|
|
#include "cuda_utils.h"
|
||
|
|
|
||
|
|
|
||
|
|
__global__ void voxel_query_kernel_stack(int M, int R1, int R2, int R3, int nsample,
|
||
|
|
float radius, int z_range, int y_range, int x_range, const float *new_xyz,
|
||
|
|
const float *xyz, const int *new_coords, const int *point_indices, int *idx) {
|
||
|
|
// :param new_coords: (M1 + M2 ..., 4) centers of the ball query
|
||
|
|
// :param point_indices: (B, Z, Y, X)
|
||
|
|
// output:
|
||
|
|
// idx: (M1 + M2, nsample)
|
||
|
|
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
|
|
if (pt_idx >= M) return;
|
||
|
|
|
||
|
|
new_xyz += pt_idx * 3;
|
||
|
|
new_coords += pt_idx * 4;
|
||
|
|
idx += pt_idx * nsample;
|
||
|
|
|
||
|
|
curandState state;
|
||
|
|
curand_init(pt_idx, 0, 0, &state);
|
||
|
|
|
||
|
|
float radius2 = radius * radius;
|
||
|
|
float new_x = new_xyz[0];
|
||
|
|
float new_y = new_xyz[1];
|
||
|
|
float new_z = new_xyz[2];
|
||
|
|
|
||
|
|
int batch_idx = new_coords[0];
|
||
|
|
int new_coords_z = new_coords[1];
|
||
|
|
int new_coords_y = new_coords[2];
|
||
|
|
int new_coords_x = new_coords[3];
|
||
|
|
|
||
|
|
int cnt = 0;
|
||
|
|
int cnt2 = 0;
|
||
|
|
// for (int dz = -1*z_range; dz <= z_range; ++dz) {
|
||
|
|
for (int dz = -1*z_range; dz <= z_range; ++dz) {
|
||
|
|
int z_coord = new_coords_z + dz;
|
||
|
|
if (z_coord < 0 || z_coord >= R1) continue;
|
||
|
|
|
||
|
|
for (int dy = -1*y_range; dy <= y_range; ++dy) {
|
||
|
|
int y_coord = new_coords_y + dy;
|
||
|
|
if (y_coord < 0 || y_coord >= R2) continue;
|
||
|
|
|
||
|
|
for (int dx = -1*x_range; dx <= x_range; ++dx) {
|
||
|
|
int x_coord = new_coords_x + dx;
|
||
|
|
if (x_coord < 0 || x_coord >= R3) continue;
|
||
|
|
|
||
|
|
int index = batch_idx * R1 * R2 * R3 + \
|
||
|
|
z_coord * R2 * R3 + \
|
||
|
|
y_coord * R3 + \
|
||
|
|
x_coord;
|
||
|
|
int neighbor_idx = point_indices[index];
|
||
|
|
if (neighbor_idx < 0) continue;
|
||
|
|
|
||
|
|
float x_per = xyz[neighbor_idx*3 + 0];
|
||
|
|
float y_per = xyz[neighbor_idx*3 + 1];
|
||
|
|
float z_per = xyz[neighbor_idx*3 + 2];
|
||
|
|
|
||
|
|
float dist2 = (x_per - new_x) * (x_per - new_x) + (y_per - new_y) * (y_per - new_y) + (z_per - new_z) * (z_per - new_z);
|
||
|
|
|
||
|
|
if (dist2 > radius2) continue;
|
||
|
|
|
||
|
|
++cnt2;
|
||
|
|
|
||
|
|
if (cnt < nsample) {
|
||
|
|
if (cnt == 0) {
|
||
|
|
for (int l = 0; l < nsample; ++l) {
|
||
|
|
idx[l] = neighbor_idx;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
idx[cnt] = neighbor_idx;
|
||
|
|
++cnt;
|
||
|
|
}
|
||
|
|
// else {
|
||
|
|
// float rnd = curand_uniform(&state);
|
||
|
|
// if (rnd < (float(nsample) / cnt2)) {
|
||
|
|
// int insertidx = ceilf(curand_uniform(&state) * nsample) - 1;
|
||
|
|
// idx[insertidx] = neighbor_idx;
|
||
|
|
// }
|
||
|
|
// }
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if (cnt == 0) idx[0] = -1;
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
void voxel_query_kernel_launcher_stack(int M, int R1, int R2, int R3, int nsample,
|
||
|
|
float radius, int z_range, int y_range, int x_range, const float *new_xyz,
|
||
|
|
const float *xyz, const int *new_coords, const int *point_indices, int *idx) {
|
||
|
|
// :param new_coords: (M1 + M2 ..., 4) centers of the voxel query
|
||
|
|
// :param point_indices: (B, Z, Y, X)
|
||
|
|
// output:
|
||
|
|
// idx: (M1 + M2, nsample)
|
||
|
|
|
||
|
|
cudaError_t err;
|
||
|
|
|
||
|
|
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
||
|
|
dim3 threads(THREADS_PER_BLOCK);
|
||
|
|
|
||
|
|
voxel_query_kernel_stack<<<blocks, threads>>>(M, R1, R2, R3, nsample, radius, z_range, y_range, x_range, new_xyz, xyz, new_coords, point_indices, idx);
|
||
|
|
// cudaDeviceSynchronize(); // for using printf in kernel function
|
||
|
|
|
||
|
|
err = cudaGetLastError();
|
||
|
|
if (cudaSuccess != err) {
|
||
|
|
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
||
|
|
exit(-1);
|
||
|
|
}
|
||
|
|
}
|