From f885b005bd7492667992e2343c3293a7daa9fbde Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:31 +0800 Subject: [PATCH] Add File --- pcdet/datasets/augmentor/data_augmentor.py | 319 +++++++++++++++++++++ 1 file changed, 319 insertions(+) create mode 100644 pcdet/datasets/augmentor/data_augmentor.py diff --git a/pcdet/datasets/augmentor/data_augmentor.py b/pcdet/datasets/augmentor/data_augmentor.py new file mode 100644 index 0000000..56acebc --- /dev/null +++ b/pcdet/datasets/augmentor/data_augmentor.py @@ -0,0 +1,319 @@ +from functools import partial + +import numpy as np +from PIL import Image + +from ...utils import common_utils +from . import augmentor_utils, database_sampler + + +class DataAugmentor(object): + def __init__(self, root_path, augmentor_configs, class_names, logger=None): + self.root_path = root_path + self.class_names = class_names + self.logger = logger + + self.data_augmentor_queue = [] + aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \ + else augmentor_configs.AUG_CONFIG_LIST + + for cur_cfg in aug_config_list: + if not isinstance(augmentor_configs, list): + if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST: + continue + cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg) + self.data_augmentor_queue.append(cur_augmentor) + + def disable_augmentation(self, augmentor_configs): + self.data_augmentor_queue = [] + aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \ + else augmentor_configs.AUG_CONFIG_LIST + + for cur_cfg in aug_config_list: + if not isinstance(augmentor_configs, list): + if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST: + continue + cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg) + self.data_augmentor_queue.append(cur_augmentor) + + def gt_sampling(self, config=None): + db_sampler = database_sampler.DataBaseSampler( + root_path=self.root_path, + sampler_cfg=config, + class_names=self.class_names, + logger=self.logger + ) + return db_sampler + + def __getstate__(self): + d = dict(self.__dict__) + del d['logger'] + return d + + def __setstate__(self, d): + self.__dict__.update(d) + + def random_world_flip(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_flip, config=config) + gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] + for cur_axis in config['ALONG_AXIS_LIST']: + assert cur_axis in ['x', 'y'] + gt_boxes, points, enable = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)( + gt_boxes, points, return_flip=True + ) + data_dict['flip_%s'%cur_axis] = enable + if 'roi_boxes' in data_dict.keys(): + num_frame, num_rois,dim = data_dict['roi_boxes'].shape + roi_boxes, _, _ = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)( + data_dict['roi_boxes'].reshape(-1,dim), np.zeros([1,3]), return_flip=True, enable=enable + ) + data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + return data_dict + + def random_world_rotation(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_rotation, config=config) + rot_range = config['WORLD_ROT_ANGLE'] + if not isinstance(rot_range, list): + rot_range = [-rot_range, rot_range] + gt_boxes, points, noise_rot = augmentor_utils.global_rotation( + data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True + ) + if 'roi_boxes' in data_dict.keys(): + num_frame, num_rois,dim = data_dict['roi_boxes'].shape + roi_boxes, _, _ = augmentor_utils.global_rotation( + data_dict['roi_boxes'].reshape(-1, dim), np.zeros([1, 3]), rot_range=rot_range, return_rot=True, noise_rotation=noise_rot) + data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + data_dict['noise_rot'] = noise_rot + return data_dict + + def random_world_scaling(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_scaling, config=config) + + if 'roi_boxes' in data_dict.keys(): + gt_boxes, roi_boxes, points, noise_scale = augmentor_utils.global_scaling_with_roi_boxes( + data_dict['gt_boxes'], data_dict['roi_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True + ) + data_dict['roi_boxes'] = roi_boxes + else: + gt_boxes, points, noise_scale = augmentor_utils.global_scaling( + data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True + ) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + data_dict['noise_scale'] = noise_scale + return data_dict + + def random_image_flip(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_image_flip, config=config) + images = data_dict["images"] + depth_maps = data_dict["depth_maps"] + gt_boxes = data_dict['gt_boxes'] + gt_boxes2d = data_dict["gt_boxes2d"] + calib = data_dict["calib"] + for cur_axis in config['ALONG_AXIS_LIST']: + assert cur_axis in ['horizontal'] + images, depth_maps, gt_boxes = getattr(augmentor_utils, 'random_image_flip_%s' % cur_axis)( + images, depth_maps, gt_boxes, calib, + ) + + data_dict['images'] = images + data_dict['depth_maps'] = depth_maps + data_dict['gt_boxes'] = gt_boxes + return data_dict + + def random_world_translation(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_translation, config=config) + noise_translate_std = config['NOISE_TRANSLATE_STD'] + assert len(noise_translate_std) == 3 + noise_translate = np.array([ + np.random.normal(0, noise_translate_std[0], 1), + np.random.normal(0, noise_translate_std[1], 1), + np.random.normal(0, noise_translate_std[2], 1), + ], dtype=np.float32).T + + gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] + points[:, :3] += noise_translate + gt_boxes[:, :3] += noise_translate + + if 'roi_boxes' in data_dict.keys(): + data_dict['roi_boxes'][:, :3] += noise_translate + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + data_dict['noise_translate'] = noise_translate + return data_dict + + def random_local_translation(self, data_dict=None, config=None): + """ + Please check the correctness of it before using. + """ + if data_dict is None: + return partial(self.random_local_translation, config=config) + offset_range = config['LOCAL_TRANSLATION_RANGE'] + gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] + for cur_axis in config['ALONG_AXIS_LIST']: + assert cur_axis in ['x', 'y', 'z'] + gt_boxes, points = getattr(augmentor_utils, 'random_local_translation_along_%s' % cur_axis)( + gt_boxes, points, offset_range, + ) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + return data_dict + + def random_local_rotation(self, data_dict=None, config=None): + """ + Please check the correctness of it before using. + """ + if data_dict is None: + return partial(self.random_local_rotation, config=config) + rot_range = config['LOCAL_ROT_ANGLE'] + if not isinstance(rot_range, list): + rot_range = [-rot_range, rot_range] + gt_boxes, points = augmentor_utils.local_rotation( + data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range + ) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + return data_dict + + def random_local_scaling(self, data_dict=None, config=None): + """ + Please check the correctness of it before using. + """ + if data_dict is None: + return partial(self.random_local_scaling, config=config) + gt_boxes, points = augmentor_utils.local_scaling( + data_dict['gt_boxes'], data_dict['points'], config['LOCAL_SCALE_RANGE'] + ) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + return data_dict + + def random_world_frustum_dropout(self, data_dict=None, config=None): + """ + Please check the correctness of it before using. + """ + if data_dict is None: + return partial(self.random_world_frustum_dropout, config=config) + + intensity_range = config['INTENSITY_RANGE'] + gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] + for direction in config['DIRECTION']: + assert direction in ['top', 'bottom', 'left', 'right'] + gt_boxes, points = getattr(augmentor_utils, 'global_frustum_dropout_%s' % direction)( + gt_boxes, points, intensity_range, + ) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + return data_dict + + def random_local_frustum_dropout(self, data_dict=None, config=None): + """ + Please check the correctness of it before using. + """ + if data_dict is None: + return partial(self.random_local_frustum_dropout, config=config) + + intensity_range = config['INTENSITY_RANGE'] + gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] + for direction in config['DIRECTION']: + assert direction in ['top', 'bottom', 'left', 'right'] + gt_boxes, points = getattr(augmentor_utils, 'local_frustum_dropout_%s' % direction)( + gt_boxes, points, intensity_range, + ) + + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + return data_dict + + def random_local_pyramid_aug(self, data_dict=None, config=None): + """ + Refer to the paper: + SE-SSD: Self-Ensembling Single-Stage Object Detector From Point Cloud + """ + if data_dict is None: + return partial(self.random_local_pyramid_aug, config=config) + + gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] + + gt_boxes, points, pyramids = augmentor_utils.local_pyramid_dropout(gt_boxes, points, config['DROP_PROB']) + gt_boxes, points, pyramids = augmentor_utils.local_pyramid_sparsify(gt_boxes, points, + config['SPARSIFY_PROB'], + config['SPARSIFY_MAX_NUM'], + pyramids) + gt_boxes, points = augmentor_utils.local_pyramid_swap(gt_boxes, points, + config['SWAP_PROB'], + config['SWAP_MAX_NUM'], + pyramids) + data_dict['gt_boxes'] = gt_boxes + data_dict['points'] = points + return data_dict + + def imgaug(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.imgaug, config=config) + imgs = data_dict["camera_imgs"] + img_process_infos = data_dict['img_process_infos'] + new_imgs = [] + for img, img_process_info in zip(imgs, img_process_infos): + flip = False + if config.RAND_FLIP and np.random.choice([0, 1]): + flip = True + rotate = np.random.uniform(*config.ROT_LIM) + # aug images + if flip: + img = img.transpose(method=Image.FLIP_LEFT_RIGHT) + img = img.rotate(rotate) + img_process_info[2] = flip + img_process_info[3] = rotate + new_imgs.append(img) + + data_dict["camera_imgs"] = new_imgs + return data_dict + + def forward(self, data_dict): + """ + Args: + data_dict: + points: (N, 3 + C_in) + gt_boxes: optional, (N, 7) [x, y, z, dx, dy, dz, heading] + gt_names: optional, (N), string + ... + + Returns: + """ + for cur_augmentor in self.data_augmentor_queue: + data_dict = cur_augmentor(data_dict=data_dict) + + data_dict['gt_boxes'][:, 6] = common_utils.limit_period( + data_dict['gt_boxes'][:, 6], offset=0.5, period=2 * np.pi + ) + # if 'calib' in data_dict: + # data_dict.pop('calib') + if 'road_plane' in data_dict: + data_dict.pop('road_plane') + if 'gt_boxes_mask' in data_dict: + gt_boxes_mask = data_dict['gt_boxes_mask'] + data_dict['gt_boxes'] = data_dict['gt_boxes'][gt_boxes_mask] + data_dict['gt_names'] = data_dict['gt_names'][gt_boxes_mask] + if 'gt_boxes2d' in data_dict: + data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][gt_boxes_mask] + + data_dict.pop('gt_boxes_mask') + return data_dict