| import pickle |
| import time |
|
|
| import numpy as np |
| import torch |
| import tqdm |
|
|
| from pcdet.models import load_data_to_gpu |
| from pcdet.utils import common_utils |
|
|
|
|
| def statistics_info(cfg, ret_dict, metric, disp_dict): |
| for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
| metric['recall_roi_%s' % str(cur_thresh)] += ret_dict.get('roi_%s' % str(cur_thresh), 0) |
| metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict.get('rcnn_%s' % str(cur_thresh), 0) |
| metric['gt_num'] += ret_dict.get('gt', 0) |
| min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0] |
| disp_dict['recall_%s' % str(min_thresh)] = \ |
| '(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num']) |
|
|
|
|
| def eval_one_epoch(cfg, args, model, dataloader, epoch_id, logger, dist_test=False, result_dir=None): |
| result_dir.mkdir(parents=True, exist_ok=True) |
|
|
| final_output_dir = result_dir / 'final_result' / 'data' |
| if args.save_to_file: |
| final_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| metric = { |
| 'gt_num': 0, |
| } |
| for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
| metric['recall_roi_%s' % str(cur_thresh)] = 0 |
| metric['recall_rcnn_%s' % str(cur_thresh)] = 0 |
|
|
| dataset = dataloader.dataset |
| class_names = dataset.class_names |
| det_annos = [] |
|
|
| if getattr(args, 'infer_time', False): |
| start_iter = int(len(dataloader) * 0.1) |
| infer_time_meter = common_utils.AverageMeter() |
|
|
| logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id) |
| if dist_test: |
| num_gpus = torch.cuda.device_count() |
| local_rank = cfg.LOCAL_RANK % num_gpus |
| model = torch.nn.parallel.DistributedDataParallel( |
| model, |
| device_ids=[local_rank], |
| broadcast_buffers=False |
| ) |
| model.eval() |
|
|
| if cfg.LOCAL_RANK == 0: |
| progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True) |
| start_time = time.time() |
| for i, batch_dict in enumerate(dataloader): |
| load_data_to_gpu(batch_dict) |
|
|
| if getattr(args, 'infer_time', False): |
| start_time = time.time() |
|
|
| with torch.no_grad(): |
| pred_dicts, ret_dict = model(batch_dict) |
|
|
| disp_dict = {} |
|
|
| if getattr(args, 'infer_time', False): |
| inference_time = time.time() - start_time |
| infer_time_meter.update(inference_time * 1000) |
| |
| disp_dict['infer_time'] = f'{infer_time_meter.val:.2f}({infer_time_meter.avg:.2f})' |
|
|
| statistics_info(cfg, ret_dict, metric, disp_dict) |
| annos = dataset.generate_prediction_dicts( |
| batch_dict, pred_dicts, class_names, |
| output_path=final_output_dir if args.save_to_file else None |
| ) |
| det_annos += annos |
| if cfg.LOCAL_RANK == 0: |
| progress_bar.set_postfix(disp_dict) |
| progress_bar.update() |
|
|
| if cfg.LOCAL_RANK == 0: |
| progress_bar.close() |
|
|
| if dist_test: |
| rank, world_size = common_utils.get_dist_info() |
| det_annos = common_utils.merge_results_dist(det_annos, len(dataset), tmpdir=result_dir / 'tmpdir') |
| metric = common_utils.merge_results_dist([metric], world_size, tmpdir=result_dir / 'tmpdir') |
|
|
| logger.info('*************** Performance of EPOCH %s *****************' % epoch_id) |
| sec_per_example = (time.time() - start_time) / len(dataloader.dataset) |
| logger.info('Generate label finished(sec_per_example: %.4f second).' % sec_per_example) |
|
|
| if cfg.LOCAL_RANK != 0: |
| return {} |
|
|
| ret_dict = {} |
| if dist_test: |
| for key, val in metric[0].items(): |
| for k in range(1, world_size): |
| metric[0][key] += metric[k][key] |
| metric = metric[0] |
|
|
| gt_num_cnt = metric['gt_num'] |
| for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
| cur_roi_recall = metric['recall_roi_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) |
| cur_rcnn_recall = metric['recall_rcnn_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) |
| logger.info('recall_roi_%s: %f' % (cur_thresh, cur_roi_recall)) |
| logger.info('recall_rcnn_%s: %f' % (cur_thresh, cur_rcnn_recall)) |
| ret_dict['recall/roi_%s' % str(cur_thresh)] = cur_roi_recall |
| ret_dict['recall/rcnn_%s' % str(cur_thresh)] = cur_rcnn_recall |
|
|
| total_pred_objects = 0 |
| for anno in det_annos: |
| total_pred_objects += anno['name'].__len__() |
| logger.info('Average predicted number of objects(%d samples): %.3f' |
| % (len(det_annos), total_pred_objects / max(1, len(det_annos)))) |
|
|
| with open(result_dir / 'result.pkl', 'wb') as f: |
| pickle.dump(det_annos, f) |
|
|
| print(f"length of det_annos: {len(det_annos)}") |
| print(dataset) |
| result_str, result_dict = dataset.evaluation( |
| det_annos, class_names, |
| eval_metric=cfg.MODEL.POST_PROCESSING.EVAL_METRIC, |
| output_path=final_output_dir |
| ) |
| print(f"result_dict: {result_dict.keys()}") |
| logger.info(result_str) |
| ret_dict.update(result_dict) |
| logger.info('Result is saved to %s' % result_dir) |
| logger.info('****************Evaluation done.*****************') |
| return ret_dict |
|
|
|
|
| if __name__ == '__main__': |
| pass |
|
|