diff --git a/pcdet/models/__init__.py b/pcdet/models/__init__.py new file mode 100644 index 0000000..7049bb4 --- /dev/null +++ b/pcdet/models/__init__.py @@ -0,0 +1,54 @@ +from collections import namedtuple + +import numpy as np +import torch + +from .detectors import build_detector + +try: + import kornia +except: + pass + # print('Warning: kornia is not installed. This package is only required by CaDDN') + + + +def build_network(model_cfg, num_class, dataset): + model = build_detector( + model_cfg=model_cfg, num_class=num_class, dataset=dataset + ) + return model + + +def load_data_to_gpu(batch_dict): + for key, val in batch_dict.items(): + if key == 'camera_imgs': + batch_dict[key] = val.cuda() + elif not isinstance(val, np.ndarray): + continue + elif key in ['frame_id', 'metadata', 'calib', 'image_paths','ori_shape','img_process_infos']: + continue + elif key in ['images']: + batch_dict[key] = kornia.image_to_tensor(val).float().cuda().contiguous() + elif key in ['image_shape']: + batch_dict[key] = torch.from_numpy(val).int().cuda() + else: + batch_dict[key] = torch.from_numpy(val).float().cuda() + + +def model_fn_decorator(): + ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict']) + + def model_func(model, batch_dict): + load_data_to_gpu(batch_dict) + ret_dict, tb_dict, disp_dict = model(batch_dict) + + loss = ret_dict['loss'].mean() + if hasattr(model, 'update_global_step'): + model.update_global_step() + else: + model.module.update_global_step() + + return ModelReturn(loss, tb_dict, disp_dict) + + return model_func