211 lines
8.5 KiB
Python
211 lines
8.5 KiB
Python
import _init_path
|
|
import argparse
|
|
import datetime
|
|
import glob
|
|
import os
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tensorboardX import SummaryWriter
|
|
|
|
from eval_utils import eval_utils
|
|
from pcdet.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
|
|
from pcdet.datasets import build_dataloader
|
|
from pcdet.models import build_network
|
|
from pcdet.utils import common_utils
|
|
|
|
|
|
def parse_config():
|
|
parser = argparse.ArgumentParser(description='arg parser')
|
|
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')
|
|
|
|
parser.add_argument('--batch_size', type=int, default=None, required=False, help='batch size for training')
|
|
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
|
|
parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment')
|
|
parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from')
|
|
parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
|
|
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
|
|
parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
|
|
parser.add_argument('--local_rank', type=int, default=None, help='local rank for distributed training')
|
|
parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER,
|
|
help='set extra config keys if needed')
|
|
|
|
parser.add_argument('--max_waiting_mins', type=int, default=30, help='max waiting minutes')
|
|
parser.add_argument('--start_epoch', type=int, default=0, help='')
|
|
parser.add_argument('--eval_tag', type=str, default='default', help='eval tag for this experiment')
|
|
parser.add_argument('--eval_all', action='store_true', default=False, help='whether to evaluate all checkpoints')
|
|
parser.add_argument('--ckpt_dir', type=str, default=None, help='specify a ckpt directory to be evaluated if needed')
|
|
parser.add_argument('--save_to_file', action='store_true', default=False, help='')
|
|
parser.add_argument('--infer_time', action='store_true', default=False, help='calculate inference latency')
|
|
|
|
args = parser.parse_args()
|
|
|
|
cfg_from_yaml_file(args.cfg_file, cfg)
|
|
cfg.TAG = Path(args.cfg_file).stem
|
|
cfg.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1]) # remove 'cfgs' and 'xxxx.yaml'
|
|
|
|
np.random.seed(1024)
|
|
|
|
if args.set_cfgs is not None:
|
|
cfg_from_list(args.set_cfgs, cfg)
|
|
|
|
return args, cfg
|
|
|
|
|
|
def eval_single_ckpt(model, test_loader, args, eval_output_dir, logger, epoch_id, dist_test=False):
|
|
# load checkpoint
|
|
model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=dist_test,
|
|
pre_trained_path=args.pretrained_model)
|
|
model.cuda()
|
|
|
|
# start evaluation
|
|
eval_utils.eval_one_epoch(
|
|
cfg, args, model, test_loader, epoch_id, logger, dist_test=dist_test,
|
|
result_dir=eval_output_dir
|
|
)
|
|
|
|
|
|
def get_no_evaluated_ckpt(ckpt_dir, ckpt_record_file, args):
|
|
ckpt_list = glob.glob(os.path.join(ckpt_dir, '*checkpoint_epoch_*.pth'))
|
|
ckpt_list.sort(key=os.path.getmtime)
|
|
evaluated_ckpt_list = [float(x.strip()) for x in open(ckpt_record_file, 'r').readlines()]
|
|
|
|
for cur_ckpt in ckpt_list:
|
|
num_list = re.findall('checkpoint_epoch_(.*).pth', cur_ckpt)
|
|
if num_list.__len__() == 0:
|
|
continue
|
|
|
|
epoch_id = num_list[-1]
|
|
if 'optim' in epoch_id:
|
|
continue
|
|
if float(epoch_id) not in evaluated_ckpt_list and int(float(epoch_id)) >= args.start_epoch:
|
|
return epoch_id, cur_ckpt
|
|
return -1, None
|
|
|
|
|
|
def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir, dist_test=False):
|
|
# evaluated ckpt record
|
|
ckpt_record_file = eval_output_dir / ('eval_list_%s.txt' % cfg.DATA_CONFIG.DATA_SPLIT['test'])
|
|
with open(ckpt_record_file, 'a'):
|
|
pass
|
|
|
|
# tensorboard log
|
|
if cfg.LOCAL_RANK == 0:
|
|
tb_log = SummaryWriter(log_dir=str(eval_output_dir / ('tensorboard_%s' % cfg.DATA_CONFIG.DATA_SPLIT['test'])))
|
|
total_time = 0
|
|
first_eval = True
|
|
|
|
while True:
|
|
# check whether there is checkpoint which is not evaluated
|
|
cur_epoch_id, cur_ckpt = get_no_evaluated_ckpt(ckpt_dir, ckpt_record_file, args)
|
|
if cur_epoch_id == -1 or int(float(cur_epoch_id)) < args.start_epoch:
|
|
wait_second = 30
|
|
if cfg.LOCAL_RANK == 0:
|
|
print('Wait %s seconds for next check (progress: %.1f / %d minutes): %s \r'
|
|
% (wait_second, total_time * 1.0 / 60, args.max_waiting_mins, ckpt_dir), end='', flush=True)
|
|
time.sleep(wait_second)
|
|
total_time += 30
|
|
if total_time > args.max_waiting_mins * 60 and (first_eval is False):
|
|
break
|
|
continue
|
|
|
|
total_time = 0
|
|
first_eval = False
|
|
|
|
model.load_params_from_file(filename=cur_ckpt, logger=logger, to_cpu=dist_test)
|
|
model.cuda()
|
|
|
|
# start evaluation
|
|
cur_result_dir = eval_output_dir / ('epoch_%s' % cur_epoch_id) / cfg.DATA_CONFIG.DATA_SPLIT['test']
|
|
tb_dict = eval_utils.eval_one_epoch(
|
|
cfg, args, model, test_loader, cur_epoch_id, logger, dist_test=dist_test,
|
|
result_dir=cur_result_dir
|
|
)
|
|
|
|
if cfg.LOCAL_RANK == 0:
|
|
for key, val in tb_dict.items():
|
|
tb_log.add_scalar(key, val, cur_epoch_id)
|
|
|
|
# record this epoch which has been evaluated
|
|
with open(ckpt_record_file, 'a') as f:
|
|
print('%s' % cur_epoch_id, file=f)
|
|
logger.info('Epoch %s has been evaluated' % cur_epoch_id)
|
|
|
|
|
|
def main():
|
|
args, cfg = parse_config()
|
|
|
|
if args.infer_time:
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
|
|
if args.launcher == 'none':
|
|
dist_test = False
|
|
total_gpus = 1
|
|
else:
|
|
if args.local_rank is None:
|
|
args.local_rank = int(os.environ.get('LOCAL_RANK', '0'))
|
|
|
|
total_gpus, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
|
|
args.tcp_port, args.local_rank, backend='nccl'
|
|
)
|
|
dist_test = True
|
|
|
|
if args.batch_size is None:
|
|
args.batch_size = cfg.OPTIMIZATION.BATCH_SIZE_PER_GPU
|
|
else:
|
|
assert args.batch_size % total_gpus == 0, 'Batch size should match the number of gpus'
|
|
args.batch_size = args.batch_size // total_gpus
|
|
|
|
output_dir = cfg.ROOT_DIR / 'output' / cfg.EXP_GROUP_PATH / cfg.TAG / args.extra_tag
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
eval_output_dir = output_dir / 'eval'
|
|
|
|
if not args.eval_all:
|
|
num_list = re.findall(r'\d+', args.ckpt) if args.ckpt is not None else []
|
|
epoch_id = num_list[-1] if num_list.__len__() > 0 else 'no_number'
|
|
eval_output_dir = eval_output_dir / ('epoch_%s' % epoch_id) / cfg.DATA_CONFIG.DATA_SPLIT['test']
|
|
else:
|
|
eval_output_dir = eval_output_dir / 'eval_all_default'
|
|
|
|
if args.eval_tag is not None:
|
|
eval_output_dir = eval_output_dir / args.eval_tag
|
|
|
|
eval_output_dir.mkdir(parents=True, exist_ok=True)
|
|
log_file = eval_output_dir / ('log_eval_%s.txt' % datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
|
|
logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)
|
|
|
|
# log to file
|
|
logger.info('**********************Start logging**********************')
|
|
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
|
|
logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)
|
|
|
|
if dist_test:
|
|
logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
|
|
for key, val in vars(args).items():
|
|
logger.info('{:16} {}'.format(key, val))
|
|
log_config_to_file(cfg, logger=logger)
|
|
|
|
ckpt_dir = args.ckpt_dir if args.ckpt_dir is not None else output_dir / 'ckpt'
|
|
|
|
test_set, test_loader, sampler = build_dataloader(
|
|
dataset_cfg=cfg.DATA_CONFIG,
|
|
class_names=cfg.CLASS_NAMES,
|
|
batch_size=args.batch_size,
|
|
dist=dist_test, workers=args.workers, logger=logger, training=False
|
|
)
|
|
|
|
model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set)
|
|
with torch.no_grad():
|
|
if args.eval_all:
|
|
repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir, dist_test=dist_test)
|
|
else:
|
|
eval_single_ckpt(model, test_loader, args, eval_output_dir, logger, epoch_id, dist_test=dist_test)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|