Add File
This commit is contained in:
71
pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.h
Normal file
71
pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.h
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
/*
|
||||||
|
Vector-pool aggregation based local feature aggregation for point cloud.
|
||||||
|
PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection
|
||||||
|
https://arxiv.org/abs/2102.00463
|
||||||
|
|
||||||
|
Written by Shaoshuai Shi
|
||||||
|
All Rights Reserved 2020.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef _STACK_VECTOR_POOL_GPU_H
|
||||||
|
#define _STACK_VECTOR_POOL_GPU_H
|
||||||
|
|
||||||
|
#include <torch/serialize/tensor.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
|
|
||||||
|
int query_stacked_local_neighbor_idxs_kernel_launcher_stack(
|
||||||
|
const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt,
|
||||||
|
int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs,
|
||||||
|
float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type);
|
||||||
|
|
||||||
|
int query_stacked_local_neighbor_idxs_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor,
|
||||||
|
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
|
||||||
|
at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cumsum_tensor,
|
||||||
|
int avg_length_of_neighbor_idxs, float max_neighbour_distance, int nsample, int neighbor_type);
|
||||||
|
|
||||||
|
|
||||||
|
int query_three_nn_by_stacked_local_idxs_kernel_launcher_stack(
|
||||||
|
const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers,
|
||||||
|
int *new_xyz_grid_idxs, float *new_xyz_grid_dist2,
|
||||||
|
const int *stack_neighbor_idxs, const int *start_len,
|
||||||
|
int M, int num_total_grids);
|
||||||
|
|
||||||
|
int query_three_nn_by_stacked_local_idxs_wrapper_stack(at::Tensor support_xyz_tensor,
|
||||||
|
at::Tensor new_xyz_tensor, at::Tensor new_xyz_grid_centers_tensor,
|
||||||
|
at::Tensor new_xyz_grid_idxs_tensor, at::Tensor new_xyz_grid_dist2_tensor,
|
||||||
|
at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor,
|
||||||
|
int M, int num_total_grids);
|
||||||
|
|
||||||
|
|
||||||
|
int vector_pool_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor,
|
||||||
|
at::Tensor support_features_tensor, at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
|
||||||
|
at::Tensor new_features_tensor, at::Tensor new_local_xyz,
|
||||||
|
at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor,
|
||||||
|
int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, int use_xyz,
|
||||||
|
int num_max_sum_points, int nsample, int neighbor_type, int pooling_type);
|
||||||
|
|
||||||
|
|
||||||
|
int vector_pool_kernel_launcher_stack(
|
||||||
|
const float *support_xyz, const float *support_features, const int *xyz_batch_cnt,
|
||||||
|
const float *new_xyz, float *new_features, float * new_local_xyz, const int *new_xyz_batch_cnt,
|
||||||
|
int *point_cnt_of_grid, int *grouped_idxs,
|
||||||
|
int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance,
|
||||||
|
int batch_size, int N, int M, int num_c_in, int num_c_out, int num_total_grids, int use_xyz,
|
||||||
|
int num_max_sum_points, int nsample, int neighbor_type, int pooling_type);
|
||||||
|
|
||||||
|
|
||||||
|
int vector_pool_grad_wrapper_stack(at::Tensor grad_new_features_tensor,
|
||||||
|
at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor,
|
||||||
|
at::Tensor grad_support_features_tensor);
|
||||||
|
|
||||||
|
|
||||||
|
void vector_pool_grad_kernel_launcher_stack(
|
||||||
|
const float *grad_new_features, const int *point_cnt_of_grid, const int *grouped_idxs,
|
||||||
|
float *grad_support_features, int N, int M, int num_c_out, int num_c_in, int num_total_grids,
|
||||||
|
int num_max_sum_points);
|
||||||
|
|
||||||
|
#endif
|
||||||
Reference in New Issue
Block a user